├── .idea ├── .gitignore ├── DGNet.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── eval.py ├── figures ├── model.png └── result.png ├── inference.py ├── loaders ├── M&Ms_Dataset_Information.xlsx ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── mms_dataloader_meta_split.cpython-37.pyc ├── mms_dataloader_meta_split.py └── mms_dataloader_meta_split_test.py ├── losses ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── losses.cpython-37.pyc └── losses.py ├── metrics ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dice_loss.cpython-37.pyc │ ├── focal_loss.cpython-37.pyc │ └── gan_loss.cpython-37.pyc ├── dice_loss.py ├── focal_loss.py ├── gan_loss.py └── hausdorff.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dgnet.cpython-37.pyc │ ├── meta_decoder.cpython-37.pyc │ ├── meta_segmentor.cpython-37.pyc │ ├── meta_styleencoder.cpython-37.pyc │ ├── meta_unet.cpython-37.pyc │ ├── ops.cpython-37.pyc │ ├── sdnet_ada.cpython-37.pyc │ ├── unet_parts.cpython-37.pyc │ └── weight_init.cpython-37.pyc ├── dgnet.py ├── meta_decoder.py ├── meta_segmentor.py ├── meta_styleencoder.py ├── meta_unet.py ├── ops.py ├── unet_parts.py └── weight_init.py ├── preprocess ├── save_MNMS_2D.py ├── save_MNMS_re.py ├── save_SCGM_2D.py └── split_MNMS_data.py ├── train_meta.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc └── model_utils.cpython-37.pyc └── model_utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /.idea/DGNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-supervised Meta-learning with Disentanglement for Domain-generalised Medical Image Segmentation 2 | ![model](figures/model.png) 3 | 4 | This repository contains the official Pytorch implementation of [Semi-supervised Meta-learning with Disentanglement for Domain-generalised Medical Image Segmentation](https://arxiv.org/abs/2106.13292)(accepted by [MICCAI 2021](https://miccai2021.org/en/) as Oral). Check the [presentation](https://www.youtube.com/watch?v=IgXAcO8Zyj8) in our official YouTube channel. 5 | 6 | The repository is created by [Xiao Liu](https://github.com/xxxliu95), [Spyridon Thermos](https://github.com/spthermo), [Alison O'Neil](https://vios.science/team/oneil), and [Sotirios A. Tsaftaris](https://www.eng.ed.ac.uk/about/people/dr-sotirios-tsaftaris), as a result of the collaboration between [The University of Edinburgh](https://www.eng.ed.ac.uk/) and [Canon Medical Systems Europe](https://eu.medical.canon/). You are welcome to visit our group website: [vios.s](https://vios.science/) 7 | 8 | # System Requirements 9 | * Pytorch 1.5.1 or higher with GPU support 10 | * Python 3.7.2 or higher 11 | * SciPy 1.5.2 or higher 12 | * CUDA toolkit 10 or newer 13 | * Nibabel 14 | * Pillow 15 | * Scikit-image 16 | * TensorBoard 17 | * Tqdm 18 | 19 | 20 | # Datasets 21 | We used two datasets in the paper: [Multi-Centre, Multi-Vendor & Multi-Disease 22 | Cardiac Image Segmentation Challenge (M&Ms) datast](https://www.ub.edu/mnms/) and [Spinal cord grey matter segmentation challenge dataset](http://niftyweb.cs.ucl.ac.uk/challenge/index.php). The dataloader in this repo is only for M&Ms dataset. 23 | 24 | # Preprocessing 25 | 26 | You need to first change the dirs in the scripts of preprocess folder. Download the M&Ms data and run ```split_MNMS_data.py``` to split the original dataset into different domains. Then run ```save_MNMS_2D.py``` to save the original 4D data as 2D numpy arrays. Finally, run ```save_MNMS_re.py``` to save the resolution of each datum. 27 | 28 | # Training 29 | Note that the hyperparameters in the current version are tuned for BCD to A cases. For other cases, the hyperparameters and few specific layers of the model are slightly different. To train the model with 5% labeled data, run: 30 | ``` 31 | python train_meta.py -e 150 -c cp_dgnet_meta_5_tvA/ -t A -w DGNetRE_COM_META_5_tvA -g 0 32 | ``` 33 | Here the default learning rate is 4e-5. You can change the learning rate by adding ```-lr 0.00002``` (sometimes this is better). 34 | 35 | To train the model with 100% labeled data, try to change the training parameters to: 36 | ``` 37 | k_un = 1 38 | k1 = 20 39 | k2 = 2 40 | ``` 41 | The first parameter controls how many iterations you want the model to be trained with unlabaled data for every iteration of training. ```k1 = 20``` means the learning rate will start to decay after 20 epochs and ```k2 = 2``` means it will check if decay learning every 2 epochs. 42 | 43 | Also, change the ratio ```k=0.05``` (line 221) to ```k=1``` in ```mms_dataloader_meta_split.py```. 44 | 45 | Then, run: 46 | ``` 47 | python train_meta.py -e 80 -c cp_dgnet_meta_100_tvA/ -t A -w DGNetRE_COM_META_100_tvA -g 0 48 | ``` 49 | Finally, when training the model, changing the ```resampling_rate=1.2``` (line 47) in ```mms_dataloader_meta_split.py``` to 1.1 - 1.3 may cause better results. This will change the rescale ratio when preprocessing the images, which will affect the size of the anatomy of interest. 50 | 51 | # Inference 52 | After training, you can test the model: 53 | ``` 54 | python inference.py -bs 1 -c cp_dgnet_meta_100_tvA/ -t A -g 0 55 | ``` 56 | This will output the DICE and Hausdorff results as well as the standard deviation. Similarly, changing the ```resampling_rate=1.2``` (line 47) in ```mms_dataloader_meta_split_test.py``` to 1.1 - 1.3 may cause better results. 57 | 58 | # Qualitative results 59 | ![results](figures/result.png) 60 | 61 | # Citation 62 | ``` 63 | @inproceedings{liu2021semi, 64 | title={Semi-supervised Meta-learning with Disentanglement for Domain-generalised Medical Image Segmentation}, 65 | author={Liu, Xiao and Thermos, Spyridon and O’Neil, Alison and Tsaftaris, Sotirios A}, 66 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 67 | pages={307--317}, 68 | year={2021}, 69 | organization={Springer} 70 | } 71 | ``` 72 | 73 | # Acknowlegement 74 | Part of the code is based on [SDNet](https://github.com/spthermo/SDNet), [MLDG](https://github.com/HAHA-DL/MLDG), [medical-mldg-seg](https://github.com/Pulkit-Khandelwal/medical-mldg-seg) and [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet). 75 | 76 | # License 77 | All scripts are released under the MIT License. 78 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import torch.nn.functional as F 4 | from metrics.dice_loss import dice_coeff 5 | 6 | def eval_dgnet(net, loader, device, mode): 7 | """Evaluation without the densecrf with the dice coefficient""" 8 | net.eval() 9 | mask_type = torch.float32 10 | n_val = len(loader) # the number of batch 11 | tot = 0 12 | tot_lv = 0 13 | tot_myo = 0 14 | tot_rv = 0 15 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: 16 | if mode=='val': 17 | for imgs, true_masks, _ in loader: 18 | imgs = imgs.to(device=device, dtype=torch.float32) 19 | true_masks = true_masks.to(device=device, dtype=mask_type) 20 | 21 | with torch.no_grad(): 22 | reco, z_out, mu_tilde, a_out, mask_pred, mu, logvar, _, _ = net(imgs, true_masks, 'test') 23 | 24 | mask_pred = a_out[:, :4, :, :] 25 | pred = F.softmax(mask_pred, dim=1) 26 | pred = (pred > 0.5).float() 27 | tot += dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :], device).item() 28 | tot_lv += dice_coeff(pred[:, 0, :, :], true_masks[:, 0, :, :], device).item() 29 | tot_myo += dice_coeff(pred[:, 1, :, :], true_masks[:, 1, :, :], device).item() 30 | tot_rv += dice_coeff(pred[:, 2, :, :], true_masks[:, 2, :, :], device).item() 31 | pbar.update() 32 | else: 33 | for imgs, true_masks in loader: 34 | imgs = imgs.to(device=device, dtype=torch.float32) 35 | true_masks = true_masks.to(device=device, dtype=mask_type) 36 | 37 | with torch.no_grad(): 38 | reco, z_out, mu_tilde, a_out, mask_pred, mu, logvar, _, _ = net(imgs, true_masks, 'test') 39 | 40 | mask_pred = a_out[:, :4, :, :] 41 | pred = F.softmax(mask_pred, dim=1) 42 | pred = (pred > 0.5).float() 43 | tot += dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :], device).item() 44 | tot_lv += dice_coeff(pred[:, 0, :, :], true_masks[:, 0, :, :], device).item() 45 | tot_myo += dice_coeff(pred[:, 1, :, :], true_masks[:, 1, :, :], device).item() 46 | tot_rv += dice_coeff(pred[:, 2, :, :], true_masks[:, 2, :, :], device).item() 47 | pbar.update() 48 | 49 | net.train() 50 | return tot / n_val, tot_lv / n_val, tot_myo / n_val, tot_rv / n_val -------------------------------------------------------------------------------- /figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/figures/model.png -------------------------------------------------------------------------------- /figures/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/figures/result.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import torch.nn.functional as F 4 | import statistics 5 | import utils 6 | from loaders.mms_dataloader_meta_split_test import get_meta_split_data_loaders 7 | import models 8 | from metrics.dice_loss import dice_coeff 9 | from metrics.hausdorff import hausdorff_distance 10 | 11 | # python inference.py -bs 1 -c cp_dgnet_gan_meta_dir_5_tvA/ -t A -mn dgnet -g 0 12 | 13 | 14 | def get_args(): 15 | usage_text = ( 16 | "SNet Pytorch Implementation" 17 | "Usage: python train.py [options]," 18 | " with [options]:" 19 | ) 20 | parser = argparse.ArgumentParser(description=usage_text) 21 | #training details 22 | parser.add_argument('-bs','--batch_size', type=int, default=4, help='Number of inputs per batch') 23 | parser.add_argument('-c', '--cp', type=str, default='checkpoints/', help='The name of the checkpoints.') 24 | parser.add_argument('-t', '--tv', type=str, default='D', help='The name of the target vendor.') 25 | parser.add_argument('-w', '--wc', type=str, default='DGNet_LR00002_LDv5', help='The name of the writter summary.') 26 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 27 | parser.add_argument('-mn','--model_name', type=str, default='dgnet', help='Name of the model architecture to be used for training/testing.') 28 | parser.add_argument('-lr','--learning_rate', type=float, default='0.00002', help='The learning rate for model training') 29 | parser.add_argument('-wi','--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)') 30 | parser.add_argument('--save_path', type=str, default='checkpoints', help= 'Path to save model checkpoints') 31 | parser.add_argument('--decoder_type', type=str, default='film', help='Choose decoder type between FiLM and SPADE') 32 | #hardware 33 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 34 | parser.add_argument('--num_workers' ,type= int, default = 0, help='Number of workers to use for dataload') 35 | 36 | return parser.parse_args() 37 | 38 | args = get_args() 39 | device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu') 40 | 41 | batch_size = args.batch_size 42 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 43 | 44 | dir_checkpoint = args.cp 45 | test_vendor = args.tv 46 | # wc = args.wc 47 | model_name = args.model_name 48 | 49 | # Model selection and initialization 50 | model_params = { 51 | 'width': 288, 52 | 'height': 288, 53 | 'ndf': 64, 54 | 'norm': "batchnorm", 55 | 'upsample': "nearest", 56 | 'num_classes': 3, 57 | 'decoder_type': args.decoder_type, 58 | 'anatomy_out_channels': 8, 59 | 'z_length': 8, 60 | 'num_mask_channels': 8, 61 | 62 | } 63 | model = models.get_model(model_name, model_params) 64 | num_params = utils.count_parameters(model) 65 | print('Model Parameters: ', num_params) 66 | model.load_state_dict(torch.load(dir_checkpoint+'CP_epoch.pth', map_location=device)) 67 | model.to(device) 68 | 69 | # writer = SummaryWriter(comment=wc) 70 | 71 | _, _, \ 72 | _, _, \ 73 | _, _, \ 74 | test_loader, \ 75 | _, _, _ = get_meta_split_data_loaders( 76 | batch_size, test_vendor=test_vendor, image_size=224) 77 | 78 | step = 0 79 | tot = [] 80 | tot_sub = [] 81 | tot_hsd = [] 82 | tot_sub_hsd = [] 83 | flag = '000' 84 | # i = 0 85 | for imgs, true_masks, path_img in test_loader: 86 | model.eval() 87 | imgs = imgs.to(device=device, dtype=torch.float32) 88 | mask_type = torch.float32 89 | true_masks = true_masks.to(device=device, dtype=mask_type) 90 | print(flag) 91 | if path_img[0][-10: -7] != flag: 92 | # if i > 10: 93 | # break 94 | # i += 1 95 | flag = path_img[0][-10: -7] 96 | tot.append(sum(tot_sub)/len(tot_sub)) 97 | tot_sub = [] 98 | tot_hsd.append(sum(tot_sub_hsd)/len(tot_sub_hsd)) 99 | tot_sub_hsd = [] 100 | with torch.no_grad(): 101 | reco, z_out, z_out_tilde, a_out, _, mu, logvar, cls_out, _ = model(imgs, true_masks, 'test') 102 | 103 | mask_pred = a_out[:, :4, :, :] 104 | pred = F.softmax(mask_pred, dim=1) 105 | pred = (pred > 0.5).float() 106 | dice = dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :], device).item() 107 | hsd = hausdorff_distance(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :]) 108 | tot_sub.append(dice) 109 | tot_sub_hsd.append(hsd) 110 | print(step) 111 | step += 1 112 | 113 | print(tot) 114 | 115 | print(sum(tot)/len(tot)) 116 | print(statistics.stdev(tot)) 117 | 118 | print(tot_hsd) 119 | 120 | print(sum(tot_hsd)/len(tot_hsd)) 121 | print(statistics.stdev(tot_hsd)) 122 | -------------------------------------------------------------------------------- /loaders/M&Ms_Dataset_Information.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/loaders/M&Ms_Dataset_Information.xlsx -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /loaders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/loaders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loaders/__pycache__/mms_dataloader_meta_split.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/loaders/__pycache__/mms_dataloader_meta_split.cpython-37.pyc -------------------------------------------------------------------------------- /loaders/mms_dataloader_meta_split.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torchfile 3 | from torch.utils.data import DataLoader, TensorDataset 4 | from torchvision import transforms 5 | import torchvision.transforms.functional as F 6 | import torch 7 | import torch.nn as nn 8 | import os 9 | import torchvision.utils as vutils 10 | import numpy as np 11 | import torch.nn.init as init 12 | import torch.utils.data as data 13 | import torch 14 | import random 15 | import xlrd 16 | import math 17 | from skimage.exposure import match_histograms 18 | 19 | # Data directories 20 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorA/' 21 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorA/' 22 | ReA_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorA/' 23 | 24 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center2/' 25 | LabeledVendorB2_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center2/' 26 | ReB2_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center2/' 27 | 28 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center3/' 29 | LabeledVendorB3_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center3/' 30 | ReB3_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center3/' 31 | 32 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorC/' 33 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorC/' 34 | ReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorC/' 35 | 36 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorD/' 37 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorD/' 38 | ReD_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorD/' 39 | 40 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Unlabeled/vendorC/' 41 | UnReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Unlabeled/vendorC/' 42 | 43 | Re_dir = [ReA_dir, ReB2_dir, ReB3_dir, ReC_dir, ReD_dir] 44 | Labeled_data_dir = [LabeledVendorA_data_dir, LabeledVendorB2_data_dir, LabeledVendorB3_data_dir, LabeledVendorC_data_dir, LabeledVendorD_data_dir] 45 | Labeled_mask_dir = [LabeledVendorA_mask_dir, LabeledVendorB2_mask_dir, LabeledVendorB3_mask_dir, LabeledVendorC_mask_dir, LabeledVendorD_mask_dir] 46 | 47 | resampling_rate = 1.2 48 | 49 | def get_meta_split_data_loaders(batch_size, test_vendor='D', image_size=224): 50 | 51 | random.seed(14) 52 | 53 | domain_1_labeled_loader, domain_1_unlabeled_loader, \ 54 | domain_2_labeled_loader, domain_2_unlabeled_loader,\ 55 | domain_3_labeled_loader, domain_3_unlabeled_loader, \ 56 | test_loader, \ 57 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset = \ 58 | get_data_loader_folder(Labeled_data_dir, Labeled_mask_dir, batch_size, image_size, test_num=test_vendor) 59 | 60 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \ 61 | domain_2_labeled_loader, domain_2_unlabeled_loader,\ 62 | domain_3_labeled_loader, domain_3_unlabeled_loader, \ 63 | test_loader, \ 64 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset 65 | 66 | 67 | def get_data_loader_folder(data_folders, mask_folders, batch_size, new_size=288, test_num='D', num_workers=2): 68 | if test_num=='A': 69 | domain_1_img_dirs = [data_folders[1], data_folders[2]] 70 | domain_1_mask_dirs = [mask_folders[1], mask_folders[2]] 71 | domain_2_img_dirs = [data_folders[3]] 72 | domain_2_mask_dirs = [mask_folders[3]] 73 | domain_3_img_dirs = [data_folders[4]] 74 | domain_3_mask_dirs = [mask_folders[4]] 75 | 76 | test_data_dirs = [data_folders[0]] 77 | test_mask_dirs = [mask_folders[0]] 78 | 79 | domain_1_re = [Re_dir[1], Re_dir[2]] 80 | domain_2_re = [Re_dir[3]] 81 | domain_3_re = [Re_dir[4]] 82 | 83 | test_re = [Re_dir[0]] 84 | 85 | domain_1_num = [74, 51] 86 | domain_2_num = [50] 87 | domain_3_num = [50] 88 | test_num = [95] 89 | 90 | elif test_num=='B': 91 | domain_1_img_dirs = [data_folders[0]] 92 | domain_1_mask_dirs = [mask_folders[0]] 93 | domain_2_img_dirs = [data_folders[3]] 94 | domain_2_mask_dirs = [mask_folders[3]] 95 | domain_3_img_dirs = [data_folders[4]] 96 | domain_3_mask_dirs = [mask_folders[4]] 97 | 98 | test_data_dirs = [data_folders[1], data_folders[2]] 99 | test_mask_dirs = [mask_folders[1], mask_folders[2]] 100 | 101 | domain_1_re = [Re_dir[0]] 102 | domain_2_re = [Re_dir[3]] 103 | domain_3_re = [Re_dir[4]] 104 | test_re = [Re_dir[1], Re_dir[2]] 105 | 106 | domain_1_num = [95] 107 | domain_2_num = [50] 108 | domain_3_num = [50] 109 | test_num = [74, 51] 110 | 111 | elif test_num=='C': 112 | domain_1_img_dirs = [data_folders[0]] 113 | domain_1_mask_dirs = [mask_folders[0]] 114 | domain_2_img_dirs = [data_folders[1], data_folders[2]] 115 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]] 116 | domain_3_img_dirs = [data_folders[4]] 117 | domain_3_mask_dirs = [mask_folders[4]] 118 | 119 | test_data_dirs = [data_folders[3]] 120 | test_mask_dirs = [mask_folders[3]] 121 | 122 | domain_1_re = [Re_dir[0]] 123 | domain_2_re = [Re_dir[1], Re_dir[2]] 124 | domain_3_re = [Re_dir[4]] 125 | test_re = [Re_dir[3]] 126 | 127 | domain_1_num = [95] 128 | domain_2_num = [74, 51] 129 | domain_3_num = [50] 130 | test_num = [50] 131 | 132 | elif test_num=='D': 133 | domain_1_img_dirs = [data_folders[0]] 134 | domain_1_mask_dirs = [mask_folders[0]] 135 | domain_2_img_dirs = [data_folders[1], data_folders[2]] 136 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]] 137 | domain_3_img_dirs = [data_folders[3]] 138 | domain_3_mask_dirs = [mask_folders[3]] 139 | 140 | test_data_dirs = [data_folders[4]] 141 | test_mask_dirs = [mask_folders[4]] 142 | 143 | domain_1_re = [Re_dir[0]] 144 | domain_2_re = [Re_dir[1], Re_dir[2]] 145 | domain_3_re = [Re_dir[3]] 146 | test_re = [Re_dir[4]] 147 | 148 | domain_1_num = [95] 149 | domain_2_num = [74, 51] 150 | domain_3_num = [50] 151 | test_num = [50] 152 | 153 | else: 154 | print('Wrong test vendor!') 155 | 156 | 157 | domain_1_labeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, num_label=domain_1_num, train=True, labeled=True) 158 | domain_2_labeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, num_label=domain_2_num, train=True, labeled=True) 159 | domain_3_labeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, num_label=domain_3_num, train=True, labeled=True) 160 | 161 | 162 | # domain_1_labeled_loader = DataLoader(dataset=domain_1_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True) 163 | # domain_2_labeled_loader = DataLoader(dataset=domain_2_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True) 164 | # domain_3_labeled_loader = DataLoader(dataset=domain_3_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True) 165 | 166 | domain_1_labeled_loader = None 167 | domain_2_labeled_loader = None 168 | domain_3_labeled_loader = None 169 | 170 | domain_1_unlabeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, train=True, labeled=False) 171 | domain_2_unlabeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, train=True, labeled=False) 172 | domain_3_unlabeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, train=True, labeled=False) 173 | 174 | domain_1_unlabeled_loader = DataLoader(dataset=domain_1_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 175 | domain_2_unlabeled_loader = DataLoader(dataset=domain_2_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 176 | domain_3_unlabeled_loader = DataLoader(dataset=domain_3_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 177 | 178 | 179 | test_dataset = ImageFolder(test_data_dirs, test_mask_dirs, domain_1_img_dirs, test_re, train=False, labeled=True) 180 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 181 | 182 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \ 183 | domain_2_labeled_loader, domain_2_unlabeled_loader,\ 184 | domain_3_labeled_loader, domain_3_unlabeled_loader, \ 185 | test_loader, \ 186 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset 187 | 188 | def default_loader(path): 189 | return np.load(path)['arr_0'] 190 | 191 | def make_dataset(dir): 192 | images = [] 193 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 194 | 195 | for root, _, fnames in sorted(os.walk(dir)): 196 | for fname in fnames: 197 | path = os.path.join(root, fname) 198 | images.append(path) 199 | return images 200 | 201 | class ImageFolder(data.Dataset): 202 | def __init__(self, data_dirs, mask_dirs, ref_dir, re, train=True, label=None, num_label=None, labeled=True, loader=default_loader): 203 | 204 | print(data_dirs) 205 | print(mask_dirs) 206 | 207 | reso_dir = re 208 | temp_imgs = [] 209 | temp_masks = [] 210 | temp_re = [] 211 | domain_labels = [] 212 | 213 | tem_ref_imgs = [] 214 | 215 | if train: 216 | #100% 217 | # k = 1 218 | #10% 219 | # k = 0.1 220 | #5% 221 | k = 0.05 222 | #2% 223 | # k = 0.02 224 | else: 225 | k = 1 226 | 227 | 228 | for num_set in range(len(data_dirs)): 229 | re_roots = sorted(make_dataset(reso_dir[num_set])) 230 | data_roots = sorted(make_dataset(data_dirs[num_set])) 231 | mask_roots = sorted(make_dataset(mask_dirs[num_set])) 232 | num_label_data = 0 233 | for num_data in range(len(data_roots)): 234 | if labeled: 235 | if train: 236 | n_label = str(math.ceil(num_label[num_set] * k + 1)) 237 | if '00'+n_label==data_roots[num_data][-10:-7] or '0'+n_label==data_roots[num_data][-10:-7]: 238 | print(n_label) 239 | print(data_roots[num_data][-10:-7]) 240 | break 241 | 242 | for num_mask in range(len(mask_roots)): 243 | if data_roots[num_data][-10:-4] == mask_roots[num_mask][-10:-4]: 244 | temp_re.append(re_roots[num_data]) 245 | temp_imgs.append(data_roots[num_data]) 246 | temp_masks.append(mask_roots[num_mask]) 247 | domain_labels.append(label) 248 | num_label_data += 1 249 | else: 250 | pass 251 | else: 252 | temp_re.append(re_roots[num_data]) 253 | temp_imgs.append(data_roots[num_data]) 254 | domain_labels.append(label) 255 | 256 | for num_set in range(len(ref_dir)): 257 | data_roots = sorted(make_dataset(ref_dir[num_set])) 258 | for num_data in range(len(data_roots)): 259 | tem_ref_imgs.append(data_roots[num_data]) 260 | 261 | reso = temp_re 262 | imgs = temp_imgs 263 | masks = temp_masks 264 | labels = domain_labels 265 | 266 | print(len(masks)) 267 | 268 | ref_imgs = tem_ref_imgs 269 | 270 | # add something here to index, such that can split the data 271 | # index = random.sample(range(len(temp_img)), len(temp_mask)) 272 | 273 | self.reso = reso 274 | self.imgs = imgs 275 | self.masks = masks 276 | self.labels = labels 277 | self.new_size = 288 278 | self.loader = loader 279 | self.labeled = labeled 280 | self.train = train 281 | self.ref = ref_imgs 282 | 283 | def __getitem__(self, index): 284 | if self.train: 285 | index = random.randrange(len(self.imgs)) 286 | else: 287 | pass 288 | 289 | path_re = self.reso[index] 290 | re = self.loader(path_re) 291 | re = re[0] 292 | 293 | path_img = self.imgs[index] 294 | img = self.loader(path_img) # numpy, HxW, numpy.Float64 295 | # 296 | # ref_paths = random.sample(self.ref, 1) 297 | # ref_img = self.loader(ref_paths[0]) 298 | # 299 | # img = match_histograms(img, ref_img) 300 | 301 | label = self.labels[index] 302 | 303 | 304 | if label==0: 305 | one_hot_label = torch.tensor([[1], [0], [0]]) 306 | elif label==1: 307 | one_hot_label = torch.tensor([[0], [1], [0]]) 308 | elif label==2: 309 | one_hot_label = torch.tensor([[0], [0], [1]]) 310 | else: 311 | one_hot_label = torch.tensor([[0], [0], [0]]) 312 | 313 | # Intensity cropping: 314 | p5 = np.percentile(img.flatten(), 0.5) 315 | p95 = np.percentile(img.flatten(), 99.5) 316 | img = np.clip(img, p5, p95) 317 | 318 | img -= img.min() 319 | img /= img.max() 320 | img = img.astype('float32') 321 | 322 | crop_size = 300 323 | 324 | # Augmentations: 325 | # 1. random rotation 326 | # 2. random scaling 0.8 - 1.2 327 | # 3. random crop from 280x280 328 | # 4. random hflip 329 | # 5. random vflip 330 | # 6. color jitter 331 | # 7. Gaussian filtering 332 | 333 | img_tensor = F.to_tensor(np.array(img)) 334 | img_size = img_tensor.size() 335 | 336 | if self.labeled: 337 | if self.train: 338 | img = F.to_pil_image(img) 339 | # rotate, random angle between 0 - 90 340 | angle = random.randint(0, 90) 341 | img = F.rotate(img, angle, Image.BILINEAR) 342 | 343 | path_mask = self.masks[index] 344 | mask = Image.open(path_mask) # numpy, HxWx3 345 | # rotate, random angle between 0 - 90 346 | mask = F.rotate(mask, angle, Image.NEAREST) 347 | 348 | ## Find the region of mask 349 | norm_mask = F.to_tensor(np.array(mask)) 350 | region = norm_mask[0] + norm_mask[1] + norm_mask[2] 351 | non_zero_index = torch.nonzero(region == 1, as_tuple=False) 352 | if region.sum() > 0: 353 | len_m = len(non_zero_index[0]) 354 | x_region = non_zero_index[len_m//2][0] 355 | y_region = non_zero_index[len_m//2][1] 356 | x_region = int(x_region.item()) 357 | y_region = int(y_region.item()) 358 | else: 359 | x_region = norm_mask.size(-2) // 2 360 | y_region = norm_mask.size(-1) // 2 361 | 362 | # resize and center-crop to 280x280 363 | resize_order = re / resampling_rate 364 | resize_size_h = int(img_size[-2] * resize_order) 365 | resize_size_w = int(img_size[-1] * resize_order) 366 | 367 | left_size = 0 368 | top_size = 0 369 | right_size = 0 370 | bot_size = 0 371 | if resize_size_h < self.new_size: 372 | top_size = (self.new_size - resize_size_h) // 2 373 | bot_size = (self.new_size - resize_size_h) - top_size 374 | if resize_size_w < self.new_size: 375 | left_size = (self.new_size - resize_size_w) // 2 376 | right_size = (self.new_size - resize_size_w) - left_size 377 | 378 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] 379 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list 380 | transform = transforms.Compose(transform_list) 381 | 382 | img = transform(img) 383 | 384 | 385 | ## Define the crop index 386 | if top_size >= 0: 387 | top_crop = 0 388 | else: 389 | if x_region > self.new_size//2: 390 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2): 391 | top_crop = x_region - self.new_size//2 392 | else: 393 | top_crop = norm_mask.size(-2) - self.new_size 394 | else: 395 | top_crop = 0 396 | 397 | if left_size >= 0: 398 | left_crop = 0 399 | else: 400 | if y_region > self.new_size//2: 401 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1): 402 | left_crop = y_region - self.new_size//2 403 | else: 404 | left_crop = norm_mask.size(-1) - self.new_size 405 | else: 406 | left_crop = 0 407 | 408 | # random crop to 224x224 409 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size) 410 | 411 | # random flip 412 | hflip_p = random.random() 413 | img = F.hflip(img) if hflip_p >= 0.5 else img 414 | vflip_p = random.random() 415 | img = F.vflip(img) if vflip_p >= 0.5 else img 416 | 417 | img = F.to_tensor(np.array(img)) 418 | 419 | # # Gamma correction: random gamma from [0.5, 1.5] 420 | # gamma = 0.5 + random.random() 421 | # img_max = img.max() 422 | # img = img_max * torch.pow((img / img_max), gamma) 423 | 424 | # Gaussian bluring: 425 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))] 426 | transform = transforms.Compose(transform_list) 427 | img = transform(img) 428 | 429 | # resize and center-crop to 280x280 430 | transform_mask_list = [transforms.Pad( 431 | (left_size, top_size, right_size, bot_size))] 432 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w), 433 | interpolation=Image.NEAREST)] + transform_mask_list 434 | transform_mask = transforms.Compose(transform_mask_list) 435 | 436 | mask = transform_mask(mask) # C,H,W 437 | 438 | # random crop to 224x224 439 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size) 440 | 441 | # random flip 442 | mask = F.hflip(mask) if hflip_p >= 0.5 else mask 443 | mask = F.vflip(mask) if vflip_p >= 0.5 else mask 444 | 445 | mask = F.to_tensor(np.array(mask)) 446 | 447 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W 448 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1))) 449 | mask = torch.cat((mask, mask_bg), dim=0) 450 | 451 | return img, mask, one_hot_label.squeeze() # pytorch: N,C,H,W 452 | 453 | else: 454 | path_mask = self.masks[index] 455 | mask = Image.open(path_mask) # numpy, HxWx3 456 | # resize and center-crop to 280x280 457 | 458 | ## Find the region of mask 459 | norm_mask = F.to_tensor(np.array(mask)) 460 | region = norm_mask[0] + norm_mask[1] + norm_mask[2] 461 | non_zero_index = torch.nonzero(region == 1, as_tuple=False) 462 | if region.sum() > 0: 463 | len_m = len(non_zero_index[0]) 464 | x_region = non_zero_index[len_m//2][0] 465 | y_region = non_zero_index[len_m//2][1] 466 | x_region = int(x_region.item()) 467 | y_region = int(y_region.item()) 468 | else: 469 | x_region = norm_mask.size(-2) // 2 470 | y_region = norm_mask.size(-1) // 2 471 | 472 | resize_order = re / resampling_rate 473 | resize_size_h = int(img_size[-2] * resize_order) 474 | resize_size_w = int(img_size[-1] * resize_order) 475 | 476 | left_size = 0 477 | top_size = 0 478 | right_size = 0 479 | bot_size = 0 480 | if resize_size_h < self.new_size: 481 | top_size = (self.new_size - resize_size_h) // 2 482 | bot_size = (self.new_size - resize_size_h) - top_size 483 | if resize_size_w < self.new_size: 484 | left_size = (self.new_size - resize_size_w) // 2 485 | right_size = (self.new_size - resize_size_w) - left_size 486 | 487 | 488 | # transform_list = [transforms.CenterCrop((crop_size, crop_size))] 489 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] 490 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list 491 | transform_list = [transforms.ToPILImage()] + transform_list 492 | transform = transforms.Compose(transform_list) 493 | img = transform(img) 494 | img = F.to_tensor(np.array(img)) 495 | 496 | ## Define the crop index 497 | if top_size >= 0: 498 | top_crop = 0 499 | else: 500 | if x_region > self.new_size//2: 501 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2): 502 | top_crop = x_region - self.new_size//2 503 | else: 504 | top_crop = norm_mask.size(-2) - self.new_size 505 | else: 506 | top_crop = 0 507 | 508 | if left_size >= 0: 509 | left_crop = 0 510 | else: 511 | if y_region > self.new_size//2: 512 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1): 513 | left_crop = y_region - self.new_size//2 514 | else: 515 | left_crop = norm_mask.size(-1) - self.new_size 516 | else: 517 | left_crop = 0 518 | 519 | # random crop to 224x224 520 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size) 521 | 522 | # resize and center-crop to 280x280 523 | # transform_mask_list = [transforms.CenterCrop((crop_size, crop_size))] 524 | transform_mask_list = [transforms.Pad( 525 | (left_size, top_size, right_size, bot_size))] 526 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w), 527 | interpolation=Image.NEAREST)] + transform_mask_list 528 | transform_mask = transforms.Compose(transform_mask_list) 529 | 530 | mask = transform_mask(mask) # C,H,W 531 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size) 532 | mask = F.to_tensor(np.array(mask)) 533 | 534 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W 535 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1))) 536 | mask = torch.cat((mask, mask_bg), dim=0) 537 | 538 | return img, mask 539 | 540 | else: 541 | img = F.to_pil_image(img) 542 | # rotate, random angle between 0 - 90 543 | angle = random.randint(0, 90) 544 | img = F.rotate(img, angle, Image.BILINEAR) 545 | 546 | # resize and center-crop to 280x280 547 | resize_order = re / resampling_rate 548 | resize_size_h = int(img_size[-2] * resize_order) 549 | resize_size_w = int(img_size[-1] * resize_order) 550 | 551 | left_size = 0 552 | top_size = 0 553 | right_size = 0 554 | bot_size = 0 555 | if resize_size_h < crop_size: 556 | top_size = (crop_size - resize_size_h) // 2 557 | bot_size = (crop_size - resize_size_h) - top_size 558 | if resize_size_w < crop_size: 559 | left_size = (crop_size - resize_size_w) // 2 560 | right_size = (crop_size - resize_size_w) - left_size 561 | 562 | transform_list = [transforms.CenterCrop((crop_size, crop_size))] 563 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] + transform_list 564 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list 565 | transform = transforms.Compose(transform_list) 566 | 567 | img = transform(img) 568 | 569 | # random crop to 224x224 570 | top_crop = random.randint(0, crop_size - self.new_size) 571 | left_crop = random.randint(0, crop_size - self.new_size) 572 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size) 573 | 574 | # random flip 575 | hflip_p = random.random() 576 | img = F.hflip(img) if hflip_p >= 0.5 else img 577 | vflip_p = random.random() 578 | img = F.vflip(img) if vflip_p >= 0.5 else img 579 | 580 | img = F.to_tensor(np.array(img)) 581 | 582 | # # Gamma correction: random gamma from [0.5, 1.5] 583 | # gamma = 0.5 + random.random() 584 | # img_max = img.max() 585 | # img = img_max*torch.pow((img/img_max), gamma) 586 | 587 | # Gaussian bluring: 588 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))] 589 | transform = transforms.Compose(transform_list) 590 | img = transform(img) 591 | 592 | return img, one_hot_label.squeeze() # pytorch: N,C,H,W 593 | 594 | def __len__(self): 595 | return len(self.imgs) 596 | 597 | 598 | -------------------------------------------------------------------------------- /loaders/mms_dataloader_meta_split_test.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torchfile 3 | from torch.utils.data import DataLoader, TensorDataset 4 | from torchvision import transforms 5 | import torchvision.transforms.functional as F 6 | import torch 7 | import torch.nn as nn 8 | import os 9 | import torchvision.utils as vutils 10 | import numpy as np 11 | import torch.nn.init as init 12 | import torch.utils.data as data 13 | import torch 14 | import random 15 | import xlrd 16 | import math 17 | from skimage.exposure import match_histograms 18 | 19 | # Data directories 20 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorA/' 21 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorA/' 22 | ReA_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorA/' 23 | 24 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center2/' 25 | LabeledVendorB2_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center2/' 26 | ReB2_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center2/' 27 | 28 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center3/' 29 | LabeledVendorB3_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center3/' 30 | ReB3_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center3/' 31 | 32 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorC/' 33 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorC/' 34 | ReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorC/' 35 | 36 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorD/' 37 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorD/' 38 | ReD_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorD/' 39 | 40 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Unlabeled/vendorC/' 41 | UnReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Unlabeled/vendorC/' 42 | 43 | Re_dir = [ReA_dir, ReB2_dir, ReB3_dir, ReC_dir, ReD_dir] 44 | Labeled_data_dir = [LabeledVendorA_data_dir, LabeledVendorB2_data_dir, LabeledVendorB3_data_dir, LabeledVendorC_data_dir, LabeledVendorD_data_dir] 45 | Labeled_mask_dir = [LabeledVendorA_mask_dir, LabeledVendorB2_mask_dir, LabeledVendorB3_mask_dir, LabeledVendorC_mask_dir, LabeledVendorD_mask_dir] 46 | 47 | resampling_rate = 1.2 48 | 49 | def get_meta_split_data_loaders(batch_size, test_vendor='D', image_size=224): 50 | 51 | random.seed(14) 52 | 53 | domain_1_labeled_loader, domain_1_unlabeled_loader, \ 54 | domain_2_labeled_loader, domain_2_unlabeled_loader,\ 55 | domain_3_labeled_loader, domain_3_unlabeled_loader, \ 56 | test_loader, \ 57 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset = \ 58 | get_data_loader_folder(Labeled_data_dir, Labeled_mask_dir, batch_size, image_size, test_num=test_vendor) 59 | 60 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \ 61 | domain_2_labeled_loader, domain_2_unlabeled_loader,\ 62 | domain_3_labeled_loader, domain_3_unlabeled_loader, \ 63 | test_loader, \ 64 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset 65 | 66 | 67 | def get_data_loader_folder(data_folders, mask_folders, batch_size, new_size=288, test_num='D', num_workers=2): 68 | if test_num=='A': 69 | domain_1_img_dirs = [data_folders[1], data_folders[2]] 70 | domain_1_mask_dirs = [mask_folders[1], mask_folders[2]] 71 | domain_2_img_dirs = [data_folders[3]] 72 | domain_2_mask_dirs = [mask_folders[3]] 73 | domain_3_img_dirs = [data_folders[4]] 74 | domain_3_mask_dirs = [mask_folders[4]] 75 | 76 | test_data_dirs = [data_folders[0]] 77 | test_mask_dirs = [mask_folders[0]] 78 | 79 | domain_1_re = [Re_dir[1], Re_dir[2]] 80 | domain_2_re = [Re_dir[3]] 81 | domain_3_re = [Re_dir[4]] 82 | 83 | test_re = [Re_dir[0]] 84 | 85 | domain_1_num = [74, 51] 86 | domain_2_num = [50] 87 | domain_3_num = [50] 88 | test_num = [95] 89 | 90 | elif test_num=='B': 91 | domain_1_img_dirs = [data_folders[0]] 92 | domain_1_mask_dirs = [mask_folders[0]] 93 | domain_2_img_dirs = [data_folders[3]] 94 | domain_2_mask_dirs = [mask_folders[3]] 95 | domain_3_img_dirs = [data_folders[4]] 96 | domain_3_mask_dirs = [mask_folders[4]] 97 | 98 | test_data_dirs = [data_folders[1], data_folders[2]] 99 | test_mask_dirs = [mask_folders[1], mask_folders[2]] 100 | 101 | domain_1_re = [Re_dir[0]] 102 | domain_2_re = [Re_dir[3]] 103 | domain_3_re = [Re_dir[4]] 104 | test_re = [Re_dir[1], Re_dir[2]] 105 | 106 | domain_1_num = [95] 107 | domain_2_num = [50] 108 | domain_3_num = [50] 109 | test_num = [74, 51] 110 | 111 | elif test_num=='C': 112 | domain_1_img_dirs = [data_folders[0]] 113 | domain_1_mask_dirs = [mask_folders[0]] 114 | domain_2_img_dirs = [data_folders[1], data_folders[2]] 115 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]] 116 | domain_3_img_dirs = [data_folders[4]] 117 | domain_3_mask_dirs = [mask_folders[4]] 118 | 119 | test_data_dirs = [data_folders[3]] 120 | test_mask_dirs = [mask_folders[3]] 121 | 122 | domain_1_re = [Re_dir[0]] 123 | domain_2_re = [Re_dir[1], Re_dir[2]] 124 | domain_3_re = [Re_dir[4]] 125 | test_re = [Re_dir[3]] 126 | 127 | domain_1_num = [95] 128 | domain_2_num = [74, 51] 129 | domain_3_num = [50] 130 | test_num = [50] 131 | 132 | elif test_num=='D': 133 | domain_1_img_dirs = [data_folders[0]] 134 | domain_1_mask_dirs = [mask_folders[0]] 135 | domain_2_img_dirs = [data_folders[1], data_folders[2]] 136 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]] 137 | domain_3_img_dirs = [data_folders[3]] 138 | domain_3_mask_dirs = [mask_folders[3]] 139 | 140 | test_data_dirs = [data_folders[4]] 141 | test_mask_dirs = [mask_folders[4]] 142 | 143 | domain_1_re = [Re_dir[0]] 144 | domain_2_re = [Re_dir[1], Re_dir[2]] 145 | domain_3_re = [Re_dir[3]] 146 | test_re = [Re_dir[4]] 147 | 148 | domain_1_num = [95] 149 | domain_2_num = [74, 51] 150 | domain_3_num = [50] 151 | test_num = [50] 152 | 153 | else: 154 | print('Wrong test vendor!') 155 | 156 | 157 | domain_1_labeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, num_label=domain_1_num, train=True, labeled=True) 158 | domain_2_labeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, num_label=domain_2_num, train=True, labeled=True) 159 | domain_3_labeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, num_label=domain_3_num, train=True, labeled=True) 160 | 161 | 162 | # domain_1_labeled_loader = DataLoader(dataset=domain_1_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True) 163 | # domain_2_labeled_loader = DataLoader(dataset=domain_2_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True) 164 | # domain_3_labeled_loader = DataLoader(dataset=domain_3_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True) 165 | 166 | domain_1_labeled_loader = None 167 | domain_2_labeled_loader = None 168 | domain_3_labeled_loader = None 169 | 170 | domain_1_unlabeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, train=True, labeled=False) 171 | domain_2_unlabeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, train=True, labeled=False) 172 | domain_3_unlabeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, train=True, labeled=False) 173 | 174 | domain_1_unlabeled_loader = DataLoader(dataset=domain_1_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 175 | domain_2_unlabeled_loader = DataLoader(dataset=domain_2_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 176 | domain_3_unlabeled_loader = DataLoader(dataset=domain_3_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 177 | 178 | 179 | test_dataset = ImageFolder(test_data_dirs, test_mask_dirs, domain_1_img_dirs, test_re, train=False, labeled=True) 180 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True) 181 | 182 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \ 183 | domain_2_labeled_loader, domain_2_unlabeled_loader,\ 184 | domain_3_labeled_loader, domain_3_unlabeled_loader, \ 185 | test_loader, \ 186 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset 187 | 188 | def default_loader(path): 189 | return np.load(path)['arr_0'] 190 | 191 | def make_dataset(dir): 192 | images = [] 193 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 194 | 195 | for root, _, fnames in sorted(os.walk(dir)): 196 | for fname in fnames: 197 | path = os.path.join(root, fname) 198 | images.append(path) 199 | return images 200 | 201 | class ImageFolder(data.Dataset): 202 | def __init__(self, data_dirs, mask_dirs, ref_dir, re, train=True, label=None, num_label=None, labeled=True, loader=default_loader): 203 | 204 | print(data_dirs) 205 | print(mask_dirs) 206 | 207 | reso_dir = re 208 | temp_imgs = [] 209 | temp_masks = [] 210 | temp_re = [] 211 | domain_labels = [] 212 | 213 | tem_ref_imgs = [] 214 | 215 | if train: 216 | #100% 217 | # k = 1 218 | #10% 219 | # k = 0.1 220 | #5% 221 | # k = 0.05 222 | #2% 223 | k = 0.02 224 | else: 225 | k = 1 226 | 227 | 228 | for num_set in range(len(data_dirs)): 229 | re_roots = sorted(make_dataset(reso_dir[num_set])) 230 | data_roots = sorted(make_dataset(data_dirs[num_set])) 231 | mask_roots = sorted(make_dataset(mask_dirs[num_set])) 232 | num_label_data = 0 233 | for num_data in range(len(data_roots)): 234 | if labeled: 235 | if train: 236 | n_label = str(math.ceil(num_label[num_set] * k + 1)) 237 | if '00'+n_label==data_roots[num_data][-10:-7] or '0'+n_label==data_roots[num_data][-10:-7]: 238 | print(n_label) 239 | print(data_roots[num_data][-10:-7]) 240 | break 241 | 242 | for num_mask in range(len(mask_roots)): 243 | if data_roots[num_data][-10:-4] == mask_roots[num_mask][-10:-4]: 244 | temp_re.append(re_roots[num_data]) 245 | temp_imgs.append(data_roots[num_data]) 246 | temp_masks.append(mask_roots[num_mask]) 247 | domain_labels.append(label) 248 | num_label_data += 1 249 | else: 250 | pass 251 | else: 252 | temp_re.append(re_roots[num_data]) 253 | temp_imgs.append(data_roots[num_data]) 254 | domain_labels.append(label) 255 | 256 | for num_set in range(len(ref_dir)): 257 | data_roots = sorted(make_dataset(ref_dir[num_set])) 258 | for num_data in range(len(data_roots)): 259 | tem_ref_imgs.append(data_roots[num_data]) 260 | 261 | reso = temp_re 262 | imgs = temp_imgs 263 | masks = temp_masks 264 | labels = domain_labels 265 | 266 | print(len(masks)) 267 | 268 | ref_imgs = tem_ref_imgs 269 | 270 | # add something here to index, such that can split the data 271 | # index = random.sample(range(len(temp_img)), len(temp_mask)) 272 | 273 | self.reso = reso 274 | self.imgs = imgs 275 | self.masks = masks 276 | self.labels = labels 277 | self.new_size = 288 278 | self.loader = loader 279 | self.labeled = labeled 280 | self.train = train 281 | self.ref = ref_imgs 282 | 283 | def __getitem__(self, index): 284 | if self.train: 285 | index = random.randrange(len(self.imgs)) 286 | else: 287 | pass 288 | 289 | path_re = self.reso[index] 290 | re = self.loader(path_re) 291 | re = re[0] 292 | 293 | path_img = self.imgs[index] 294 | img = self.loader(path_img) # numpy, HxW, numpy.Float64 295 | 296 | # ref_paths = random.sample(self.ref, 1) 297 | # ref_img = self.loader(ref_paths[0]) 298 | # 299 | # img = match_histograms(img, ref_img) 300 | 301 | label = self.labels[index] 302 | 303 | 304 | if label==0: 305 | one_hot_label = torch.tensor([[1], [0], [0]]) 306 | elif label==1: 307 | one_hot_label = torch.tensor([[0], [1], [0]]) 308 | elif label==2: 309 | one_hot_label = torch.tensor([[0], [0], [1]]) 310 | else: 311 | one_hot_label = torch.tensor([[0], [0], [0]]) 312 | 313 | # Intensity cropping: 314 | p5 = np.percentile(img.flatten(), 0.5) 315 | p95 = np.percentile(img.flatten(), 99.5) 316 | img = np.clip(img, p5, p95) 317 | 318 | img -= img.min() 319 | img /= img.max() 320 | img = img.astype('float32') 321 | 322 | crop_size = 300 323 | 324 | # Augmentations: 325 | # 1. random rotation 326 | # 2. random scaling 0.8 - 1.2 327 | # 3. random crop from 280x280 328 | # 4. random hflip 329 | # 5. random vflip 330 | # 6. color jitter 331 | # 7. Gaussian filtering 332 | 333 | img_tensor = F.to_tensor(np.array(img)) 334 | img_size = img_tensor.size() 335 | 336 | 337 | if self.labeled: 338 | if self.train: 339 | 340 | img = F.to_pil_image(img) 341 | # rotate, random angle between 0 - 90 342 | angle = random.randint(0, 90) 343 | img = F.rotate(img, angle, Image.BILINEAR) 344 | 345 | path_mask = self.masks[index] 346 | mask = Image.open(path_mask) # numpy, HxWx3 347 | # rotate, random angle between 0 - 90 348 | mask = F.rotate(mask, angle, Image.NEAREST) 349 | 350 | ## Find the region of mask 351 | norm_mask = F.to_tensor(np.array(mask)) 352 | region = norm_mask[0] + norm_mask[1] + norm_mask[2] 353 | non_zero_index = torch.nonzero(region == 1, as_tuple=False) 354 | if region.sum() > 0: 355 | len_m = len(non_zero_index[0]) 356 | x_region = non_zero_index[len_m//2][0] 357 | y_region = non_zero_index[len_m//2][1] 358 | x_region = int(x_region.item()) 359 | y_region = int(y_region.item()) 360 | else: 361 | x_region = norm_mask.size(-2) // 2 362 | y_region = norm_mask.size(-1) // 2 363 | 364 | # resize and center-crop to 280x280 365 | resize_order = re / resampling_rate 366 | resize_size_h = int(img_size[-2] * resize_order) 367 | resize_size_w = int(img_size[-1] * resize_order) 368 | 369 | left_size = 0 370 | top_size = 0 371 | right_size = 0 372 | bot_size = 0 373 | if resize_size_h < self.new_size: 374 | top_size = (self.new_size - resize_size_h) // 2 375 | bot_size = (self.new_size - resize_size_h) - top_size 376 | if resize_size_w < self.new_size: 377 | left_size = (self.new_size - resize_size_w) // 2 378 | right_size = (self.new_size - resize_size_w) - left_size 379 | 380 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] 381 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list 382 | transform = transforms.Compose(transform_list) 383 | 384 | img = transform(img) 385 | 386 | 387 | ## Define the crop index 388 | if top_size >= 0: 389 | top_crop = 0 390 | else: 391 | if x_region > self.new_size//2: 392 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2): 393 | top_crop = x_region - self.new_size//2 394 | else: 395 | top_crop = norm_mask.size(-2) - self.new_size 396 | else: 397 | top_crop = 0 398 | 399 | if left_size >= 0: 400 | left_crop = 0 401 | else: 402 | if y_region > self.new_size//2: 403 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1): 404 | left_crop = y_region - self.new_size//2 405 | else: 406 | left_crop = norm_mask.size(-1) - self.new_size 407 | else: 408 | left_crop = 0 409 | 410 | # random crop to 224x224 411 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size) 412 | 413 | # random flip 414 | hflip_p = random.random() 415 | img = F.hflip(img) if hflip_p >= 0.5 else img 416 | vflip_p = random.random() 417 | img = F.vflip(img) if vflip_p >= 0.5 else img 418 | 419 | img = F.to_tensor(np.array(img)) 420 | 421 | # # Gamma correction: random gamma from [0.5, 1.5] 422 | # gamma = 0.5 + random.random() 423 | # img_max = img.max() 424 | # img = img_max * torch.pow((img / img_max), gamma) 425 | 426 | # Gaussian bluring: 427 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))] 428 | transform = transforms.Compose(transform_list) 429 | img = transform(img) 430 | 431 | # resize and center-crop to 280x280 432 | transform_mask_list = [transforms.Pad( 433 | (left_size, top_size, right_size, bot_size))] 434 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w), 435 | interpolation=Image.NEAREST)] + transform_mask_list 436 | transform_mask = transforms.Compose(transform_mask_list) 437 | 438 | mask = transform_mask(mask) # C,H,W 439 | 440 | # random crop to 224x224 441 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size) 442 | 443 | # random flip 444 | mask = F.hflip(mask) if hflip_p >= 0.5 else mask 445 | mask = F.vflip(mask) if vflip_p >= 0.5 else mask 446 | 447 | mask = F.to_tensor(np.array(mask)) 448 | 449 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W 450 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1))) 451 | mask = torch.cat((mask, mask_bg), dim=0) 452 | 453 | return img, mask, one_hot_label.squeeze() # pytorch: N,C,H,W 454 | 455 | else: 456 | path_mask = self.masks[index] 457 | mask = Image.open(path_mask) # numpy, HxWx3 458 | # resize and center-crop to 280x280 459 | 460 | ## Find the region of mask 461 | norm_mask = F.to_tensor(np.array(mask)) 462 | region = norm_mask[0] + norm_mask[1] + norm_mask[2] 463 | non_zero_index = torch.nonzero(region == 1, as_tuple=False) 464 | if region.sum() > 0: 465 | len_m = len(non_zero_index[0]) 466 | x_region = non_zero_index[len_m//2][0] 467 | y_region = non_zero_index[len_m//2][1] 468 | x_region = int(x_region.item()) 469 | y_region = int(y_region.item()) 470 | else: 471 | x_region = norm_mask.size(-2) // 2 472 | y_region = norm_mask.size(-1) // 2 473 | 474 | resize_order = re / resampling_rate 475 | resize_size_h = int(img_size[-2] * resize_order) 476 | resize_size_w = int(img_size[-1] * resize_order) 477 | 478 | left_size = 0 479 | top_size = 0 480 | right_size = 0 481 | bot_size = 0 482 | if resize_size_h < self.new_size: 483 | top_size = (self.new_size - resize_size_h) // 2 484 | bot_size = (self.new_size - resize_size_h) - top_size 485 | if resize_size_w < self.new_size: 486 | left_size = (self.new_size - resize_size_w) // 2 487 | right_size = (self.new_size - resize_size_w) - left_size 488 | 489 | 490 | # transform_list = [transforms.CenterCrop((crop_size, crop_size))] 491 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] 492 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list 493 | transform_list = [transforms.ToPILImage()] + transform_list 494 | transform = transforms.Compose(transform_list) 495 | img = transform(img) 496 | img = F.to_tensor(np.array(img)) 497 | 498 | ## Define the crop index 499 | if top_size >= 0: 500 | top_crop = 0 501 | else: 502 | if x_region > self.new_size//2: 503 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2): 504 | top_crop = x_region - self.new_size//2 505 | else: 506 | top_crop = norm_mask.size(-2) - self.new_size 507 | else: 508 | top_crop = 0 509 | 510 | if left_size >= 0: 511 | left_crop = 0 512 | else: 513 | if y_region > self.new_size//2: 514 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1): 515 | left_crop = y_region - self.new_size//2 516 | else: 517 | left_crop = norm_mask.size(-1) - self.new_size 518 | else: 519 | left_crop = 0 520 | 521 | # random crop to 224x224 522 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size) 523 | 524 | # resize and center-crop to 280x280 525 | # transform_mask_list = [transforms.CenterCrop((crop_size, crop_size))] 526 | transform_mask_list = [transforms.Pad( 527 | (left_size, top_size, right_size, bot_size))] 528 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w), 529 | interpolation=Image.NEAREST)] + transform_mask_list 530 | transform_mask = transforms.Compose(transform_mask_list) 531 | 532 | mask = transform_mask(mask) # C,H,W 533 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size) 534 | mask = F.to_tensor(np.array(mask)) 535 | 536 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W 537 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1))) 538 | mask = torch.cat((mask, mask_bg), dim=0) 539 | 540 | return img, mask, path_img 541 | 542 | else: 543 | img = F.to_pil_image(img) 544 | # rotate, random angle between 0 - 90 545 | angle = random.randint(0, 90) 546 | img = F.rotate(img, angle, Image.BILINEAR) 547 | 548 | # resize and center-crop to 280x280 549 | resize_order = re / resampling_rate 550 | resize_size_h = int(img_size[-2] * resize_order) 551 | resize_size_w = int(img_size[-1] * resize_order) 552 | 553 | left_size = 0 554 | top_size = 0 555 | right_size = 0 556 | bot_size = 0 557 | if resize_size_h < crop_size: 558 | top_size = (crop_size - resize_size_h) // 2 559 | bot_size = (crop_size - resize_size_h) - top_size 560 | if resize_size_w < crop_size: 561 | left_size = (crop_size - resize_size_w) // 2 562 | right_size = (crop_size - resize_size_w) - left_size 563 | 564 | transform_list = [transforms.CenterCrop((crop_size, crop_size))] 565 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] + transform_list 566 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list 567 | transform = transforms.Compose(transform_list) 568 | 569 | img = transform(img) 570 | 571 | # random crop to 224x224 572 | top_crop = random.randint(0, crop_size - self.new_size) 573 | left_crop = random.randint(0, crop_size - self.new_size) 574 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size) 575 | 576 | # random flip 577 | hflip_p = random.random() 578 | img = F.hflip(img) if hflip_p >= 0.5 else img 579 | vflip_p = random.random() 580 | img = F.vflip(img) if vflip_p >= 0.5 else img 581 | 582 | img = F.to_tensor(np.array(img)) 583 | 584 | # # Gamma correction: random gamma from [0.5, 1.5] 585 | # gamma = 0.5 + random.random() 586 | # img_max = img.max() 587 | # img = img_max*torch.pow((img/img_max), gamma) 588 | 589 | # Gaussian bluring: 590 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))] 591 | transform = transforms.Compose(transform_list) 592 | img = transform(img) 593 | 594 | return img, one_hot_label.squeeze() # pytorch: N,C,H,W 595 | 596 | def __len__(self): 597 | return len(self.imgs) 598 | 599 | 600 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/losses/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def beta_vae_loss(reco_x, x, logvar, mu, beta, type='L2'): 6 | if type == 'BCE': 7 | reco_x_loss = F.binary_cross_entropy(reco_x, x, reduction='sum') 8 | elif type == 'L1': 9 | reco_x_loss = F.l1_loss(reco_x, x, size_average=False) 10 | else: 11 | reco_x_loss = F.mse_loss(reco_x, x, size_average=False) 12 | kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 13 | 14 | return (reco_x_loss + beta*kld)/x.shape[0] 15 | 16 | def KL_divergence(logvar, mu): 17 | kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) 18 | return kld.mean() 19 | 20 | def dice_loss(pred, target): 21 | """ 22 | This definition generalize to real valued pred and target vector. 23 | This should be differentiable. 24 | pred: tensor with first dimension as batch 25 | target: tensor with first dimension as batch 26 | """ 27 | smooth = 0.1 #1e-12 28 | 29 | # have to use contiguous since they may from a torch.view op 30 | iflat = pred.contiguous().view(-1) 31 | tflat = target.contiguous().view(-1) 32 | intersection = (iflat * tflat).sum() 33 | 34 | #A_sum = torch.sum(tflat * iflat) 35 | #B_sum = torch.sum(tflat * tflat) 36 | loss = ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)).mean() 37 | 38 | return 1 - loss 39 | 40 | def HSIC_lossfunc(x, y): 41 | assert x.dim() == y.dim() == 2 42 | assert x.size(0) == y.size(0) 43 | m = x.size(0) 44 | h = torch.eye(m) - 1/m 45 | h = h.to(x.device) 46 | K_x = gaussian_kernel(x) 47 | K_y = gaussian_kernel(y) 48 | return K_x.mm(h).mm(K_y).mm(h).trace() / (m-1+1e-10) 49 | 50 | 51 | def gaussian_kernel(x, y=None, sigma=5): 52 | if y is None: 53 | y = x 54 | assert x.dim() == y.dim() == 2 55 | assert x.size() == y.size() 56 | z = ((x.unsqueeze(0) - y.unsqueeze(1)) ** 2).sum(-1) 57 | return torch.exp(- 0.5 * z / (sigma * sigma)) 58 | 59 | def LS_dis(score_real, score_fake): 60 | return 0.5 * (torch.mean((score_real-1)**2) + torch.mean(score_fake**2)) 61 | 62 | def LS_model(score_fake): 63 | return 0.5 * (torch.mean((score_fake-1)**2)) 64 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .dice_loss import * 2 | from .focal_loss import * 3 | from .gan_loss import * -------------------------------------------------------------------------------- /metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/dice_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/dice_loss.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/focal_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/focal_loss.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/gan_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/gan_loss.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class DiceCoeff(Function): 6 | """Dice coeff for individual examples""" 7 | 8 | def forward(self, input, target): 9 | self.save_for_backward(input, target) 10 | eps = 0.0001 11 | self.inter = torch.dot(input.view(-1), target.view(-1)) 12 | self.union = torch.sum(input) + torch.sum(target) + eps 13 | 14 | t = (2 * self.inter.float() + eps) / self.union.float() 15 | return t 16 | 17 | # This function has only a single output, so it gets only one gradient 18 | def backward(self, grad_output): 19 | 20 | input, target = self.saved_variables 21 | grad_input = grad_target = None 22 | 23 | if self.needs_input_grad[0]: 24 | grad_input = grad_output * 2 * (target * self.union - self.inter) \ 25 | / (self.union * self.union) 26 | if self.needs_input_grad[1]: 27 | grad_target = None 28 | 29 | return grad_input, grad_target 30 | 31 | 32 | def dice_coeff(input, target, device): 33 | """Dice coeff for batches""" 34 | if input.is_cuda: 35 | s = torch.FloatTensor(1).zero_() 36 | s = s.to(device) 37 | else: 38 | s = torch.FloatTensor(1).zero_() 39 | 40 | for i, c in enumerate(zip(input, target)): 41 | s = s + DiceCoeff().forward(c[0], c[1]) 42 | 43 | return s / (i + 1) -------------------------------------------------------------------------------- /metrics/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FocalLoss(nn.Module): 6 | def __init__(self, gamma=0, alpha=None, size_average=True): 7 | super(FocalLoss, self).__init__() 8 | self.gamma = gamma 9 | self.alpha = alpha 10 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 11 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 12 | self.size_average = size_average 13 | 14 | def forward(self, input, target): 15 | if input.dim()>2: 16 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 17 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 18 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 19 | target = target.view(-1,1) 20 | 21 | logpt = F.log_softmax(input, dim=1) 22 | logpt = logpt.gather(1,target) 23 | logpt = logpt.view(-1) 24 | pt = logpt.data.exp() 25 | 26 | if self.alpha is not None: 27 | if self.alpha.type()!=input.data.type(): 28 | self.alpha = self.alpha.type_as(input.data) 29 | at = self.alpha.gather(0,target.data.view(-1)) 30 | logpt = logpt * at 31 | 32 | loss = -1 * (1-pt)**self.gamma * logpt 33 | if self.size_average: return loss.mean() 34 | else: return loss.sum() -------------------------------------------------------------------------------- /metrics/gan_loss.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch 4 | 5 | 6 | def ls_discriminator_loss(scores_real, scores_fake): 7 | """ 8 | Compute the Least-Squares GAN metrics for the discriminator. 9 | 10 | Inputs: 11 | - scores_real: PyTorch Variable of shape (N,) giving scores for the real data. 12 | - scores_fake: PyTorch Variable of shape (N,) giving scores for the fake data. 13 | 14 | Outputs: 15 | - metrics: A PyTorch Variable containing the metrics. 16 | """ 17 | loss = (torch.mean((scores_real - 1) ** 2) + torch.mean(scores_fake ** 2)) / 2 18 | return loss 19 | 20 | 21 | def ls_generator_loss(scores_fake): 22 | """ 23 | Computes the Least-Squares GAN metrics for the generator. 24 | 25 | Inputs: 26 | - scores_fake: PyTorch Variable of shape (N,) giving scores for the fake data. 27 | 28 | Outputs: 29 | - metrics: A PyTorch Variable containing the metrics. 30 | """ 31 | loss = torch.mean((scores_fake - 1) ** 2) / 2 32 | return loss 33 | 34 | 35 | # %% 36 | 37 | def bce_loss(input, target): 38 | """ 39 | Numerically stable version of the binary cross-entropy metrics function. 40 | 41 | As per https://github.com/pytorch/pytorch/issues/751 42 | See the TensorFlow docs for a derivation of this formula: 43 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits 44 | 45 | Inputs: 46 | - input: PyTorch Variable of shape (N, ) giving scores. 47 | - target: PyTorch Variable of shape (N,) containing 0 and 1 giving targets. 48 | 49 | Returns: 50 | - A PyTorch Variable containing the mean BCE metrics over the minibatch of input data. 51 | """ 52 | neg_abs = - input.abs() 53 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 54 | return loss.mean() 55 | 56 | 57 | # %% 58 | 59 | def discriminator_loss(logits_real, logits_fake, device): 60 | """ 61 | Computes the discriminator metrics described above. 62 | 63 | Inputs: 64 | - logits_real: PyTorch Variable of shape (N,) giving scores for the real data. 65 | - logits_fake: PyTorch Variable of shape (N,) giving scores for the fake data. 66 | 67 | Returns: 68 | - metrics: PyTorch Variable containing (scalar) the metrics for the discriminator. 69 | """ 70 | true_labels = torch.ones(logits_real.size()).to(device=device, dtype=torch.float32) 71 | loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, true_labels - 1) 72 | return loss 73 | 74 | 75 | def generator_loss(logits_fake, device): 76 | """ 77 | Computes the generator metrics described above. 78 | 79 | Inputs: 80 | - logits_fake: PyTorch Variable of shape (N,) giving scores for the fake data. 81 | 82 | Returns: 83 | - metrics: PyTorch Variable containing the (scalar) metrics for the generator. 84 | """ 85 | true_labels = torch.ones(logits_fake.size()).to(device=device, dtype=torch.float32) 86 | loss = bce_loss(logits_fake, true_labels) 87 | return loss -------------------------------------------------------------------------------- /metrics/hausdorff.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.distance import directed_hausdorff 2 | 3 | def hausdorff_distance(x, y): 4 | x = x.cpu().data.numpy() 5 | u = x.reshape(x.shape[1], -1) 6 | y = y.cpu().data.numpy() 7 | v = y.reshape(y.shape[1], -1) 8 | return max(directed_hausdorff(u, v)[0], directed_hausdorff(v, u)[0]) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgnet import * 2 | from .weight_init import * 3 | 4 | 5 | import sys 6 | 7 | def get_model(name, params): 8 | if name == 'dgnet': 9 | return DGNet(params['width'], params['height'], params['num_classes'], params['ndf'], params['z_length'], 10 | params['norm'], params['upsample'], params['decoder_type'], params['anatomy_out_channels'], 11 | params['num_mask_channels']) 12 | else: 13 | print("Could not find the requested model ({})".format(name), file=sys.stderr) -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/dgnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/dgnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/meta_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/meta_segmentor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_segmentor.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/meta_styleencoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_styleencoder.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/meta_unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/sdnet_ada.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/sdnet_ada.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_parts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/unet_parts.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/weight_init.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/weight_init.cpython-37.pyc -------------------------------------------------------------------------------- /models/dgnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | import time 6 | from models.meta_segmentor import * 7 | from models.meta_styleencoder import * 8 | from models.meta_decoder import * 9 | 10 | 11 | 12 | class DGNet(nn.Module): 13 | def __init__(self, width, height, num_classes, ndf, z_length, norm, upsample, decoder_type, anatomy_out_channels, num_mask_channels): 14 | super(DGNet, self).__init__() 15 | """ 16 | Args: 17 | width: input width 18 | height: input height 19 | upsample: upsampling type (nearest | bilateral) 20 | nclasses: number of semantice segmentation classes 21 | """ 22 | self.h = height 23 | self.w = width 24 | self.ndf = ndf 25 | self.z_length = z_length 26 | self.anatomy_out_channels = anatomy_out_channels 27 | self.norm = norm 28 | self.upsample = upsample 29 | self.num_classes = num_classes 30 | self.decoder_type = decoder_type 31 | self.num_mask_channels = num_mask_channels 32 | 33 | self.m_encoder = StyleEncoder(z_length*2) 34 | self.a_encoder = ContentEncoder(self.h, self.w, self.ndf, self.anatomy_out_channels, self.norm, self.upsample) 35 | # self.segmentor = Segmentor(self.anatomy_out_channels, self.num_classes) 36 | self.decoder = Ada_Decoder(self.decoder_type, self.anatomy_out_channels, self.z_length*2, self.num_mask_channels) 37 | 38 | def forward(self, x, mask, script_type, meta_loss=None, meta_step_size=0.001, stop_gradient=False, a_in=None, z_in=None): 39 | self.meta_loss = meta_loss 40 | self.meta_step_size = meta_step_size 41 | self.stop_gradient = stop_gradient 42 | 43 | z_out, mu_out, logvar_out, cls_out= self.m_encoder(x, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 44 | stop_gradient=self.stop_gradient) 45 | a_out = self.a_encoder(x, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 46 | stop_gradient=self.stop_gradient) 47 | # seg_pred = self.segmentor(a_out, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 48 | # stop_gradient=self.stop_gradient) 49 | 50 | z_out_tilede = None 51 | cls_out_tild = None 52 | 53 | #t0 = time.time() 54 | if a_in is None: 55 | if script_type == 'training': 56 | reco = self.decoder(a_out, z_out, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 57 | stop_gradient=self.stop_gradient) 58 | z_out_tilede, mu_out_tilde, logvar_tilde, cls_out_tild = self.m_encoder(reco, meta_loss=self.meta_loss, 59 | meta_step_size=self.meta_step_size, 60 | stop_gradient=self.stop_gradient) 61 | elif script_type == 'val' or script_type == 'test': 62 | z_out_tilede, mu_out_tilde, logvar_tilde, cls_out_tild = self.m_encoder(x, meta_loss=self.meta_loss, 63 | meta_step_size=self.meta_step_size, 64 | stop_gradient=self.stop_gradient) 65 | reco = self.decoder(a_out, z_out_tilede, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 66 | stop_gradient=self.stop_gradient) 67 | else: 68 | reco = self.decoder(a_in, z_in, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 69 | stop_gradient=self.stop_gradient) 70 | 71 | return reco, z_out, z_out_tilede, a_out, None, mu_out, logvar_out, cls_out, cls_out_tild 72 | -------------------------------------------------------------------------------- /models/meta_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.ops import * 5 | 6 | class LayerNorm(nn.Module): 7 | def __init__(self, num_features, eps=1e-5, affine=True): 8 | super(LayerNorm, self).__init__() 9 | self.num_features = num_features 10 | self.affine = affine 11 | self.eps = eps 12 | 13 | if self.affine: 14 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 15 | self.beta = nn.Parameter(torch.zeros(num_features)) 16 | 17 | def forward(self, x): 18 | shape = [-1] + [1] * (x.dim() - 1) 19 | # print(x.size()) 20 | if x.size(0) == 1: 21 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 22 | mean = x.view(-1).mean().view(*shape) 23 | std = x.view(-1).std().view(*shape) 24 | else: 25 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 26 | std = x.view(x.size(0), -1).std(1).view(*shape) 27 | 28 | x = (x - mean) / (std + self.eps) 29 | 30 | if self.affine: 31 | shape = [1, -1] + [1] * (x.dim() - 2) 32 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 33 | return x 34 | 35 | class MLP(nn.Module): 36 | def __init__(self, input_dim, output_dim, dim, n_blk): 37 | super(MLP, self).__init__() 38 | 39 | self.fc1 = nn.Linear(input_dim, dim) 40 | self.fc2 = nn.Linear(dim, dim) 41 | self.fc3 = nn.Linear(dim, output_dim) 42 | 43 | def forward(self, x, meta_loss, meta_step_size, stop_gradient): 44 | self.meta_loss = meta_loss 45 | self.meta_step_size = meta_step_size 46 | self.stop_gradient = stop_gradient 47 | 48 | x = x.view(x.size(0), -1) 49 | 50 | out = linear(x, self.fc1.weight, self.fc1.bias, meta_loss=self.meta_loss, 51 | meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 52 | out = relu(out) 53 | out = linear(out, self.fc2.weight, self.fc2.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 54 | stop_gradient=self.stop_gradient) 55 | out = relu(out) 56 | out = linear(out, self.fc3.weight, self.fc3.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 57 | stop_gradient=self.stop_gradient) 58 | out = relu(out) 59 | 60 | return out 61 | 62 | class AdaptiveInstanceNorm2d(nn.Module): 63 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 64 | super(AdaptiveInstanceNorm2d, self).__init__() 65 | self.num_features = num_features 66 | self.eps = eps 67 | self.momentum = momentum 68 | # weight and bias are dynamically assigned 69 | self.weight = None 70 | self.bias = None 71 | # just dummy buffers, not used 72 | self.register_buffer('running_mean', torch.zeros(num_features)) 73 | self.register_buffer('running_var', torch.ones(num_features)) 74 | 75 | def forward(self, x): 76 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 77 | b, c = x.size(0), x.size(1) 78 | running_mean = self.running_mean.repeat(b) 79 | running_var = self.running_var.repeat(b) 80 | 81 | # Apply instance norm 82 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 83 | 84 | out = F.batch_norm( 85 | x_reshaped, running_mean, running_var, self.weight, self.bias, 86 | True, self.momentum, self.eps) 87 | 88 | return out.view(b, c, *x.size()[2:]) 89 | 90 | def __repr__(self): 91 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 92 | 93 | class Decoder(nn.Module): 94 | def __init__(self, dim, output_dim=1): 95 | super(Decoder, self).__init__() 96 | 97 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1, bias=True) 98 | self.adain1 = AdaptiveInstanceNorm2d(dim) 99 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1, bias=True) 100 | self.adain2 = AdaptiveInstanceNorm2d(dim) 101 | 102 | self.conv3 = nn.Conv2d(dim, dim//2, 3, 1, 1, bias=True) 103 | dim //= 2 104 | # self.ln3 = LayerNorm(dim) 105 | self.bn3 = normalization(dim, norm='bn') 106 | self.conv4 = nn.Conv2d(dim, output_dim, 3, 1, 1, bias=True) 107 | 108 | def forward(self, x, meta_loss, meta_step_size, stop_gradient): 109 | self.meta_loss = meta_loss 110 | self.meta_step_size = meta_step_size 111 | self.stop_gradient = stop_gradient 112 | 113 | # x = F.softmax(x, dim=1) 114 | 115 | out = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, 116 | meta_step_size=self.meta_step_size, 117 | stop_gradient=self.stop_gradient) 118 | out = self.adain1(out) 119 | out = conv2d(out, self.conv2.weight, self.conv2.bias, stride=1, padding=1, meta_loss=self.meta_loss, 120 | meta_step_size=self.meta_step_size, 121 | stop_gradient=self.stop_gradient) 122 | out = self.adain2(out) 123 | out = conv2d(out, self.conv3.weight, self.conv3.bias, stride=1, padding=1, meta_loss=self.meta_loss, 124 | meta_step_size=self.meta_step_size, 125 | stop_gradient=self.stop_gradient) 126 | # out = self.bn3(out) 127 | out = conv2d(out, self.conv4.weight, self.conv4.bias, stride=1, padding=1, meta_loss=self.meta_loss, 128 | meta_step_size=self.meta_step_size, 129 | stop_gradient=self.stop_gradient) 130 | out = tanh(out) 131 | return out 132 | 133 | # decoder 134 | class Ada_Decoder(nn.Module): 135 | # AdaIN auto-encoder architecture 136 | def __init__(self, decoder_type, anatomy_out_channels, z_length, num_mask_channels): 137 | super(Ada_Decoder, self).__init__() 138 | """ 139 | """ 140 | self.dec = Decoder(anatomy_out_channels) 141 | # MLP to generate AdaIN parameters 142 | self.mlp = MLP(z_length, self.get_num_adain_params(self.dec), 256, 3) 143 | 144 | def forward(self, a, z, meta_loss, meta_step_size, stop_gradient): 145 | self.meta_loss = meta_loss 146 | self.meta_step_size = meta_step_size 147 | self.stop_gradient = stop_gradient 148 | # reconstruct an image 149 | images_recon = self.decode(a, z, meta_loss=self.meta_loss, 150 | meta_step_size=self.meta_step_size, 151 | stop_gradient=self.stop_gradient) 152 | return images_recon 153 | 154 | def decode(self, content, style, meta_loss, meta_step_size, stop_gradient): 155 | self.meta_loss = meta_loss 156 | self.meta_step_size = meta_step_size 157 | self.stop_gradient = stop_gradient 158 | # decode content and style codes to an image 159 | adain_params = self.mlp(style, meta_loss=self.meta_loss, 160 | meta_step_size=self.meta_step_size, 161 | stop_gradient=self.stop_gradient) 162 | self.assign_adain_params(adain_params, self.dec) 163 | images = self.dec(content, meta_loss=self.meta_loss, 164 | meta_step_size=self.meta_step_size, 165 | stop_gradient=self.stop_gradient) 166 | return images 167 | 168 | def assign_adain_params(self, adain_params, model): 169 | # assign the adain_params to the AdaIN layers in model 170 | for m in model.modules(): 171 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 172 | mean = adain_params[:, :m.num_features] 173 | std = adain_params[:, m.num_features:2*m.num_features] 174 | m.bias = mean.contiguous().view(-1) 175 | m.weight = std.contiguous().view(-1) 176 | if adain_params.size(1) > 2*m.num_features: 177 | adain_params = adain_params[:, 2*m.num_features:] 178 | 179 | def get_num_adain_params(self, model): 180 | # return the number of AdaIN parameters needed by the model 181 | num_adain_params = 0 182 | for m in model.modules(): 183 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 184 | num_adain_params += 2*m.num_features 185 | return num_adain_params -------------------------------------------------------------------------------- /models/meta_segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.ops import * 5 | from models.meta_unet import * 6 | 7 | class ContentEncoder(nn.Module): 8 | def __init__(self, width, height, ndf, num_output_channels, norm, upsample): 9 | super(ContentEncoder, self).__init__() 10 | """ 11 | Build an encoder to extract anatomical information from the image. 12 | """ 13 | self.width = width 14 | self.height = height 15 | self.ndf = ndf 16 | self.num_output_channels = num_output_channels 17 | self.norm = norm 18 | self.upsample = upsample 19 | 20 | self.unet = UNet(c=1, n=32, norm='in', num_classes=self.num_output_channels) 21 | 22 | def forward(self, x, meta_loss, meta_step_size, stop_gradient): 23 | self.meta_loss = meta_loss 24 | self.meta_step_size = meta_step_size 25 | self.stop_gradient = stop_gradient 26 | 27 | out = self.unet(x, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 28 | return out 29 | 30 | class Segmentor(nn.Module): 31 | def __init__(self, num_output_channels, num_classes): 32 | super(Segmentor, self).__init__() 33 | """ 34 | """ 35 | self.num_output_channels = num_output_channels 36 | self.num_classes = num_classes+1 # check again 37 | 38 | self.conv1 = nn.Conv2d(self.num_output_channels, 64, 3, 1, 1, bias=True) 39 | self.bn1 = normalization(64, norm='bn') 40 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=True) 41 | self.bn2 = normalization(64, norm='bn') 42 | self.pred = nn.Conv2d(64, self.num_classes, 1, 1, 0) 43 | 44 | def forward(self, x, meta_loss, meta_step_size, stop_gradient): 45 | self.meta_loss = meta_loss 46 | self.meta_step_size = meta_step_size 47 | self.stop_gradient = stop_gradient 48 | 49 | out = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 50 | stop_gradient=self.stop_gradient) 51 | out = self.bn1(out) 52 | out = relu(out) 53 | out = conv2d(out, self.conv2.weight, self.conv2.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 54 | stop_gradient=self.stop_gradient) 55 | out = self.bn2(out) 56 | out = relu(out) 57 | out = conv2d(out, self.pred.weight, self.pred.bias, stride=1, padding=0, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 58 | 59 | return out -------------------------------------------------------------------------------- /models/meta_styleencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.ops import * 5 | 6 | class StyleEncoder(nn.Module): 7 | def __init__(self, style_dim): 8 | super(StyleEncoder, self).__init__() 9 | dim = 8 10 | self.style_dim = style_dim // 2 11 | self.conv1 = nn.Conv2d(1, dim, 7, 1, 3, bias=True) 12 | self.conv2 = nn.Conv2d(dim, dim*2, 4, 2, 1, bias=True) 13 | self.conv3 = nn.Conv2d(dim*2, dim*4, 4, 2, 1, bias=True) 14 | self.conv4 = nn.Conv2d(dim*4, dim*8, 4, 2, 1, bias=True) 15 | self.conv5 = nn.Conv2d(dim*8, dim*16, 4, 2, 1, bias=True) 16 | self.conv6 = nn.Conv2d(dim*16, dim*32, 4, 2, 1, bias=True) 17 | 18 | self.fc1 = nn.Linear(256*9*9, 4*9*9) 19 | self.fc2 = nn.Linear(4*9*9, 32) 20 | self.mu = nn.Linear(32, style_dim) 21 | self.logvar = nn.Linear(32, style_dim) 22 | self.classifier = nn.Linear(self.style_dim, 3) 23 | 24 | def reparameterize(self, mu, logvar): 25 | std = torch.exp(0.5 * logvar) 26 | eps = torch.randn_like(std) 27 | 28 | return mu + eps * std 29 | 30 | def forward(self, x, meta_loss, meta_step_size, stop_gradient): 31 | self.meta_loss = meta_loss 32 | self.meta_step_size = meta_step_size 33 | self.stop_gradient = stop_gradient 34 | 35 | out = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=3, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 36 | stop_gradient=self.stop_gradient) 37 | out = conv2d(out, self.conv2.weight, self.conv2.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 38 | stop_gradient=self.stop_gradient) 39 | out = conv2d(out, self.conv3.weight, self.conv3.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 40 | stop_gradient=self.stop_gradient) 41 | out = conv2d(out, self.conv4.weight, self.conv4.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 42 | stop_gradient=self.stop_gradient) 43 | out = conv2d(out, self.conv5.weight, self.conv5.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 44 | stop_gradient=self.stop_gradient) 45 | out = conv2d(out, self.conv6.weight, self.conv6.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 46 | stop_gradient=self.stop_gradient) 47 | 48 | out = linear(out.view(-1, out.shape[1] * out.shape[2] * out.shape[3]), self.fc1.weight, self.fc1.bias, meta_loss=self.meta_loss, 49 | meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 50 | out = lrelu(out) 51 | out = linear(out, self.fc2.weight, self.fc2.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 52 | stop_gradient=self.stop_gradient) 53 | out = lrelu(out) 54 | 55 | mu = linear(out, self.mu.weight, self.mu.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 56 | stop_gradient=self.stop_gradient) 57 | logvar = linear(out, self.logvar.weight, self.logvar.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 58 | stop_gradient=self.stop_gradient) 59 | 60 | zs = self.reparameterize(mu[:,self.style_dim:], logvar[:,self.style_dim:]) 61 | zd = self.reparameterize(mu[:,:self.style_dim], logvar[:,:self.style_dim]) 62 | z = torch.cat((zs,zd), dim=1) 63 | 64 | cls = linear(z[:,:self.style_dim], self.classifier.weight, self.classifier.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, 65 | stop_gradient=self.stop_gradient) 66 | cls = F.softmax(cls, dim=1) 67 | return z, mu, logvar, cls -------------------------------------------------------------------------------- /models/meta_unet.py: -------------------------------------------------------------------------------- 1 | 2 | """ Parts of the U-Net model """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from models.ops import * 8 | 9 | class ConvD(nn.Module): 10 | def __init__(self, inplanes, planes, meta_loss, meta_step_size, stop_gradient, norm='in', first=False): 11 | super(ConvD, self).__init__() 12 | 13 | self.meta_loss = meta_loss 14 | self.meta_step_size = meta_step_size 15 | self.stop_gradient = stop_gradient 16 | 17 | self.first = first 18 | 19 | self.conv1 = nn.Conv2d(inplanes, planes, 3, 1, 1, bias=True) 20 | self.in1 = normalization(planes, norm) 21 | 22 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=True) 23 | self.in2 = normalization(planes, norm) 24 | 25 | self.conv3 = nn.Conv2d(planes, planes, 3, 1, 1, bias=True) 26 | self.in3 = normalization(planes, norm) 27 | 28 | def forward(self, x): 29 | 30 | if not self.first: 31 | x = maxpool2D(x, kernel_size=2) 32 | 33 | #layer 1 conv, in 34 | x = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 35 | x = self.in1(x) 36 | 37 | #layer 2 conv, in, lrelu 38 | y = conv2d(x, self.conv2.weight, self.conv2.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 39 | y = self.in2(y) 40 | y = lrelu(y) 41 | 42 | #layer 3 conv, in, lrelu 43 | z = conv2d(y, self.conv3.weight, self.conv3.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 44 | z = self.in3(z) 45 | z = lrelu(z) 46 | 47 | return z 48 | 49 | class ConvU(nn.Module): 50 | def __init__(self, planes, meta_loss, meta_step_size, stop_gradient, norm='in', first=False): 51 | super(ConvU, self).__init__() 52 | 53 | self.meta_loss = meta_loss 54 | self.meta_step_size = meta_step_size 55 | self.stop_gradient = stop_gradient 56 | 57 | self.first = first 58 | if not self.first: 59 | self.conv1 = nn.Conv2d(2*planes, planes, 3, 1, 1, bias=True) 60 | self.in1 = normalization(planes, norm) 61 | 62 | self.conv2 = nn.Conv2d(planes, planes//2, 1, 1, 0, bias=True) 63 | self.in2 = normalization(planes//2, norm) 64 | 65 | self.conv3 = nn.Conv2d(planes, planes, 3, 1, 1, bias=True) 66 | self.in3 = normalization(planes, norm) 67 | 68 | def forward(self, x, prev): 69 | #layer 1 conv, in, lrelu 70 | if not self.first: 71 | x = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 72 | x = self.in1(x) 73 | x = lrelu(x) 74 | 75 | #upsample, layer 2 conv, bn, relu 76 | y = upsample(x) 77 | y = conv2d(y, self.conv2.weight, self.conv2.bias, stride=1, padding=0, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 78 | y = self.in2(y) 79 | y = lrelu(y) 80 | 81 | #concatenation of two layers 82 | y = torch.cat([prev, y], 1) 83 | 84 | #layer 3 conv, bn 85 | y = conv2d(y, self.conv3.weight, self.conv3.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient) 86 | y = self.in3(y) 87 | y = lrelu(y) 88 | 89 | return y 90 | 91 | class UNet(nn.Module): 92 | def __init__(self, c, n, num_classes, norm='in'): 93 | super(UNet, self).__init__() 94 | 95 | meta_loss = None 96 | meta_step_size = 0.01 97 | stop_gradient = False 98 | 99 | self.convd1 = ConvD(c, n, meta_loss, meta_step_size, stop_gradient, norm, first=True) 100 | self.convd2 = ConvD(n, 2*n, meta_loss, meta_step_size, stop_gradient, norm) 101 | self.convd3 = ConvD(2*n, 4*n, meta_loss, meta_step_size, stop_gradient, norm) 102 | self.convd4 = ConvD(4*n, 8*n, meta_loss, meta_step_size, stop_gradient, norm) 103 | self.convd5 = ConvD(8*n,16*n, meta_loss, meta_step_size, stop_gradient, norm) 104 | 105 | self.convu4 = ConvU(16*n, meta_loss, meta_step_size, stop_gradient, norm, first=True) 106 | self.convu3 = ConvU(8*n, meta_loss, meta_step_size, stop_gradient, norm) 107 | self.convu2 = ConvU(4*n, meta_loss, meta_step_size, stop_gradient, norm) 108 | self.convu1 = ConvU(2*n, meta_loss, meta_step_size, stop_gradient, norm) 109 | 110 | self.seg1 = nn.Conv2d(2*n, num_classes, 1) 111 | 112 | def forward(self, x, meta_loss, meta_step_size, stop_gradient): 113 | self.meta_loss = meta_loss 114 | self.meta_step_size = meta_step_size 115 | self.stop_gradient = stop_gradient 116 | 117 | x1 = self.convd1(x) 118 | x2 = self.convd2(x1) 119 | x3 = self.convd3(x2) 120 | x4 = self.convd4(x3) 121 | x5 = self.convd5(x4) 122 | 123 | y4 = self.convu4(x5, x4) 124 | y3 = self.convu3(y4, x3) 125 | y2 = self.convu2(y3, x2) 126 | y1 = self.convu1(y2, x1) 127 | 128 | y1 = conv2d(y1, self.seg1.weight, self.seg1.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient, kernel_size=None, stride=1, padding=0) 129 | 130 | return y1 -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | def normalization(planes, norm='in'): 8 | if norm == 'bn': 9 | m = nn.BatchNorm2d(planes) 10 | elif norm == 'gn': 11 | m = nn.GroupNorm(1, planes) 12 | elif norm == 'in': 13 | m = nn.InstanceNorm2d(planes) 14 | else: 15 | raise ValueError('normalization type {} is not supported'.format(norm)) 16 | return m 17 | 18 | def linear(inputs, weight, bias, meta_step_size=0.001, meta_loss=None, stop_gradient=False): 19 | inputs = inputs 20 | weight = weight 21 | bias = bias 22 | 23 | if meta_loss is not None: 24 | 25 | if not stop_gradient: 26 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True) [0] 27 | 28 | if bias is not None: 29 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True) [0] 30 | bias_adapt = bias - grad_bias * meta_step_size 31 | else: 32 | bias_adapt = bias 33 | 34 | else: 35 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0].data, requires_grad=False) 36 | 37 | if bias is not None: 38 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0].data, requires_grad=False) 39 | bias_adapt = bias - grad_bias * meta_step_size 40 | else: 41 | bias_adapt = bias 42 | 43 | return F.linear(inputs, 44 | weight - grad_weight * meta_step_size, 45 | bias_adapt) 46 | else: 47 | return F.linear(inputs, weight, bias) 48 | 49 | 50 | def conv2d(inputs, weight, bias, meta_step_size=0.001, stride=1, padding=0, dilation=1, groups=1, meta_loss=None, 51 | stop_gradient=False, kernel_size=None): 52 | 53 | inputs = inputs 54 | weight = weight 55 | bias = bias 56 | 57 | 58 | if meta_loss is not None: 59 | 60 | if not stop_gradient: 61 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0] 62 | 63 | if bias is not None: 64 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0] 65 | if grad_bias is not None: 66 | bias_adapt = bias - grad_bias * meta_step_size 67 | else: 68 | bias_adapt = bias 69 | else: 70 | bias_adapt = bias 71 | 72 | else: 73 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0].data, 74 | requires_grad=False) 75 | if bias is not None: 76 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0].data, requires_grad=False) 77 | bias_adapt = bias - grad_bias * meta_step_size 78 | else: 79 | bias_adapt = bias 80 | if grad_weight is not None: 81 | weight_adapt = weight - grad_weight * meta_step_size 82 | else: 83 | weight_adapt = weight 84 | 85 | return F.conv2d(inputs, 86 | weight_adapt, 87 | bias_adapt, stride, 88 | padding, 89 | dilation, groups) 90 | else: 91 | return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups) 92 | 93 | 94 | def deconv2d(inputs, weight, bias, meta_step_size=0.001, stride=2, padding=0, dilation=0, groups=1, meta_loss=None, 95 | stop_gradient=False, kernel_size=None): 96 | 97 | inputs = inputs 98 | weight = weight 99 | bias = bias 100 | 101 | 102 | if meta_loss is not None: 103 | 104 | if not stop_gradient: 105 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0] 106 | 107 | if bias is not None: 108 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0] 109 | bias_adapt = bias - grad_bias * meta_step_size 110 | else: 111 | bias_adapt = bias 112 | 113 | else: 114 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0].data, 115 | requires_grad=False) 116 | if bias is not None: 117 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0].data, requires_grad=False) 118 | bias_adapt = bias - grad_bias * meta_step_size 119 | else: 120 | bias_adapt = bias 121 | 122 | return F.conv_transpose2d(inputs, 123 | weight - grad_weight * meta_step_size, 124 | bias_adapt, stride, 125 | padding, 126 | dilation, groups) 127 | else: 128 | return F.conv_transpose2d(inputs, weight, bias, stride, padding, dilation, groups) 129 | 130 | def tanh(inputs): 131 | return torch.tanh(inputs) 132 | 133 | def relu(inputs): 134 | return F.relu(inputs, inplace=True) 135 | 136 | def lrelu(inputs): 137 | return F.leaky_relu(inputs, negative_slope=0.01, inplace=False) 138 | 139 | def maxpool(inputs, kernel_size, stride=None, padding=0): 140 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 141 | 142 | def dropout(inputs): 143 | return F.dropout(inputs, p=0.5, training=False, inplace=False) 144 | 145 | def batchnorm(inputs, running_mean, running_var): 146 | return F.batch_norm(inputs, running_mean, running_var) 147 | 148 | def instancenorm(input): 149 | return F.instance_norm(input) 150 | 151 | def groupnorm(input): 152 | return F.group_norm(input) 153 | 154 | def dropout2D(inputs): 155 | return F.dropout2d(inputs, p=0.5, training=False, inplace=False) 156 | 157 | def maxpool2D(inputs, kernel_size, stride=None, padding=0): 158 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 159 | 160 | def upsample(input): 161 | return F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False) 162 | -------------------------------------------------------------------------------- /models/unet_parts.py: -------------------------------------------------------------------------------- 1 | 2 | """ Parts of the U-Net model """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.InstanceNorm2d(mid_channels, momentum=0.1, affine=True), 18 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.InstanceNorm2d(out_channels, momentum=0.1, affine=True), 21 | nn.LeakyReLU(negative_slope=0.01, inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | -------------------------------------------------------------------------------- /models/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | import sys 5 | 6 | def init_dcgan_weights(model): 7 | for m in model.modules(): 8 | if isinstance(m, torch.nn.Conv2d): 9 | m.weight.data.normal_(0.0, 0.02) 10 | if m.bias is not None: 11 | m.bias.data.zero_() 12 | 13 | def initialize_weights(model, init = "xavier"): 14 | init_func = None 15 | if init == "xavier": 16 | init_func = torch.nn.init.xavier_normal_ 17 | elif init == "kaiming": 18 | init_func = torch.nn.init.kaiming_normal_ 19 | elif init == "gaussian" or init == "normal": 20 | init_func = torch.nn.init.normal_ 21 | 22 | if init_func is not None: 23 | for module in model.modules(): 24 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 25 | init_func(module.weight) 26 | if module.bias is not None: 27 | module.bias.data.zero_() 28 | elif isinstance(module, torch.nn.BatchNorm2d): 29 | module.weight.data.fill_(1) 30 | module.bias.data.zero_() 31 | else: 32 | print("Error when initializing model's weights, {} either doesn't exist or is not a valid initialization function.".format(init), \ 33 | file=sys.stderr) -------------------------------------------------------------------------------- /preprocess/save_MNMS_2D.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import argparse 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import configparser 7 | import xlrd 8 | 9 | def safe_mkdir(path): 10 | try: 11 | os.makedirs(path) 12 | except OSError: 13 | pass 14 | 15 | def save_mask(img, path, name): 16 | h = img.shape[0] 17 | w = img.shape[1] 18 | img_1 = (img >= 0.5) * (img<1.5) * 255 19 | img_2 = (img >= 1.5) * (img<2.5) * 255 20 | img_3 = (img >= 2.5) * 255 21 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2) 22 | img = Image.fromarray(img.astype(np.uint8), 'RGB') 23 | img.save(os.path.join(path, name)) 24 | 25 | def save_image(img, path, name): 26 | if img.min() < 0 or img.max() <= 0: 27 | pass 28 | else: 29 | img -= img.min() 30 | img /= img.max() 31 | 32 | img *= 255 33 | img = Image.fromarray(img.astype(np.uint8), 'L') 34 | img.save(os.path.join(path, name)) 35 | 36 | def save_np(img, path, name): 37 | np.savez_compressed(os.path.join(path, name), img) 38 | 39 | def save_mask_np(img, path, name): 40 | h = img.shape[0] 41 | w = img.shape[1] 42 | img_1 = (img >= 0.5) * (img<1.5) 43 | img_2 = (img >= 1.5) * (img<2.5) 44 | img_3 = (img >= 2.5) 45 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2) 46 | np.savez_compressed(os.path.join(path, name), img) 47 | 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--LabeledVendorA', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorA/', help='The root of dataset.') 51 | 52 | parser.add_argument('--LabeledVendorBcenter2', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center2/', help='The root of dataset.') 53 | parser.add_argument('--LabeledVendorBcenter3', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center3/', help='The root of dataset.') 54 | 55 | parser.add_argument('--LabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorC/', help='The root of dataset.') 56 | 57 | parser.add_argument('--LabeledVendorD', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorD/', help='The root of dataset.') 58 | 59 | parser.add_argument('--UnlabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Unlabeled/vendorC/', help='The root of dataset.') 60 | 61 | arg = parser.parse_args() 62 | 63 | ###################################################################################################### 64 | #Load excel information: 65 | # cell_value(1,0) -> cell_value(175,0) 66 | ex_file = '/home/s1575424/xiao/Year2/mnms_split_data/mnms_dataset_info.xlsx' 67 | wb = xlrd.open_workbook(ex_file) 68 | sheet = wb.sheet_by_index(0) 69 | # sheet.cell_value(r, c) 70 | 71 | # Save data dirs 72 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorA/' 73 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorA/' 74 | 75 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorB/center2/' 76 | LabeledVendorB2_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorB/center2/' 77 | 78 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorB/center3/' 79 | LabeledVendorB3_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorB/center3/' 80 | 81 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorC/' 82 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorC/' 83 | 84 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorD/' 85 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorD/' 86 | 87 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Unlabeled/vendorC/' 88 | 89 | 90 | # Load all the data names 91 | LabeledVendorA_names = sorted(os.listdir(arg.LabeledVendorA)) 92 | 93 | LabeledVendorB2_names = sorted(os.listdir(arg.LabeledVendorBcenter2)) 94 | LabeledVendorB3_names = sorted(os.listdir(arg.LabeledVendorBcenter3)) 95 | 96 | LabeledVendorC_names = sorted(os.listdir(arg.LabeledVendorC)) 97 | 98 | LabeledVendorD_names = sorted(os.listdir(arg.LabeledVendorD)) 99 | 100 | UnlabeledVendorC_names = sorted(os.listdir(arg.UnlabeledVendorC)) 101 | 102 | #### Output: non-normed, non-cropped, no HW transposed npz file. 103 | 104 | ###################################################################################################### 105 | # Load LabeledVendorA data and save them to 2D images 106 | for num_pat in range(0, len(LabeledVendorA_names)): 107 | gz_name = sorted(os.listdir(arg.LabeledVendorA+LabeledVendorA_names[num_pat]+'/')) 108 | patient_root = arg.LabeledVendorA+LabeledVendorA_names[num_pat] + '/' + gz_name[0] 109 | img = nib.load(patient_root) 110 | img_np = img.get_fdata() 111 | 112 | # p5 = np.percentile(img_np.flatten(), 5) 113 | # p95 = np.percentile(img_np.flatten(), 95) 114 | # img_np = np.clip(img_np, p5, p95) 115 | 116 | print('patient%03d...' % num_pat) 117 | save_labeled_data_root = LabeledVendorA_data_dir 118 | 119 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 120 | 121 | if num_pat == 0: 122 | print(type(img_np[0,0,0,0])) 123 | 124 | ## save each image of the patient 125 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 126 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 127 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 128 | 129 | 130 | ###################################################################################################### 131 | # Load LabeledVendorA mask and save them to 2D images 132 | for num_pat in range(0, len(LabeledVendorA_names)): 133 | gz_name = sorted(os.listdir(arg.LabeledVendorA+LabeledVendorA_names[num_pat]+'/')) 134 | patient_root = arg.LabeledVendorA+LabeledVendorA_names[num_pat] + '/' + gz_name[1] 135 | img = nib.load(patient_root) 136 | img_np = img.get_fdata() 137 | 138 | print('patient%03d...' % num_pat) 139 | 140 | save_mask_root = LabeledVendorA_mask_dir 141 | 142 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 143 | 144 | for num_row in range(1, 346): 145 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]: 146 | ED = sheet.cell_value(num_row, 4) 147 | ES = sheet.cell_value(num_row, 5) 148 | break 149 | 150 | ## save masks of the patient 151 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 152 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 153 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES: 154 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice)) 155 | 156 | ###################################################################################################### 157 | # Load LabeledVendorB2 data and save them to 2D images 158 | for num_pat in range(0, len(LabeledVendorB2_names)): 159 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat]+'/')) 160 | patient_root = arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat] + '/' + gz_name[0] 161 | img = nib.load(patient_root) 162 | img_np = img.get_fdata() 163 | 164 | # p5 = np.percentile(img_np.flatten(), 5) 165 | # p95 = np.percentile(img_np.flatten(), 95) 166 | # img_np = np.clip(img_np, p5, p95) 167 | 168 | print('patient%03d...' % num_pat) 169 | save_labeled_data_root = LabeledVendorB2_data_dir 170 | 171 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 172 | 173 | if num_pat == 0: 174 | print(type(img_np[0,0,0,0])) 175 | 176 | ## save each image of the patient 177 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 178 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 179 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 180 | 181 | 182 | ###################################################################################################### 183 | # Load LabeledVendorB2 mask and save them to 2D images 184 | for num_pat in range(0, len(LabeledVendorB2_names)): 185 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat]+'/')) 186 | patient_root = arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat] + '/' + gz_name[1] 187 | img = nib.load(patient_root) 188 | img_np = img.get_fdata() 189 | 190 | print('patient%03d...' % num_pat) 191 | 192 | save_mask_root = LabeledVendorB2_mask_dir 193 | 194 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 195 | 196 | for num_row in range(1, 346): 197 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]: 198 | ED = sheet.cell_value(num_row, 4) 199 | ES = sheet.cell_value(num_row, 5) 200 | break 201 | 202 | ## save masks of the patient 203 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 204 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 205 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES: 206 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice)) 207 | 208 | 209 | ###################################################################################################### 210 | # Load LabeledVendorB3 data and save them to 2D images 211 | for num_pat in range(0, len(LabeledVendorB3_names)): 212 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat]+'/')) 213 | patient_root = arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat] + '/' + gz_name[0] 214 | img = nib.load(patient_root) 215 | img_np = img.get_fdata() 216 | 217 | # p5 = np.percentile(img_np.flatten(), 5) 218 | # p95 = np.percentile(img_np.flatten(), 95) 219 | # img_np = np.clip(img_np, p5, p95) 220 | 221 | print('patient%03d...' % num_pat) 222 | save_labeled_data_root = LabeledVendorB3_data_dir 223 | 224 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 225 | 226 | if num_pat == 0: 227 | print(type(img_np[0,0,0,0])) 228 | 229 | ## save each image of the patient 230 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 231 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 232 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 233 | 234 | 235 | ###################################################################################################### 236 | # Load LabeledVendorB3 mask and save them to 2D images 237 | for num_pat in range(0, len(LabeledVendorB3_names)): 238 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat]+'/')) 239 | patient_root = arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat] + '/' + gz_name[1] 240 | img = nib.load(patient_root) 241 | img_np = img.get_fdata() 242 | 243 | print('patient%03d...' % num_pat) 244 | 245 | save_mask_root = LabeledVendorB3_mask_dir 246 | 247 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 248 | 249 | for num_row in range(1, 346): 250 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]: 251 | ED = sheet.cell_value(num_row, 4) 252 | ES = sheet.cell_value(num_row, 5) 253 | break 254 | 255 | ## save masks of the patient 256 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 257 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 258 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES: 259 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice)) 260 | 261 | 262 | ###################################################################################################### 263 | # Load LabeledVendorC data and save them to 2D images 264 | for num_pat in range(0, len(LabeledVendorC_names)): 265 | gz_name = sorted(os.listdir(arg.LabeledVendorC+LabeledVendorC_names[num_pat]+'/')) 266 | patient_root = arg.LabeledVendorC+LabeledVendorC_names[num_pat] + '/' + gz_name[0] 267 | img = nib.load(patient_root) 268 | img_np = img.get_fdata() 269 | 270 | # p5 = np.percentile(img_np.flatten(), 5) 271 | # p95 = np.percentile(img_np.flatten(), 95) 272 | # img_np = np.clip(img_np, p5, p95) 273 | 274 | print('patient%03d...' % num_pat) 275 | save_labeled_data_root = LabeledVendorC_data_dir 276 | 277 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 278 | 279 | if num_pat == 0: 280 | print(type(img_np[0,0,0,0])) 281 | 282 | ## save each image of the patient 283 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 284 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 285 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 286 | 287 | 288 | ###################################################################################################### 289 | # Load LabeledVendorC mask and save them to 2D images 290 | for num_pat in range(0, len(LabeledVendorC_names)): 291 | gz_name = sorted(os.listdir(arg.LabeledVendorC+LabeledVendorC_names[num_pat]+'/')) 292 | patient_root = arg.LabeledVendorC+LabeledVendorC_names[num_pat] + '/' + gz_name[1] 293 | img = nib.load(patient_root) 294 | img_np = img.get_fdata() 295 | 296 | print('patient%03d...' % num_pat) 297 | 298 | save_mask_root = LabeledVendorC_mask_dir 299 | 300 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 301 | 302 | for num_row in range(1, 346): 303 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]: 304 | ED = sheet.cell_value(num_row, 4) 305 | ES = sheet.cell_value(num_row, 5) 306 | break 307 | 308 | ## save masks of the patient 309 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 310 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 311 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES: 312 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice)) 313 | 314 | 315 | ###################################################################################################### 316 | # Load LabeledVendorD data and save them to 2D images 317 | for num_pat in range(0, len(LabeledVendorD_names)): 318 | gz_name = sorted(os.listdir(arg.LabeledVendorD+LabeledVendorD_names[num_pat]+'/')) 319 | patient_root = arg.LabeledVendorD+LabeledVendorD_names[num_pat] + '/' + gz_name[0] 320 | img = nib.load(patient_root) 321 | img_np = img.get_fdata() 322 | 323 | # p5 = np.percentile(img_np.flatten(), 5) 324 | # p95 = np.percentile(img_np.flatten(), 95) 325 | # img_np = np.clip(img_np, p5, p95) 326 | 327 | print('patient%03d...' % num_pat) 328 | save_labeled_data_root = LabeledVendorD_data_dir 329 | 330 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 331 | 332 | if num_pat == 0: 333 | print(type(img_np[0,0,0,0])) 334 | 335 | ## save each image of the patient 336 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 337 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 338 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 339 | 340 | 341 | ###################################################################################################### 342 | # Load LabeledVendorD mask and save them to 2D images 343 | for num_pat in range(0, len(LabeledVendorD_names)): 344 | gz_name = sorted(os.listdir(arg.LabeledVendorD+LabeledVendorD_names[num_pat]+'/')) 345 | patient_root = arg.LabeledVendorD+LabeledVendorD_names[num_pat] + '/' + gz_name[1] 346 | img = nib.load(patient_root) 347 | img_np = img.get_fdata() 348 | 349 | print('patient%03d...' % num_pat) 350 | 351 | save_mask_root = LabeledVendorD_mask_dir 352 | 353 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 354 | 355 | for num_row in range(1, 346): 356 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]: 357 | ED = sheet.cell_value(num_row, 4) 358 | ES = sheet.cell_value(num_row, 5) 359 | break 360 | 361 | ## save masks of the patient 362 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 363 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 364 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES: 365 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice)) 366 | 367 | 368 | ###################################################################################################### 369 | # Load UnlabeledVendorC data and save them to 2D images 370 | for num_pat in range(0, len(UnlabeledVendorC_names)): 371 | gz_name = sorted(os.listdir(arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat]+'/')) 372 | patient_root = arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat] + '/' + gz_name[0] 373 | img = nib.load(patient_root) 374 | img_np = img.get_fdata() 375 | 376 | # p5 = np.percentile(img_np.flatten(), 5) 377 | # p95 = np.percentile(img_np.flatten(), 95) 378 | # img_np = np.clip(img_np, p5, p95) 379 | 380 | print('patient%03d...' % num_pat) 381 | save_labeled_data_root = UnlabeledVendorC_data_dir 382 | 383 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 384 | 385 | if num_pat == 0: 386 | print(type(img_np[0,0,0,0])) 387 | 388 | ## save each image of the patient 389 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 390 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]] 391 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) -------------------------------------------------------------------------------- /preprocess/save_MNMS_re.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import argparse 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import configparser 7 | import xlrd 8 | 9 | def safe_mkdir(path): 10 | try: 11 | os.makedirs(path) 12 | except OSError: 13 | pass 14 | 15 | def save_mask(img, path, name): 16 | h = img.shape[0] 17 | w = img.shape[1] 18 | img_1 = (img == 1) * 255 19 | img_2 = (img == 2) * 255 20 | img_3 = (img == 3) * 255 21 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2) 22 | img = Image.fromarray(img.astype(np.uint8), 'RGB') 23 | img.save(os.path.join(path, name)) 24 | 25 | def save_image(img, path, name): 26 | if img.min() < 0 or img.max() <= 0: 27 | pass 28 | else: 29 | img -= img.min() 30 | img /= img.max() 31 | 32 | img *= 255 33 | img = Image.fromarray(img.astype(np.uint8), 'L') 34 | img.save(os.path.join(path, name)) 35 | 36 | def save_np(img, path, name): 37 | np.savez_compressed(os.path.join(path, name), img) 38 | 39 | def save_mask_np(img, path, name): 40 | h = img.shape[0] 41 | w = img.shape[1] 42 | img_1 = (img == 1) 43 | img_2 = (img == 2) 44 | img_3 = (img == 3) 45 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2) 46 | np.savez_compressed(os.path.join(path, name), img) 47 | 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--LabeledVendorA', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorA/', help='The root of dataset.') 51 | 52 | parser.add_argument('--LabeledVendorBcenter2', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center2/', help='The root of dataset.') 53 | parser.add_argument('--LabeledVendorBcenter3', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center3/', help='The root of dataset.') 54 | 55 | parser.add_argument('--LabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorC/', help='The root of dataset.') 56 | 57 | parser.add_argument('--LabeledVendorD', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorD/', help='The root of dataset.') 58 | 59 | parser.add_argument('--UnlabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Unlabeled/vendorC/', help='The root of dataset.') 60 | 61 | arg = parser.parse_args() 62 | 63 | ###################################################################################################### 64 | #Load excel information: 65 | # cell_value(1,0) -> cell_value(175,0) 66 | ex_file = '/home/s1575424/xiao/Year2/mnms_split_data/mnms_dataset_info.xlsx' 67 | wb = xlrd.open_workbook(ex_file) 68 | sheet = wb.sheet_by_index(0) 69 | # sheet.cell_value(r, c) 70 | 71 | # Save data dirs 72 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorA/' 73 | 74 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorB/center2/' 75 | 76 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorB/center3/' 77 | 78 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorC/' 79 | 80 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorD/' 81 | 82 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Unlabeled/vendorC/' 83 | 84 | 85 | # Load all the data names 86 | LabeledVendorA_names = sorted(os.listdir(arg.LabeledVendorA)) 87 | 88 | LabeledVendorB2_names = sorted(os.listdir(arg.LabeledVendorBcenter2)) 89 | LabeledVendorB3_names = sorted(os.listdir(arg.LabeledVendorBcenter3)) 90 | 91 | LabeledVendorC_names = sorted(os.listdir(arg.LabeledVendorC)) 92 | 93 | LabeledVendorD_names = sorted(os.listdir(arg.LabeledVendorD)) 94 | 95 | UnlabeledVendorC_names = sorted(os.listdir(arg.UnlabeledVendorC)) 96 | 97 | #### Output: non-normed, non-cropped, no HW transposed npz file. 98 | 99 | ###################################################################################################### 100 | # Load LabeledVendorA data and save them to 2D images 101 | 102 | for num_pat in range(0, len(LabeledVendorA_names)): 103 | gz_name = sorted(os.listdir(arg.LabeledVendorA+LabeledVendorA_names[num_pat]+'/')) 104 | patient_root = arg.LabeledVendorA+LabeledVendorA_names[num_pat] + '/' + gz_name[0] 105 | img = nib.load(patient_root) 106 | img_np = img.get_fdata() 107 | 108 | header = img.header 109 | resolution = header.get_zooms() 110 | X_scan_re = [] 111 | X_scan_re.append(resolution[0]) 112 | X_scan_re_np = np.array(X_scan_re) 113 | 114 | print('patient%03d...' % num_pat) 115 | save_labeled_data_root = LabeledVendorA_data_dir 116 | 117 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 118 | 119 | ## save each image of the patient 120 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 121 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 122 | 123 | 124 | ###################################################################################################### 125 | # Load LabeledVendorB2 data and save them to 2D images 126 | for num_pat in range(0, len(LabeledVendorB2_names)): 127 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat]+'/')) 128 | patient_root = arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat] + '/' + gz_name[0] 129 | img = nib.load(patient_root) 130 | img_np = img.get_fdata() 131 | 132 | header = img.header 133 | resolution = header.get_zooms() 134 | X_scan_re = [] 135 | X_scan_re.append(resolution[0]) 136 | X_scan_re_np = np.array(X_scan_re) 137 | 138 | 139 | print('patient%03d...' % num_pat) 140 | save_labeled_data_root = LabeledVendorB2_data_dir 141 | 142 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 143 | 144 | if num_pat == 0: 145 | print(type(img_np[0,0,0,0])) 146 | 147 | ## save each image of the patient 148 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 149 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 150 | 151 | 152 | 153 | ###################################################################################################### 154 | # Load LabeledVendorB3 data and save them to 2D images 155 | for num_pat in range(0, len(LabeledVendorB3_names)): 156 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat]+'/')) 157 | patient_root = arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat] + '/' + gz_name[0] 158 | img = nib.load(patient_root) 159 | img_np = img.get_fdata() 160 | 161 | header = img.header 162 | resolution = header.get_zooms() 163 | X_scan_re = [] 164 | X_scan_re.append(resolution[0]) 165 | X_scan_re_np = np.array(X_scan_re) 166 | 167 | 168 | print('patient%03d...' % num_pat) 169 | save_labeled_data_root = LabeledVendorB3_data_dir 170 | 171 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 172 | 173 | if num_pat == 0: 174 | print(type(img_np[0,0,0,0])) 175 | 176 | ## save each image of the patient 177 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 178 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 179 | 180 | 181 | ###################################################################################################### 182 | # Load LabeledVendorC data and save them to 2D images 183 | for num_pat in range(0, len(LabeledVendorC_names)): 184 | gz_name = sorted(os.listdir(arg.LabeledVendorC+LabeledVendorC_names[num_pat]+'/')) 185 | patient_root = arg.LabeledVendorC+LabeledVendorC_names[num_pat] + '/' + gz_name[0] 186 | img = nib.load(patient_root) 187 | img_np = img.get_fdata() 188 | 189 | header = img.header 190 | resolution = header.get_zooms() 191 | X_scan_re = [] 192 | X_scan_re.append(resolution[0]) 193 | X_scan_re_np = np.array(X_scan_re) 194 | 195 | print('patient%03d...' % num_pat) 196 | save_labeled_data_root = LabeledVendorC_data_dir 197 | 198 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 199 | 200 | if num_pat == 0: 201 | print(type(img_np[0,0,0,0])) 202 | 203 | ## save each image of the patient 204 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 205 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 206 | 207 | 208 | ###################################################################################################### 209 | # Load LabeledVendorD data and save them to 2D images 210 | for num_pat in range(0, len(LabeledVendorD_names)): 211 | gz_name = sorted(os.listdir(arg.LabeledVendorD+LabeledVendorD_names[num_pat]+'/')) 212 | patient_root = arg.LabeledVendorD+LabeledVendorD_names[num_pat] + '/' + gz_name[0] 213 | img = nib.load(patient_root) 214 | img_np = img.get_fdata() 215 | 216 | header = img.header 217 | resolution = header.get_zooms() 218 | X_scan_re = [] 219 | X_scan_re.append(resolution[0]) 220 | X_scan_re_np = np.array(X_scan_re) 221 | 222 | print('patient%03d...' % num_pat) 223 | save_labeled_data_root = LabeledVendorD_data_dir 224 | 225 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 226 | 227 | if num_pat == 0: 228 | print(type(img_np[0,0,0,0])) 229 | 230 | ## save each image of the patient 231 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 232 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) 233 | 234 | 235 | ###################################################################################################### 236 | # Load UnlabeledVendorC data and save them to 2D images 237 | for num_pat in range(0, len(UnlabeledVendorC_names)): 238 | gz_name = sorted(os.listdir(arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat]+'/')) 239 | patient_root = arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat] + '/' + gz_name[0] 240 | img = nib.load(patient_root) 241 | img_np = img.get_fdata() 242 | 243 | header = img.header 244 | resolution = header.get_zooms() 245 | X_scan_re = [] 246 | X_scan_re.append(resolution[0]) 247 | X_scan_re_np = np.array(X_scan_re) 248 | 249 | print('patient%03d...' % num_pat) 250 | save_labeled_data_root = UnlabeledVendorC_data_dir 251 | 252 | img_np = np.transpose(img_np, (0, 1, 3, 2)) 253 | 254 | if num_pat == 0: 255 | print(type(img_np[0,0,0,0])) 256 | 257 | ## save each image of the patient 258 | for num_slice in range(img_np.shape[2]*img_np.shape[3]): 259 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice)) -------------------------------------------------------------------------------- /preprocess/save_SCGM_2D.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import argparse 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import configparser 7 | 8 | import re 9 | import SimpleITK as stik 10 | from collections import OrderedDict 11 | from torchvision import transforms 12 | import random 13 | import torchvision.transforms.functional as F 14 | import cv2 15 | import torch 16 | 17 | def safe_mkdir(path): 18 | try: 19 | os.makedirs(path) 20 | except OSError: 21 | pass 22 | 23 | def save_np(img, path, name): 24 | np.savez_compressed(os.path.join(path, name), img) 25 | 26 | def save_mask_np(mask1, mask2, path, name): 27 | h = mask1.shape[0] 28 | w = mask1.shape[1] 29 | mask = np.concatenate((mask1.reshape(h,w,1), mask2.reshape(h,w,1)), axis=2) 30 | np.savez_compressed(os.path.join(path, name), mask) 31 | 32 | path_train = '/home/s1575424/xiao/Year2/scgm_rawdata/train/' 33 | past_test = '/home/s1575424/xiao/Year2/scgm_rawdata/test/' 34 | 35 | ###################################################################################################### 36 | 37 | # Save data dirs 38 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorA/' 39 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorA/' 40 | 41 | LabeledVendorB_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorB/' 42 | LabeledVendorB_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorB/' 43 | 44 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorC/' 45 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorC/' 46 | 47 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorD/' 48 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorD/' 49 | 50 | UnlabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorA/' 51 | UnlabeledVendorB_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorB/' 52 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorC/' 53 | UnlabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorD/' 54 | 55 | labeled_data_dir = [LabeledVendorA_data_dir, LabeledVendorB_data_dir, LabeledVendorC_data_dir, LabeledVendorD_data_dir] 56 | labeled_mask_dir = [LabeledVendorA_mask_dir, LabeledVendorB_mask_dir, LabeledVendorC_mask_dir, LabeledVendorD_mask_dir] 57 | un_labeled_data_dir = [UnlabeledVendorA_data_dir, UnlabeledVendorB_data_dir, UnlabeledVendorC_data_dir, UnlabeledVendorD_data_dir] 58 | 59 | safe_mkdir(LabeledVendorA_data_dir) 60 | safe_mkdir(LabeledVendorA_mask_dir) 61 | safe_mkdir(LabeledVendorB_data_dir) 62 | safe_mkdir(LabeledVendorB_mask_dir) 63 | safe_mkdir(LabeledVendorC_data_dir) 64 | safe_mkdir(LabeledVendorC_mask_dir) 65 | safe_mkdir(LabeledVendorD_data_dir) 66 | safe_mkdir(LabeledVendorD_mask_dir) 67 | safe_mkdir(UnlabeledVendorA_data_dir) 68 | safe_mkdir(UnlabeledVendorB_data_dir) 69 | safe_mkdir(UnlabeledVendorC_data_dir) 70 | safe_mkdir(UnlabeledVendorD_data_dir) 71 | 72 | 73 | def read_numpy(file_name): 74 | reader = stik.ImageFileReader() 75 | reader.SetImageIO("NiftiImageIO") 76 | reader.SetFileName(file_name) 77 | data = reader.Execute() 78 | return stik.GetArrayFromImage(data) 79 | 80 | def read_dataset_into_memory(data_list, labeled=True): 81 | for val in data_list.values(): 82 | val['input'] = read_numpy(val['input']) 83 | if labeled: 84 | for idx, gt in enumerate(val['gt']): 85 | val['gt'][idx] = read_numpy(gt) 86 | else: 87 | pass 88 | return data_list 89 | 90 | def get_index_map(data_list, labeled=True): 91 | map_list = [] 92 | num_sbj = 0 93 | for data in data_list.values(): 94 | slice_num = data['input'].shape[0] 95 | print(data['input'].shape) 96 | for i in range(slice_num): 97 | if labeled: 98 | map_list.append([data['input'][i], np.stack([data['gt'][idx][i] for idx in range(4)], axis=0), num_sbj]) 99 | else: 100 | map_list.append([data['input'][i], num_sbj]) 101 | num_sbj += 1 102 | return map_list 103 | 104 | # data_list: data_list['input'] path, data_list['gt'] path 105 | def Save_vendor_data(data_list, labeled=True, data_dir='', mask_dir=''): 106 | data_list = read_dataset_into_memory(data_list, labeled) 107 | # data_list: data_list['input'] images, data_list['gt'] ground-truth labels 108 | map_list = get_index_map(data_list,labeled) 109 | # map_list: map_list[i]: [images, gt] 110 | if labeled: 111 | num_sbj = 0 112 | num_slice = 0 113 | for idx in range(len(map_list)): 114 | img, gt_list, flag = map_list[idx] 115 | if flag != num_sbj: 116 | num_sbj = flag 117 | num_slice = 0 118 | img = img / (img.max() if img.max() > 0 else 1) 119 | gt_list = torch.tensor(gt_list, dtype=torch.uint8) 120 | spinal_cord_mask = (torch.mean(((gt_list > 0)).float(), dim=0) > 0.5).float() 121 | spinal_cord_mask = spinal_cord_mask.numpy() 122 | gm_mask = (torch.mean((gt_list == 1).float(), dim=0) > 0.5).float() 123 | gm_mask = gm_mask.numpy() 124 | save_np(img, data_dir, '%03d%03d' % (num_sbj, num_slice)) 125 | save_mask_np(spinal_cord_mask, gm_mask, mask_dir, '%03d%03d' % (num_sbj, num_slice)) 126 | num_slice += 1 127 | else: 128 | num_sbj = 0 129 | num_slice = 0 130 | for idx in range(len(map_list)): 131 | img, flag = map_list[idx] 132 | if flag != num_sbj: 133 | num_sbj = flag 134 | num_slice = 0 135 | img = img / (img.max() if img.max() > 0 else 1) 136 | save_np(img, data_dir, '%03d%03d' % (num_sbj, num_slice)) 137 | num_slice += 1 138 | 139 | resolution = { 140 | 'site1': [5, 0.5, 0.5], 141 | 'site2': [5, 0.5, 0.5], 142 | 'site3': [2.5, 0.5, 0.5], 143 | 'site4': [5, 0.29, 0.29], 144 | } 145 | 146 | labeled_imageFileList = [os.path.join(path_train, f) for f in os.listdir(path_train) if 'site' in f and '.txt' not in f] 147 | labeled_data_dict = {'site1': OrderedDict(), 'site2': OrderedDict(), 'site3': OrderedDict(), 'site4': OrderedDict()} 148 | for file in sorted(labeled_imageFileList): 149 | res = re.search('site(\d)-sc(\d*)-(image|mask)', file).groups() 150 | if res[1] not in labeled_data_dict['site' + res[0]].keys(): 151 | labeled_data_dict['site' + res[0]][res[1]] = {'input': None, 'gt': []} 152 | if res[2] == 'image': 153 | labeled_data_dict['site' + res[0]][res[1]]['input'] = file 154 | if res[2] == 'mask': 155 | labeled_data_dict['site' + res[0]][res[1]]['gt'].append(file) 156 | i = 0 157 | for domain, data_list in labeled_data_dict.items(): 158 | print(domain) 159 | Save_vendor_data(data_list, labeled=True, data_dir=labeled_data_dir[i], mask_dir=labeled_mask_dir[i]) 160 | i += 1 161 | 162 | unlabeled_imageFileList = [os.path.join(past_test, f) for f in os.listdir(past_test) if 'site' in f and '.txt' not in f] 163 | unlabeled_data_dict = {'site1': OrderedDict(), 'site2': OrderedDict(), 'site3': OrderedDict(), 'site4': OrderedDict()} 164 | for file in sorted(unlabeled_imageFileList): 165 | res = re.search('site(\d)-sc(\d*)-(image|mask)', file).groups() 166 | if res[1] not in unlabeled_data_dict['site' + res[0]].keys(): 167 | unlabeled_data_dict['site' + res[0]][res[1]] = {'input': None, 'gt': []} 168 | if res[2] == 'image': 169 | unlabeled_data_dict['site' + res[0]][res[1]]['input'] = file 170 | if res[2] == 'mask': 171 | unlabeled_data_dict['site' + res[0]][res[1]]['gt'].append(file) 172 | i = 0 173 | for domain, data_list in unlabeled_data_dict.items(): 174 | print(domain) 175 | Save_vendor_data(data_list, labeled=False, data_dir=un_labeled_data_dir[i], mask_dir=None) 176 | i += 1 177 | # two items: 1. domain name, 2. data['input'], data[gt] -------------------------------------------------------------------------------- /preprocess/split_MNMS_data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | import xlrd 5 | 6 | def walk_path(dir): 7 | dir = dir 8 | paths = [] 9 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 10 | 11 | for root, dirs, _ in sorted(os.walk(dir)): 12 | for name in dirs: 13 | paths.append(os.path.join(root, name)) 14 | # print(paths) 15 | return paths 16 | 17 | 18 | ###################################################################################################### 19 | #Load excel information: 20 | # cell_value(1,0) -> cell_value(175,0) 21 | ex_file = '/home/s1575424/xiao/Year2/mnms_split_data/mnms_dataset_info.xlsx' 22 | wb = xlrd.open_workbook(ex_file) 23 | sheet = wb.sheet_by_index(0) 24 | # sheet.cell_value(r, c) 25 | 26 | vendor_A = [] 27 | vendor_B = [] 28 | vendor_C = [] 29 | vendor_D = [] 30 | 31 | center_2 = [] 32 | center_3 = [] 33 | 34 | for i in range(1, 346): 35 | if sheet.cell_value(i, 2)=='A': 36 | vendor_A.append(sheet.cell_value(i, 0)) 37 | elif sheet.cell_value(i, 2)=='B': 38 | vendor_B.append(sheet.cell_value(i, 0)) 39 | if sheet.cell_value(i, 3)==2: 40 | center_2.append(sheet.cell_value(i, 0)) 41 | else: 42 | center_3.append(sheet.cell_value(i, 0)) 43 | elif sheet.cell_value(i, 2) == 'C': 44 | vendor_C.append(sheet.cell_value(i, 0)) 45 | elif sheet.cell_value(i, 2) == 'D': 46 | vendor_D.append(sheet.cell_value(i, 0)) 47 | else: 48 | break 49 | 50 | ###################################################################################################### 51 | # move data to the corresponding folders: mnms_split_data/Labeled/(vendorA, vendorB, vendorC, vendorD), mnms_split_data/Unlabeled/vendorC 52 | path_vendorA = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorA/' 53 | path_vendorB = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/' 54 | path_vendorC = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorC/' 55 | path_vendorD = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorD/' 56 | 57 | labeled_train_paths = walk_path('/home/s1575424/xiao/Year2/OpenDataset/Training/Labeled') 58 | testing_paths = walk_path('/home/s1575424/xiao/Year2/OpenDataset/Testing') 59 | val_paths = walk_path('/home/s1575424/xiao/Year2/OpenDataset/Validation') 60 | 61 | i = 0 62 | for train_path in labeled_train_paths: 63 | if train_path[-6:] in vendor_A: 64 | shutil.move(train_path, path_vendorA) 65 | elif train_path[-6:] in vendor_B: 66 | if train_path[-6:] in center_2: 67 | shutil.move(train_path, path_vendorB + '/center2') 68 | else: 69 | shutil.move(train_path, path_vendorB + '/center3') 70 | else: 71 | continue 72 | i += 1 73 | print(i) 74 | 75 | i = 0 76 | for test_path in testing_paths: 77 | if test_path[-6:] in vendor_A: 78 | shutil.move(test_path, path_vendorA) 79 | elif test_path[-6:] in vendor_B: 80 | if test_path[-6:] in center_2: 81 | shutil.move(test_path, path_vendorB + '/center2') 82 | else: 83 | shutil.move(test_path, path_vendorB + '/center3') 84 | elif test_path[-6:] in vendor_C: 85 | shutil.move(test_path, path_vendorC) 86 | elif test_path[-6:] in vendor_D: 87 | shutil.move(test_path, path_vendorD) 88 | else: 89 | continue 90 | i += 1 91 | print(i) 92 | 93 | i = 0 94 | for val_path in val_paths: 95 | if val_path[-6:] in vendor_A: 96 | shutil.move(val_path, path_vendorA) 97 | elif val_path[-6:] in vendor_B: 98 | if val_path[-6:] in center_2: 99 | shutil.move(val_path, path_vendorB + '/center2') 100 | else: 101 | shutil.move(val_path, path_vendorB + '/center3') 102 | elif val_path[-6:] in vendor_C: 103 | shutil.move(val_path, path_vendorC) 104 | elif val_path[-6:] in vendor_D: 105 | shutil.move(val_path, path_vendorD) 106 | else: 107 | continue 108 | i += 1 109 | print(i) -------------------------------------------------------------------------------- /train_meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import os 6 | import argparse 7 | from eval import eval_dgnet 8 | from tqdm import tqdm 9 | import logging 10 | from metrics.focal_loss import FocalLoss 11 | from torch.utils.data import DataLoader, random_split, ConcatDataset 12 | import torch.nn.functional as F 13 | import utils 14 | from loaders.mms_dataloader_meta_split import get_meta_split_data_loaders 15 | import models 16 | import losses 17 | from torch.utils.tensorboard import SummaryWriter 18 | import time 19 | 20 | 21 | def get_args(): 22 | usage_text = ( 23 | "DGNet Pytorch Implementation" 24 | "Usage: python train_meta.py [options]," 25 | " with [options]:" 26 | ) 27 | parser = argparse.ArgumentParser(description=usage_text) 28 | #training details 29 | parser.add_argument('-e','--epochs', type=int, default=100, help='Number of epochs') 30 | parser.add_argument('-bs','--batch_size', type=int, default=4, help='Number of inputs per batch') 31 | parser.add_argument('-c', '--cp', type=str, default='checkpoints/', help='The name of the checkpoints.') 32 | parser.add_argument('-tc', '--tcp', type=str, default='temp_checkpoints/', help='The name of the checkpoints.') 33 | parser.add_argument('-t', '--tv', type=str, default='D', help='The name of the target vendor.') 34 | parser.add_argument('-w', '--wc', type=str, default='DGNet_LR00002_LDv5', help='The name of the writter summary.') 35 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 36 | parser.add_argument('-mn','--model_name', type=str, default='dgnet', help='Name of the model architecture to be used for training/testing.') 37 | parser.add_argument('-lr','--learning_rate', type=float, default='0.00004', help='The learning rate for model training') 38 | parser.add_argument('-wi','--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)') 39 | parser.add_argument('--save_path', type=str, default='checkpoints', help= 'Path to save model checkpoints') 40 | parser.add_argument('--decoder_type', type=str, default='film', help='Choose decoder type between FiLM and SPADE') 41 | #hardware 42 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 43 | parser.add_argument('--num_workers' ,type= int, default = 0, help='Number of workers to use for dataload') 44 | 45 | return parser.parse_args() 46 | 47 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvA/ -t A -w DGNetRE_COM_META_100_tvA -g 0 48 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvB/ -t B -w DGNetRE_COM_META_100_tvB -g 1 49 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvC/ -t C -w DGNetRE_COM_META_100_tvC -g 2 50 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvD/ -t D -w DGNetRE_COM_META_100_tvD -g 3 51 | # k_un = 1 52 | # k1 = 20 53 | # k2 = 2 54 | # opt_patience = 4 55 | 56 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvA/ -t A -w DGNetRE_COM_META_50_tvA -g 0 57 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvB/ -t B -w DGNetRE_COM_META_50_tvB -g 1 58 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvC/ -t C -w DGNetRE_COM_META_50_tvC -g 2 59 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvD/ -t D -w DGNetRE_COM_META_50_tvD -g 3 60 | # k_un = 1 61 | # k1 = 20 62 | # k2 = 3 63 | # opt_patience = 4 64 | 65 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvA/ -t A -w DGNetRE_COM_META_20_tvA -g 0 66 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvB/ -t B -w DGNetRE_COM_META_20_tvB -g 1 67 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvC/ -t C -w DGNetRE_COM_META_20_tvC -g 2 68 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvD/ -t D -w DGNetRE_COM_META_20_tvD -g 3 69 | # k_un = 1 70 | # k1 = 30 71 | # k2 = 3 72 | # opt_patience = 4 73 | 74 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvA/ -t A -w DGNetRE_COM_META_5_tvA -g 0 75 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvB/ -t B -w DGNetRE_COM_META_5_tvB -g 1 76 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvC/ -t C -w DGNetRE_COM_META_5_tvC -g 2 77 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvD/ -t D -w DGNetRE_COM_META_5_tvD -g 3 78 | k_un = 1 79 | k1 = 30 80 | k2 = 3 81 | opt_patience = 4 82 | 83 | def latent_norm(a): 84 | n_batch, n_channel, _, _ = a.size() 85 | for batch in range(n_batch): 86 | for channel in range(n_channel): 87 | a_min = a[batch,channel,:,:].min() 88 | a_max = a[batch, channel, :, :].max() 89 | a[batch,channel,:,:] += a_min 90 | a[batch, channel, :, :] /= a_max - a_min 91 | return a 92 | 93 | def train_net(args): 94 | best_dice = 0 95 | best_lv = 0 96 | best_myo = 0 97 | best_rv = 0 98 | 99 | epochs = args.epochs 100 | batch_size = args.batch_size 101 | lr = args.learning_rate 102 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 103 | 104 | dir_checkpoint = args.cp 105 | test_vendor = args.tv 106 | wc = args.wc 107 | 108 | #Model selection and initialization 109 | model_params = { 110 | 'width': 288, 111 | 'height': 288, 112 | 'ndf': 64, 113 | 'norm': "batchnorm", 114 | 'upsample': "nearest", 115 | 'num_classes': 3, 116 | 'decoder_type': args.decoder_type, 117 | 'anatomy_out_channels': 8, 118 | 'z_length': 8, 119 | 'num_mask_channels': 8, 120 | 121 | } 122 | model = models.get_model(args.model_name, model_params) 123 | num_params = utils.count_parameters(model) 124 | print('Model Parameters: ', num_params) 125 | models.initialize_weights(model, args.weight_init) 126 | model.to(device) 127 | 128 | # size: 129 | # X: N, 1, 224, 224 130 | # Y: N, 3, 224, 224 131 | 132 | _, domain_1_unlabeled_loader, \ 133 | _, domain_2_unlabeled_loader,\ 134 | _, domain_3_unlabeled_loader, \ 135 | test_loader, \ 136 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset = get_meta_split_data_loaders(batch_size//2, test_vendor=test_vendor, image_size=224) 137 | 138 | 139 | val_dataset = ConcatDataset([domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset]) 140 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True, num_workers=2) 141 | n_val = len(val_dataset) 142 | print(n_val) 143 | print(len(test_loader)) 144 | print(len(domain_1_unlabeled_loader)) 145 | print(len(domain_2_unlabeled_loader)) 146 | print(len(domain_3_unlabeled_loader)) 147 | 148 | d_len = [] 149 | d_len.append(len(domain_1_labeled_dataset)) 150 | d_len.append(len(domain_2_labeled_dataset)) 151 | d_len.append(len(domain_3_labeled_dataset)) 152 | long_len = d_len[0] 153 | for i in range(len(d_len)): 154 | long_len = d_len[i] if d_len[i]>=long_len else long_len 155 | print(long_len) 156 | 157 | new_d_1 = domain_1_labeled_dataset 158 | for i in range(long_len//d_len[0]+1): 159 | if long_len == d_len[0]: 160 | break 161 | new_d_1 = ConcatDataset([new_d_1, domain_1_labeled_dataset]) 162 | domain_1_labeled_dataset = new_d_1 163 | domain_1_labeled_loader = DataLoader(dataset=domain_1_labeled_dataset, batch_size=batch_size//2, shuffle=False, 164 | drop_last=True, num_workers=2, pin_memory=True) 165 | 166 | new_d_2 = domain_2_labeled_dataset 167 | for i in range(long_len//d_len[1]+1): 168 | if long_len == d_len[1]: 169 | break 170 | new_d_2 = ConcatDataset([new_d_2, domain_2_labeled_dataset]) 171 | domain_2_labeled_dataset = new_d_2 172 | domain_2_labeled_loader = DataLoader(dataset=domain_2_labeled_dataset, batch_size=batch_size//2, shuffle=False, 173 | drop_last=True, num_workers=2, pin_memory=True) 174 | 175 | new_d_3 = domain_3_labeled_dataset 176 | for i in range(long_len//d_len[2]+1): 177 | if long_len == d_len[2]: 178 | break 179 | new_d_3 = ConcatDataset([new_d_3, domain_3_labeled_dataset]) 180 | domain_3_labeled_dataset = new_d_3 181 | domain_3_labeled_loader = DataLoader(dataset=domain_3_labeled_dataset, batch_size=batch_size//2, shuffle=False, 182 | drop_last=True, num_workers=2, pin_memory=True) 183 | 184 | print(len(domain_1_labeled_loader)) 185 | print(len(domain_2_labeled_loader)) 186 | print(len(domain_3_labeled_loader)) 187 | 188 | #metrics initialization 189 | # l2_distance = nn.MSELoss().to(device) 190 | criterion = nn.BCEWithLogitsLoss().to(device) 191 | l1_distance = nn.L1Loss().to(device) 192 | focal = FocalLoss() 193 | 194 | #optimizer initialization 195 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 196 | # need to use a more useful lr_scheduler 197 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=opt_patience) 198 | 199 | writer = SummaryWriter(comment=wc) 200 | 201 | global_step = 0 202 | for epoch in range(epochs): 203 | model.train() 204 | with tqdm(total=long_len, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 205 | domain_1_labeled_itr = iter(domain_1_labeled_loader) 206 | domain_2_labeled_itr = iter(domain_2_labeled_loader) 207 | domain_3_labeled_itr = iter(domain_3_labeled_loader) 208 | domain_labeled_iter_list = [domain_1_labeled_itr, domain_2_labeled_itr, domain_3_labeled_itr] 209 | 210 | domain_1_unlabeled_itr = iter(domain_1_unlabeled_loader) 211 | domain_2_unlabeled_itr = iter(domain_2_unlabeled_loader) 212 | domain_3_unlabeled_itr = iter(domain_3_unlabeled_loader) 213 | domain_unlabeled_iter_list = [domain_1_unlabeled_itr, domain_2_unlabeled_itr, domain_3_unlabeled_itr] 214 | 215 | 216 | for num_itr in range(long_len//batch_size): 217 | # Randomly choosing meta train and meta test domains 218 | domain_list = np.random.permutation(3) 219 | meta_train_domain_list = domain_list[:2] 220 | meta_test_domain_list = domain_list[2] 221 | 222 | meta_train_imgs = [] 223 | meta_train_masks = [] 224 | meta_train_labels = [] 225 | meta_test_imgs = [] 226 | meta_test_masks = [] 227 | meta_test_labels = [] 228 | meta_test_un_imgs = [] 229 | meta_test_un_labels = [] 230 | 231 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_train_domain_list[0]]) 232 | meta_train_imgs.append(imgs) 233 | meta_train_masks.append(true_masks) 234 | meta_train_labels.append(labels) 235 | 236 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_train_domain_list[1]]) 237 | meta_train_imgs.append(imgs) 238 | meta_train_masks.append(true_masks) 239 | meta_train_labels.append(labels) 240 | 241 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_test_domain_list]) 242 | meta_test_imgs.append(imgs) 243 | meta_test_masks.append(true_masks) 244 | meta_test_labels.append(labels) 245 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_test_domain_list]) 246 | meta_test_imgs.append(imgs) 247 | meta_test_masks.append(true_masks) 248 | meta_test_labels.append(labels) 249 | 250 | imgs, labels = next(domain_unlabeled_iter_list[meta_test_domain_list]) 251 | meta_test_un_imgs.append(imgs) 252 | meta_test_un_labels.append(labels) 253 | imgs, labels = next(domain_unlabeled_iter_list[meta_test_domain_list]) 254 | meta_test_un_imgs.append(imgs) 255 | meta_test_un_labels.append(labels) 256 | 257 | meta_train_imgs = torch.cat((meta_train_imgs[0], meta_train_imgs[1]), dim=0) 258 | meta_train_masks = torch.cat((meta_train_masks[0], meta_train_masks[1]), dim=0) 259 | meta_train_labels = torch.cat((meta_train_labels[0], meta_train_labels[1]), dim=0) 260 | meta_test_imgs = torch.cat((meta_test_imgs[0], meta_test_imgs[1]), dim=0) 261 | meta_test_masks = torch.cat((meta_test_masks[0], meta_test_masks[1]), dim=0) 262 | meta_test_labels = torch.cat((meta_test_labels[0], meta_test_labels[1]), dim=0) 263 | meta_test_un_imgs = torch.cat((meta_test_un_imgs[0], meta_test_un_imgs[1]), dim=0) 264 | meta_test_un_labels = torch.cat((meta_test_un_labels[0], meta_test_un_labels[1]), dim=0) 265 | 266 | meta_train_un_imgs = [] 267 | meta_train_un_labels = [] 268 | for i in range(k_un): 269 | train_un_imgs = [] 270 | train_un_labels = [] 271 | un_imgs, un_labels = next(domain_unlabeled_iter_list[meta_train_domain_list[0]]) 272 | train_un_imgs.append(un_imgs) 273 | train_un_labels.append(un_labels) 274 | un_imgs, un_labels = next(domain_unlabeled_iter_list[meta_train_domain_list[1]]) 275 | train_un_imgs.append(un_imgs) 276 | train_un_labels.append(un_labels) 277 | meta_train_un_imgs.append(torch.cat((train_un_imgs[0], train_un_imgs[1]), dim=0)) 278 | meta_train_un_labels.append(torch.cat((train_un_labels[0], train_un_labels[1]), dim=0)) 279 | 280 | total_meta_un_loss = 0.0 281 | for i in range(k_un): 282 | # meta-train: 1. load meta-train data 2. calculate meta-train loss 283 | ###############################Meta train####################################################### 284 | un_imgs = meta_train_un_imgs[i].to(device=device, dtype=torch.float32) 285 | un_labels = meta_train_un_labels[i].to(device=device, dtype=torch.float32) 286 | 287 | un_reco, un_z_out, un_z_tilde, un_a_out, _, un_mu, un_logvar, un_cls_out, _ = model(un_imgs, true_masks, 'training') 288 | 289 | un_a_feature = F.softmax(un_a_out, dim=1) 290 | # un_a_feature = un_a_feature[:,4:,:,:] 291 | # un_seg_pred = un_a_out[:,:4,:,:] 292 | 293 | latent_dim = un_a_feature.size(1) 294 | un_a_feature = un_a_out.permute(0, 2, 3, 1).contiguous().view(-1, latent_dim) 295 | un_a_feature = un_a_feature[torch.randperm(len(un_a_feature))] 296 | un_U_a, un_S_a, un_V_a = torch.svd(un_a_feature[0:2000]) 297 | 298 | # loss_low_rank_Un_a = 0.1*torch.sum(un_S_a) 299 | loss_low_rank_Un_a = un_S_a[4] 300 | 301 | un_reco_loss = l1_distance(un_reco, un_imgs) 302 | un_regression_loss = l1_distance(un_z_tilde, un_z_out) 303 | 304 | kl_loss1 = losses.KL_divergence(un_logvar[:, :8], un_mu[:, :8]) 305 | kl_loss2 = losses.KL_divergence(un_logvar[:, 8:], un_mu[:, 8:]) 306 | hsic_loss = losses.HSIC_lossfunc(un_z_out[:, :8], un_z_out[:, 8:]) 307 | un_kl_loss = kl_loss1 + kl_loss2 + hsic_loss 308 | 309 | d_cls = criterion(un_cls_out, un_labels) 310 | un_batch_loss = un_reco_loss + (0.1*un_regression_loss) + 0.1*un_kl_loss + d_cls + 0.1*loss_low_rank_Un_a 311 | 312 | total_meta_un_loss += un_batch_loss 313 | 314 | # meta-test: 1. load meta-test data 2. calculate meta-test loss 315 | ###############################Meta test####################################################### 316 | un_imgs = meta_test_un_imgs.to(device=device, dtype=torch.float32) 317 | un_labels = meta_test_un_labels.to(device=device, dtype=torch.float32) 318 | un_reco, un_z_out, un_z_tilde, un_a_out, _, un_mu, un_logvar, un_cls_out, _ = model( 319 | un_imgs, true_masks, 'training', meta_loss=un_batch_loss) 320 | 321 | un_seg_pred = un_a_out[:, :4, :, :] 322 | sf_un_seg_pred = F.softmax(un_seg_pred, dim=1) 323 | 324 | un_reco_loss = l1_distance(un_reco, un_imgs) 325 | un_regression_loss = l1_distance(un_z_tilde, un_z_out) 326 | 327 | # kl_loss1 = losses.KL_divergence(un_logvar[:, :8], un_mu[:, :8]) 328 | # kl_loss2 = losses.KL_divergence(un_logvar[:, 8:], un_mu[:, 8:]) 329 | # hsic_loss = losses.HSIC_lossfunc(un_z_out[:, :8], un_z_out[:, 8:]) 330 | # un_kl_loss = kl_loss1 + kl_loss2 + hsic_loss 331 | 332 | d_cls = criterion(un_cls_out, un_labels) 333 | un_batch_loss = un_reco_loss + d_cls 334 | 335 | total_meta_un_loss += un_batch_loss 336 | 337 | writer.add_scalar('Meta_test_loss/un_reco_loss', un_reco_loss.item(), global_step) 338 | writer.add_scalar('Meta_test_loss/un_regression_loss', un_regression_loss.item(), global_step) 339 | # writer.add_scalar('Meta_test_loss/un_kl_loss', un_kl_loss.item(), global_step) 340 | writer.add_scalar('Meta_test_loss/d_cls', d_cls.item(), global_step) 341 | # writer.add_scalar('Meta_test_loss/loss_low_rank_Un_a', loss_low_rank_Un_a.item(), un_step) 342 | writer.add_scalar('Meta_test_loss/un_batch_loss', un_batch_loss.item(), global_step) 343 | 344 | optimizer.zero_grad() 345 | total_meta_un_loss.backward() 346 | nn.utils.clip_grad_value_(model.parameters(), 0.1) 347 | optimizer.step() 348 | 349 | total_meta_loss = 0.0 350 | # meta-train: 1. load meta-train data 2. calculate meta-train loss 351 | ###############################Meta train####################################################### 352 | imgs = meta_train_imgs.to(device=device, dtype=torch.float32) 353 | mask_type = torch.float32 354 | ce_mask = meta_train_masks.clone().to(device=device, dtype=torch.long) 355 | true_masks = meta_train_masks.to(device=device, dtype=mask_type) 356 | labels = meta_train_labels.to(device=device, dtype=torch.float32) 357 | 358 | reco, z_out, z_out_tilde, a_out, _, mu, logvar, cls_out, _ = model(imgs, true_masks, 'training') 359 | 360 | # mode-1 flattering and change the original 4,8,224,224 features to 4x224x224, 8 361 | # randomly pick 4000, 8 features to calculate the singular values 362 | a_feature = F.softmax(a_out, dim=1) 363 | # a_feature = a_feature[:, 4:, :, :] 364 | seg_pred = a_out[:, :4, :, :] 365 | 366 | latent_dim = a_feature.size(1) 367 | a_feature = a_feature.permute(0, 2, 3, 1).contiguous().view(-1, latent_dim) 368 | a_feature = a_feature[torch.randperm(len(a_feature))] 369 | U_a, S_a, V_a = torch.svd(a_feature[0:2000]) 370 | 371 | # loss_low_rank_a = 0.1*torch.sum(S_a) 372 | loss_low_rank_a = S_a[4] 373 | 374 | reco_loss = l1_distance(reco, imgs) 375 | kl_loss1 = losses.KL_divergence(logvar[:,:8], mu[:,:8]) 376 | kl_loss2 = losses.KL_divergence(logvar[:,8:], mu[:,8:]) 377 | hsic_loss = losses.HSIC_lossfunc(z_out[:,:8], z_out[:,8:]) 378 | kl_loss = kl_loss1 + kl_loss2 + hsic_loss 379 | regression_loss = l1_distance(z_out_tilde, z_out) 380 | 381 | sf_seg = F.softmax(seg_pred, dim=1) 382 | dice_loss_lv = losses.dice_loss(sf_seg[:,0,:,:], true_masks[:,0,:,:]) 383 | dice_loss_myo = losses.dice_loss(sf_seg[:,1,:,:], true_masks[:,1,:,:]) 384 | dice_loss_rv = losses.dice_loss(sf_seg[:,2,:,:], true_masks[:,2,:,:]) 385 | dice_loss_bg = losses.dice_loss(sf_seg[:, 3, :, :], true_masks[:, 3, :, :]) 386 | loss_dice = dice_loss_lv + dice_loss_myo + dice_loss_rv + dice_loss_bg 387 | 388 | ce_target = ce_mask[:, 3, :, :]*0 + ce_mask[:, 0, :, :]*1 + ce_mask[:, 1, :, :]*2 + ce_mask[:, 2, :, :]*3 389 | 390 | seg_pred_swap = torch.cat((seg_pred[:,3,:,:].unsqueeze(1), seg_pred[:,:3,:,:]), dim=1) 391 | 392 | loss_focal = focal(seg_pred_swap, ce_target) 393 | 394 | d_cls = criterion(cls_out, labels) 395 | d_losses = d_cls 396 | 397 | batch_loss = reco_loss + (0.1 * regression_loss) + 0.1*kl_loss + 5*loss_dice + 5*loss_focal + d_losses + 0.1*loss_low_rank_a 398 | 399 | total_meta_loss += batch_loss 400 | 401 | writer.add_scalar('Meta_train_Loss/loss_dice', loss_dice.item(), global_step) 402 | writer.add_scalar('Meta_train_Loss/dice_loss_lv', dice_loss_lv.item(), global_step) 403 | writer.add_scalar('Meta_train_Loss/dice_loss_myo', dice_loss_myo.item(), global_step) 404 | writer.add_scalar('Meta_train_Loss/dice_loss_rv', dice_loss_rv.item(), global_step) 405 | writer.add_scalar('Meta_train_Loss/loss_focal', loss_focal.item(), global_step) 406 | writer.add_scalar('Meta_train_Loss/kl_loss', kl_loss.item(), global_step) 407 | writer.add_scalar('Meta_train_Loss/loss_low_rank_a', loss_low_rank_a.item(), global_step) 408 | writer.add_scalar('Meta_train_Loss/batch_loss', batch_loss.item(), global_step) 409 | 410 | # meta-test: 1. load meta-test data 2. calculate meta-test loss 411 | ###############################Meta test####################################################### 412 | imgs = meta_test_imgs.to(device=device, dtype=torch.float32) 413 | mask_type = torch.float32 414 | ce_mask = meta_test_masks.clone().to(device=device, dtype=torch.long) 415 | true_masks = meta_test_masks.to(device=device, dtype=mask_type) 416 | labels = meta_test_labels.to(device=device, dtype=torch.float32) 417 | reco, z_out, z_out_tilde, a_out, _, mu, logvar, cls_out, _ = model(imgs, true_masks, 'training', meta_loss=batch_loss) 418 | 419 | # mode-1 flattering and change the original 4,8,224,224 features to 4x224x224, 8 420 | # randomly pick 4000, 8 features to calculate the singular values 421 | # latent_dim = a_out.size(1) 422 | # a_feature = a_out.permute(0, 2, 3, 1).contiguous().view(-1, latent_dim) 423 | # a_feature = a_feature[torch.randperm(len(a_feature))] 424 | # U_a, S_a, V_a = torch.svd(a_feature[0:2000]) 425 | 426 | seg_pred = a_out[:, :4, :, :] 427 | 428 | reco_loss = l1_distance(reco, imgs) 429 | # kl_loss = losses.KL_divergence(logvar, mu) 430 | # regression_loss = l1_distance(z_out_tilde, z_out) 431 | 432 | sf_seg = F.softmax(seg_pred, dim=1) 433 | dice_loss_lv = losses.dice_loss(sf_seg[:,0,:,:], true_masks[:,0,:,:]) 434 | dice_loss_myo = losses.dice_loss(sf_seg[:,1,:,:], true_masks[:,1,:,:]) 435 | dice_loss_rv = losses.dice_loss(sf_seg[:,2,:,:], true_masks[:,2,:,:]) 436 | dice_loss_bg = losses.dice_loss(sf_seg[:, 3, :, :], true_masks[:, 3, :, :]) 437 | loss_dice = dice_loss_lv + dice_loss_myo + dice_loss_rv + dice_loss_bg 438 | 439 | ce_target = ce_mask[:, 3, :, :]*0 + ce_mask[:, 0, :, :]*1 + ce_mask[:, 1, :, :]*2 + ce_mask[:, 2, :, :]*3 440 | 441 | seg_pred_swap = torch.cat((seg_pred[:,3,:,:].unsqueeze(1), seg_pred[:,:3,:,:]), dim=1) 442 | 443 | loss_focal = focal(seg_pred_swap, ce_target) 444 | 445 | d_cls = criterion(cls_out, labels) 446 | d_losses = d_cls 447 | 448 | batch_loss = 5*loss_dice + 5*loss_focal + reco_loss + d_losses 449 | total_meta_loss += batch_loss 450 | 451 | optimizer.zero_grad() 452 | total_meta_loss.backward() 453 | nn.utils.clip_grad_value_(model.parameters(), 0.1) 454 | optimizer.step() 455 | 456 | pbar.set_postfix(**{'loss (batch)': total_meta_loss.item()}) 457 | pbar.update(imgs.shape[0]) 458 | 459 | if (epoch + 1) > (k1) and (epoch + 1) % k2 == 0: 460 | if global_step % ((long_len//batch_size) // 2) == 0: 461 | a_feature = F.softmax(a_out, dim=1) 462 | a_feature = latent_norm(a_feature) 463 | writer.add_images('Meta_train_images/train', imgs, global_step) 464 | writer.add_images('Meta_train_images/a_out0', a_feature[:,0,:,:].unsqueeze(1), global_step) 465 | writer.add_images('Meta_train_images/a_out1', a_feature[:, 1, :, :].unsqueeze(1), global_step) 466 | writer.add_images('Meta_train_images/a_out2', a_feature[:, 2, :, :].unsqueeze(1), global_step) 467 | writer.add_images('Meta_train_images/a_out3', a_feature[:, 3, :, :].unsqueeze(1), global_step) 468 | writer.add_images('Meta_train_images/a_out4', a_feature[:, 4, :, :].unsqueeze(1), global_step) 469 | writer.add_images('Meta_train_images/a_out5', a_feature[:, 5, :, :].unsqueeze(1), global_step) 470 | writer.add_images('Meta_train_images/a_out6', a_feature[:, 6, :, :].unsqueeze(1), global_step) 471 | writer.add_images('Meta_train_images/a_out7', a_feature[:, 7, :, :].unsqueeze(1), global_step) 472 | writer.add_images('Meta_train_images/train_reco', reco, global_step) 473 | writer.add_images('Meta_train_images/train_true', true_masks[:,0:3,:,:], global_step) 474 | writer.add_images('Meta_train_images/train_pred', sf_seg[:,0:3,:,:] > 0.5, global_step) 475 | writer.add_images('Meta_test_images/train_un_img', un_imgs, global_step) 476 | writer.add_images('Meta_test_images/train_un_mask', sf_un_seg_pred[:, 0:3, :, :] > 0.5, global_step) 477 | 478 | global_step += 1 479 | 480 | if optimizer.param_groups[0]['lr']<=2e-8: 481 | print('Converge') 482 | if (epoch + 1) > k1 and (epoch + 1) % k2 == 0: 483 | 484 | val_score, val_lv, val_myo, val_rv = eval_dgnet(model, val_loader, device, mode='val') 485 | scheduler.step(val_score) 486 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 487 | 488 | logging.info('Validation Dice Coeff: {}'.format(val_score)) 489 | logging.info('Validation LV Dice Coeff: {}'.format(val_lv)) 490 | logging.info('Validation MYO Dice Coeff: {}'.format(val_myo)) 491 | logging.info('Validation RV Dice Coeff: {}'.format(val_rv)) 492 | 493 | writer.add_scalar('Dice/val', val_score, epoch) 494 | writer.add_scalar('Dice/val_lv', val_lv, epoch) 495 | writer.add_scalar('Dice/val_myo', val_myo, epoch) 496 | writer.add_scalar('Dice/val_rv', val_rv, epoch) 497 | 498 | initial_itr = 0 499 | for imgs, true_masks in test_loader: 500 | if initial_itr == 5: 501 | model.eval() 502 | imgs = imgs.to(device=device, dtype=torch.float32) 503 | with torch.no_grad(): 504 | reco, z_out, z_out_tilde, a_out, seg_pred, mu, logvar, _, _ = model(imgs, true_masks, 505 | 'test') 506 | seg_pred = a_out[:, :4, :, :] 507 | mask_type = torch.float32 508 | true_masks = true_masks.to(device=device, dtype=mask_type) 509 | sf_seg_pred = F.softmax(seg_pred, dim=1) 510 | writer.add_images('Test_images/test', imgs, epoch) 511 | writer.add_images('Test_images/test_reco', reco, epoch) 512 | writer.add_images('Test_images/test_true', true_masks[:, 0:3, :, :], epoch) 513 | writer.add_images('Test_images/test_pred', sf_seg_pred[:, 0:3, :, :] > 0.5, epoch) 514 | model.train() 515 | break 516 | else: 517 | pass 518 | initial_itr += 1 519 | test_score, test_lv, test_myo, test_rv = eval_dgnet(model, test_loader, device, mode='test') 520 | 521 | if best_dice < test_score: 522 | best_dice = test_score 523 | best_lv = test_lv 524 | best_myo = test_myo 525 | best_rv = test_rv 526 | print("Epoch checkpoint") 527 | try: 528 | os.mkdir(dir_checkpoint) 529 | logging.info('Created checkpoint directory') 530 | except OSError: 531 | pass 532 | torch.save(model.state_dict(), 533 | dir_checkpoint + 'CP_epoch.pth') 534 | logging.info('Checkpoint saved !') 535 | else: 536 | pass 537 | logging.info('Best Dice Coeff: {}'.format(best_dice)) 538 | logging.info('Best LV Dice Coeff: {}'.format(best_lv)) 539 | logging.info('Best MYO Dice Coeff: {}'.format(best_myo)) 540 | logging.info('Best RV Dice Coeff: {}'.format(best_rv)) 541 | writer.add_scalar('Dice/test', test_score, epoch) 542 | writer.add_scalar('Dice/test_lv', test_lv, epoch) 543 | writer.add_scalar('Dice/test_myo', test_myo, epoch) 544 | writer.add_scalar('Dice/test_rv', test_rv, epoch) 545 | writer.close() 546 | 547 | 548 | if __name__ == '__main__': 549 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') 550 | args = get_args() 551 | device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu') 552 | logging.info(f'Using device {device}') 553 | 554 | 555 | torch.manual_seed(14) 556 | if device.type == 'cuda': 557 | torch.cuda.manual_seed(14) 558 | 559 | train_net(args) 560 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_utils import * -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/utils/__pycache__/model_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from os import path 3 | 4 | def count_parameters(model): 5 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 6 | 7 | def save_network_state(model, width, height, ndf, norm, upsample, num_classes, decoder_type, anatomy_out_channels, z_length, num_mask_channels, optimizer, epoch , name , save_path): 8 | if not path.exists(save_path): 9 | raise ValueError("{} not a valid path to save model state".format(save_path)) 10 | torch.save( 11 | { 12 | 'epoch' : epoch, 13 | 'width': width, 14 | 'height': height, 15 | 'ndf' : ndf, 16 | 'norm' : norm, 17 | 'upsample' : upsample, 18 | 'num_classes': num_classes, 19 | 'decoder_type' : decoder_type, 20 | 'anatomy_out_channels' : anatomy_out_channels, 21 | 'z_length' : z_length, 22 | 'num_mask_channels' : num_mask_channels, 23 | 'model_state_dict' : model.state_dict(), 24 | 'optimizer_state_dict' : optimizer.state_dict() 25 | }, path.join(save_path, name)) --------------------------------------------------------------------------------