├── .idea
├── .gitignore
├── DGNet.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── eval.py
├── figures
├── model.png
└── result.png
├── inference.py
├── loaders
├── M&Ms_Dataset_Information.xlsx
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── mms_dataloader_meta_split.cpython-37.pyc
├── mms_dataloader_meta_split.py
└── mms_dataloader_meta_split_test.py
├── losses
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── losses.cpython-37.pyc
└── losses.py
├── metrics
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── dice_loss.cpython-37.pyc
│ ├── focal_loss.cpython-37.pyc
│ └── gan_loss.cpython-37.pyc
├── dice_loss.py
├── focal_loss.py
├── gan_loss.py
└── hausdorff.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── dgnet.cpython-37.pyc
│ ├── meta_decoder.cpython-37.pyc
│ ├── meta_segmentor.cpython-37.pyc
│ ├── meta_styleencoder.cpython-37.pyc
│ ├── meta_unet.cpython-37.pyc
│ ├── ops.cpython-37.pyc
│ ├── sdnet_ada.cpython-37.pyc
│ ├── unet_parts.cpython-37.pyc
│ └── weight_init.cpython-37.pyc
├── dgnet.py
├── meta_decoder.py
├── meta_segmentor.py
├── meta_styleencoder.py
├── meta_unet.py
├── ops.py
├── unet_parts.py
└── weight_init.py
├── preprocess
├── save_MNMS_2D.py
├── save_MNMS_re.py
├── save_SCGM_2D.py
└── split_MNMS_data.py
├── train_meta.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-37.pyc
└── model_utils.cpython-37.pyc
└── model_utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/DGNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semi-supervised Meta-learning with Disentanglement for Domain-generalised Medical Image Segmentation
2 | 
3 |
4 | This repository contains the official Pytorch implementation of [Semi-supervised Meta-learning with Disentanglement for Domain-generalised Medical Image Segmentation](https://arxiv.org/abs/2106.13292)(accepted by [MICCAI 2021](https://miccai2021.org/en/) as Oral). Check the [presentation](https://www.youtube.com/watch?v=IgXAcO8Zyj8) in our official YouTube channel.
5 |
6 | The repository is created by [Xiao Liu](https://github.com/xxxliu95), [Spyridon Thermos](https://github.com/spthermo), [Alison O'Neil](https://vios.science/team/oneil), and [Sotirios A. Tsaftaris](https://www.eng.ed.ac.uk/about/people/dr-sotirios-tsaftaris), as a result of the collaboration between [The University of Edinburgh](https://www.eng.ed.ac.uk/) and [Canon Medical Systems Europe](https://eu.medical.canon/). You are welcome to visit our group website: [vios.s](https://vios.science/)
7 |
8 | # System Requirements
9 | * Pytorch 1.5.1 or higher with GPU support
10 | * Python 3.7.2 or higher
11 | * SciPy 1.5.2 or higher
12 | * CUDA toolkit 10 or newer
13 | * Nibabel
14 | * Pillow
15 | * Scikit-image
16 | * TensorBoard
17 | * Tqdm
18 |
19 |
20 | # Datasets
21 | We used two datasets in the paper: [Multi-Centre, Multi-Vendor & Multi-Disease
22 | Cardiac Image Segmentation Challenge (M&Ms) datast](https://www.ub.edu/mnms/) and [Spinal cord grey matter segmentation challenge dataset](http://niftyweb.cs.ucl.ac.uk/challenge/index.php). The dataloader in this repo is only for M&Ms dataset.
23 |
24 | # Preprocessing
25 |
26 | You need to first change the dirs in the scripts of preprocess folder. Download the M&Ms data and run ```split_MNMS_data.py``` to split the original dataset into different domains. Then run ```save_MNMS_2D.py``` to save the original 4D data as 2D numpy arrays. Finally, run ```save_MNMS_re.py``` to save the resolution of each datum.
27 |
28 | # Training
29 | Note that the hyperparameters in the current version are tuned for BCD to A cases. For other cases, the hyperparameters and few specific layers of the model are slightly different. To train the model with 5% labeled data, run:
30 | ```
31 | python train_meta.py -e 150 -c cp_dgnet_meta_5_tvA/ -t A -w DGNetRE_COM_META_5_tvA -g 0
32 | ```
33 | Here the default learning rate is 4e-5. You can change the learning rate by adding ```-lr 0.00002``` (sometimes this is better).
34 |
35 | To train the model with 100% labeled data, try to change the training parameters to:
36 | ```
37 | k_un = 1
38 | k1 = 20
39 | k2 = 2
40 | ```
41 | The first parameter controls how many iterations you want the model to be trained with unlabaled data for every iteration of training. ```k1 = 20``` means the learning rate will start to decay after 20 epochs and ```k2 = 2``` means it will check if decay learning every 2 epochs.
42 |
43 | Also, change the ratio ```k=0.05``` (line 221) to ```k=1``` in ```mms_dataloader_meta_split.py```.
44 |
45 | Then, run:
46 | ```
47 | python train_meta.py -e 80 -c cp_dgnet_meta_100_tvA/ -t A -w DGNetRE_COM_META_100_tvA -g 0
48 | ```
49 | Finally, when training the model, changing the ```resampling_rate=1.2``` (line 47) in ```mms_dataloader_meta_split.py``` to 1.1 - 1.3 may cause better results. This will change the rescale ratio when preprocessing the images, which will affect the size of the anatomy of interest.
50 |
51 | # Inference
52 | After training, you can test the model:
53 | ```
54 | python inference.py -bs 1 -c cp_dgnet_meta_100_tvA/ -t A -g 0
55 | ```
56 | This will output the DICE and Hausdorff results as well as the standard deviation. Similarly, changing the ```resampling_rate=1.2``` (line 47) in ```mms_dataloader_meta_split_test.py``` to 1.1 - 1.3 may cause better results.
57 |
58 | # Qualitative results
59 | 
60 |
61 | # Citation
62 | ```
63 | @inproceedings{liu2021semi,
64 | title={Semi-supervised Meta-learning with Disentanglement for Domain-generalised Medical Image Segmentation},
65 | author={Liu, Xiao and Thermos, Spyridon and O’Neil, Alison and Tsaftaris, Sotirios A},
66 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
67 | pages={307--317},
68 | year={2021},
69 | organization={Springer}
70 | }
71 | ```
72 |
73 | # Acknowlegement
74 | Part of the code is based on [SDNet](https://github.com/spthermo/SDNet), [MLDG](https://github.com/HAHA-DL/MLDG), [medical-mldg-seg](https://github.com/Pulkit-Khandelwal/medical-mldg-seg) and [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet).
75 |
76 | # License
77 | All scripts are released under the MIT License.
78 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import torch.nn.functional as F
4 | from metrics.dice_loss import dice_coeff
5 |
6 | def eval_dgnet(net, loader, device, mode):
7 | """Evaluation without the densecrf with the dice coefficient"""
8 | net.eval()
9 | mask_type = torch.float32
10 | n_val = len(loader) # the number of batch
11 | tot = 0
12 | tot_lv = 0
13 | tot_myo = 0
14 | tot_rv = 0
15 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
16 | if mode=='val':
17 | for imgs, true_masks, _ in loader:
18 | imgs = imgs.to(device=device, dtype=torch.float32)
19 | true_masks = true_masks.to(device=device, dtype=mask_type)
20 |
21 | with torch.no_grad():
22 | reco, z_out, mu_tilde, a_out, mask_pred, mu, logvar, _, _ = net(imgs, true_masks, 'test')
23 |
24 | mask_pred = a_out[:, :4, :, :]
25 | pred = F.softmax(mask_pred, dim=1)
26 | pred = (pred > 0.5).float()
27 | tot += dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :], device).item()
28 | tot_lv += dice_coeff(pred[:, 0, :, :], true_masks[:, 0, :, :], device).item()
29 | tot_myo += dice_coeff(pred[:, 1, :, :], true_masks[:, 1, :, :], device).item()
30 | tot_rv += dice_coeff(pred[:, 2, :, :], true_masks[:, 2, :, :], device).item()
31 | pbar.update()
32 | else:
33 | for imgs, true_masks in loader:
34 | imgs = imgs.to(device=device, dtype=torch.float32)
35 | true_masks = true_masks.to(device=device, dtype=mask_type)
36 |
37 | with torch.no_grad():
38 | reco, z_out, mu_tilde, a_out, mask_pred, mu, logvar, _, _ = net(imgs, true_masks, 'test')
39 |
40 | mask_pred = a_out[:, :4, :, :]
41 | pred = F.softmax(mask_pred, dim=1)
42 | pred = (pred > 0.5).float()
43 | tot += dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :], device).item()
44 | tot_lv += dice_coeff(pred[:, 0, :, :], true_masks[:, 0, :, :], device).item()
45 | tot_myo += dice_coeff(pred[:, 1, :, :], true_masks[:, 1, :, :], device).item()
46 | tot_rv += dice_coeff(pred[:, 2, :, :], true_masks[:, 2, :, :], device).item()
47 | pbar.update()
48 |
49 | net.train()
50 | return tot / n_val, tot_lv / n_val, tot_myo / n_val, tot_rv / n_val
--------------------------------------------------------------------------------
/figures/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/figures/model.png
--------------------------------------------------------------------------------
/figures/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/figures/result.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import torch.nn.functional as F
4 | import statistics
5 | import utils
6 | from loaders.mms_dataloader_meta_split_test import get_meta_split_data_loaders
7 | import models
8 | from metrics.dice_loss import dice_coeff
9 | from metrics.hausdorff import hausdorff_distance
10 |
11 | # python inference.py -bs 1 -c cp_dgnet_gan_meta_dir_5_tvA/ -t A -mn dgnet -g 0
12 |
13 |
14 | def get_args():
15 | usage_text = (
16 | "SNet Pytorch Implementation"
17 | "Usage: python train.py [options],"
18 | " with [options]:"
19 | )
20 | parser = argparse.ArgumentParser(description=usage_text)
21 | #training details
22 | parser.add_argument('-bs','--batch_size', type=int, default=4, help='Number of inputs per batch')
23 | parser.add_argument('-c', '--cp', type=str, default='checkpoints/', help='The name of the checkpoints.')
24 | parser.add_argument('-t', '--tv', type=str, default='D', help='The name of the target vendor.')
25 | parser.add_argument('-w', '--wc', type=str, default='DGNet_LR00002_LDv5', help='The name of the writter summary.')
26 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
27 | parser.add_argument('-mn','--model_name', type=str, default='dgnet', help='Name of the model architecture to be used for training/testing.')
28 | parser.add_argument('-lr','--learning_rate', type=float, default='0.00002', help='The learning rate for model training')
29 | parser.add_argument('-wi','--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)')
30 | parser.add_argument('--save_path', type=str, default='checkpoints', help= 'Path to save model checkpoints')
31 | parser.add_argument('--decoder_type', type=str, default='film', help='Choose decoder type between FiLM and SPADE')
32 | #hardware
33 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
34 | parser.add_argument('--num_workers' ,type= int, default = 0, help='Number of workers to use for dataload')
35 |
36 | return parser.parse_args()
37 |
38 | args = get_args()
39 | device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')
40 |
41 | batch_size = args.batch_size
42 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
43 |
44 | dir_checkpoint = args.cp
45 | test_vendor = args.tv
46 | # wc = args.wc
47 | model_name = args.model_name
48 |
49 | # Model selection and initialization
50 | model_params = {
51 | 'width': 288,
52 | 'height': 288,
53 | 'ndf': 64,
54 | 'norm': "batchnorm",
55 | 'upsample': "nearest",
56 | 'num_classes': 3,
57 | 'decoder_type': args.decoder_type,
58 | 'anatomy_out_channels': 8,
59 | 'z_length': 8,
60 | 'num_mask_channels': 8,
61 |
62 | }
63 | model = models.get_model(model_name, model_params)
64 | num_params = utils.count_parameters(model)
65 | print('Model Parameters: ', num_params)
66 | model.load_state_dict(torch.load(dir_checkpoint+'CP_epoch.pth', map_location=device))
67 | model.to(device)
68 |
69 | # writer = SummaryWriter(comment=wc)
70 |
71 | _, _, \
72 | _, _, \
73 | _, _, \
74 | test_loader, \
75 | _, _, _ = get_meta_split_data_loaders(
76 | batch_size, test_vendor=test_vendor, image_size=224)
77 |
78 | step = 0
79 | tot = []
80 | tot_sub = []
81 | tot_hsd = []
82 | tot_sub_hsd = []
83 | flag = '000'
84 | # i = 0
85 | for imgs, true_masks, path_img in test_loader:
86 | model.eval()
87 | imgs = imgs.to(device=device, dtype=torch.float32)
88 | mask_type = torch.float32
89 | true_masks = true_masks.to(device=device, dtype=mask_type)
90 | print(flag)
91 | if path_img[0][-10: -7] != flag:
92 | # if i > 10:
93 | # break
94 | # i += 1
95 | flag = path_img[0][-10: -7]
96 | tot.append(sum(tot_sub)/len(tot_sub))
97 | tot_sub = []
98 | tot_hsd.append(sum(tot_sub_hsd)/len(tot_sub_hsd))
99 | tot_sub_hsd = []
100 | with torch.no_grad():
101 | reco, z_out, z_out_tilde, a_out, _, mu, logvar, cls_out, _ = model(imgs, true_masks, 'test')
102 |
103 | mask_pred = a_out[:, :4, :, :]
104 | pred = F.softmax(mask_pred, dim=1)
105 | pred = (pred > 0.5).float()
106 | dice = dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :], device).item()
107 | hsd = hausdorff_distance(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :])
108 | tot_sub.append(dice)
109 | tot_sub_hsd.append(hsd)
110 | print(step)
111 | step += 1
112 |
113 | print(tot)
114 |
115 | print(sum(tot)/len(tot))
116 | print(statistics.stdev(tot))
117 |
118 | print(tot_hsd)
119 |
120 | print(sum(tot_hsd)/len(tot_hsd))
121 | print(statistics.stdev(tot_hsd))
122 |
--------------------------------------------------------------------------------
/loaders/M&Ms_Dataset_Information.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/loaders/M&Ms_Dataset_Information.xlsx
--------------------------------------------------------------------------------
/loaders/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/loaders/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/loaders/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/loaders/__pycache__/mms_dataloader_meta_split.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/loaders/__pycache__/mms_dataloader_meta_split.cpython-37.pyc
--------------------------------------------------------------------------------
/loaders/mms_dataloader_meta_split.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torchfile
3 | from torch.utils.data import DataLoader, TensorDataset
4 | from torchvision import transforms
5 | import torchvision.transforms.functional as F
6 | import torch
7 | import torch.nn as nn
8 | import os
9 | import torchvision.utils as vutils
10 | import numpy as np
11 | import torch.nn.init as init
12 | import torch.utils.data as data
13 | import torch
14 | import random
15 | import xlrd
16 | import math
17 | from skimage.exposure import match_histograms
18 |
19 | # Data directories
20 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorA/'
21 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorA/'
22 | ReA_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorA/'
23 |
24 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center2/'
25 | LabeledVendorB2_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center2/'
26 | ReB2_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center2/'
27 |
28 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center3/'
29 | LabeledVendorB3_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center3/'
30 | ReB3_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center3/'
31 |
32 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorC/'
33 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorC/'
34 | ReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorC/'
35 |
36 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorD/'
37 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorD/'
38 | ReD_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorD/'
39 |
40 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Unlabeled/vendorC/'
41 | UnReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Unlabeled/vendorC/'
42 |
43 | Re_dir = [ReA_dir, ReB2_dir, ReB3_dir, ReC_dir, ReD_dir]
44 | Labeled_data_dir = [LabeledVendorA_data_dir, LabeledVendorB2_data_dir, LabeledVendorB3_data_dir, LabeledVendorC_data_dir, LabeledVendorD_data_dir]
45 | Labeled_mask_dir = [LabeledVendorA_mask_dir, LabeledVendorB2_mask_dir, LabeledVendorB3_mask_dir, LabeledVendorC_mask_dir, LabeledVendorD_mask_dir]
46 |
47 | resampling_rate = 1.2
48 |
49 | def get_meta_split_data_loaders(batch_size, test_vendor='D', image_size=224):
50 |
51 | random.seed(14)
52 |
53 | domain_1_labeled_loader, domain_1_unlabeled_loader, \
54 | domain_2_labeled_loader, domain_2_unlabeled_loader,\
55 | domain_3_labeled_loader, domain_3_unlabeled_loader, \
56 | test_loader, \
57 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset = \
58 | get_data_loader_folder(Labeled_data_dir, Labeled_mask_dir, batch_size, image_size, test_num=test_vendor)
59 |
60 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \
61 | domain_2_labeled_loader, domain_2_unlabeled_loader,\
62 | domain_3_labeled_loader, domain_3_unlabeled_loader, \
63 | test_loader, \
64 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset
65 |
66 |
67 | def get_data_loader_folder(data_folders, mask_folders, batch_size, new_size=288, test_num='D', num_workers=2):
68 | if test_num=='A':
69 | domain_1_img_dirs = [data_folders[1], data_folders[2]]
70 | domain_1_mask_dirs = [mask_folders[1], mask_folders[2]]
71 | domain_2_img_dirs = [data_folders[3]]
72 | domain_2_mask_dirs = [mask_folders[3]]
73 | domain_3_img_dirs = [data_folders[4]]
74 | domain_3_mask_dirs = [mask_folders[4]]
75 |
76 | test_data_dirs = [data_folders[0]]
77 | test_mask_dirs = [mask_folders[0]]
78 |
79 | domain_1_re = [Re_dir[1], Re_dir[2]]
80 | domain_2_re = [Re_dir[3]]
81 | domain_3_re = [Re_dir[4]]
82 |
83 | test_re = [Re_dir[0]]
84 |
85 | domain_1_num = [74, 51]
86 | domain_2_num = [50]
87 | domain_3_num = [50]
88 | test_num = [95]
89 |
90 | elif test_num=='B':
91 | domain_1_img_dirs = [data_folders[0]]
92 | domain_1_mask_dirs = [mask_folders[0]]
93 | domain_2_img_dirs = [data_folders[3]]
94 | domain_2_mask_dirs = [mask_folders[3]]
95 | domain_3_img_dirs = [data_folders[4]]
96 | domain_3_mask_dirs = [mask_folders[4]]
97 |
98 | test_data_dirs = [data_folders[1], data_folders[2]]
99 | test_mask_dirs = [mask_folders[1], mask_folders[2]]
100 |
101 | domain_1_re = [Re_dir[0]]
102 | domain_2_re = [Re_dir[3]]
103 | domain_3_re = [Re_dir[4]]
104 | test_re = [Re_dir[1], Re_dir[2]]
105 |
106 | domain_1_num = [95]
107 | domain_2_num = [50]
108 | domain_3_num = [50]
109 | test_num = [74, 51]
110 |
111 | elif test_num=='C':
112 | domain_1_img_dirs = [data_folders[0]]
113 | domain_1_mask_dirs = [mask_folders[0]]
114 | domain_2_img_dirs = [data_folders[1], data_folders[2]]
115 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]]
116 | domain_3_img_dirs = [data_folders[4]]
117 | domain_3_mask_dirs = [mask_folders[4]]
118 |
119 | test_data_dirs = [data_folders[3]]
120 | test_mask_dirs = [mask_folders[3]]
121 |
122 | domain_1_re = [Re_dir[0]]
123 | domain_2_re = [Re_dir[1], Re_dir[2]]
124 | domain_3_re = [Re_dir[4]]
125 | test_re = [Re_dir[3]]
126 |
127 | domain_1_num = [95]
128 | domain_2_num = [74, 51]
129 | domain_3_num = [50]
130 | test_num = [50]
131 |
132 | elif test_num=='D':
133 | domain_1_img_dirs = [data_folders[0]]
134 | domain_1_mask_dirs = [mask_folders[0]]
135 | domain_2_img_dirs = [data_folders[1], data_folders[2]]
136 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]]
137 | domain_3_img_dirs = [data_folders[3]]
138 | domain_3_mask_dirs = [mask_folders[3]]
139 |
140 | test_data_dirs = [data_folders[4]]
141 | test_mask_dirs = [mask_folders[4]]
142 |
143 | domain_1_re = [Re_dir[0]]
144 | domain_2_re = [Re_dir[1], Re_dir[2]]
145 | domain_3_re = [Re_dir[3]]
146 | test_re = [Re_dir[4]]
147 |
148 | domain_1_num = [95]
149 | domain_2_num = [74, 51]
150 | domain_3_num = [50]
151 | test_num = [50]
152 |
153 | else:
154 | print('Wrong test vendor!')
155 |
156 |
157 | domain_1_labeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, num_label=domain_1_num, train=True, labeled=True)
158 | domain_2_labeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, num_label=domain_2_num, train=True, labeled=True)
159 | domain_3_labeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, num_label=domain_3_num, train=True, labeled=True)
160 |
161 |
162 | # domain_1_labeled_loader = DataLoader(dataset=domain_1_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
163 | # domain_2_labeled_loader = DataLoader(dataset=domain_2_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
164 | # domain_3_labeled_loader = DataLoader(dataset=domain_3_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
165 |
166 | domain_1_labeled_loader = None
167 | domain_2_labeled_loader = None
168 | domain_3_labeled_loader = None
169 |
170 | domain_1_unlabeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, train=True, labeled=False)
171 | domain_2_unlabeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, train=True, labeled=False)
172 | domain_3_unlabeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, train=True, labeled=False)
173 |
174 | domain_1_unlabeled_loader = DataLoader(dataset=domain_1_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
175 | domain_2_unlabeled_loader = DataLoader(dataset=domain_2_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
176 | domain_3_unlabeled_loader = DataLoader(dataset=domain_3_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
177 |
178 |
179 | test_dataset = ImageFolder(test_data_dirs, test_mask_dirs, domain_1_img_dirs, test_re, train=False, labeled=True)
180 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
181 |
182 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \
183 | domain_2_labeled_loader, domain_2_unlabeled_loader,\
184 | domain_3_labeled_loader, domain_3_unlabeled_loader, \
185 | test_loader, \
186 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset
187 |
188 | def default_loader(path):
189 | return np.load(path)['arr_0']
190 |
191 | def make_dataset(dir):
192 | images = []
193 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
194 |
195 | for root, _, fnames in sorted(os.walk(dir)):
196 | for fname in fnames:
197 | path = os.path.join(root, fname)
198 | images.append(path)
199 | return images
200 |
201 | class ImageFolder(data.Dataset):
202 | def __init__(self, data_dirs, mask_dirs, ref_dir, re, train=True, label=None, num_label=None, labeled=True, loader=default_loader):
203 |
204 | print(data_dirs)
205 | print(mask_dirs)
206 |
207 | reso_dir = re
208 | temp_imgs = []
209 | temp_masks = []
210 | temp_re = []
211 | domain_labels = []
212 |
213 | tem_ref_imgs = []
214 |
215 | if train:
216 | #100%
217 | # k = 1
218 | #10%
219 | # k = 0.1
220 | #5%
221 | k = 0.05
222 | #2%
223 | # k = 0.02
224 | else:
225 | k = 1
226 |
227 |
228 | for num_set in range(len(data_dirs)):
229 | re_roots = sorted(make_dataset(reso_dir[num_set]))
230 | data_roots = sorted(make_dataset(data_dirs[num_set]))
231 | mask_roots = sorted(make_dataset(mask_dirs[num_set]))
232 | num_label_data = 0
233 | for num_data in range(len(data_roots)):
234 | if labeled:
235 | if train:
236 | n_label = str(math.ceil(num_label[num_set] * k + 1))
237 | if '00'+n_label==data_roots[num_data][-10:-7] or '0'+n_label==data_roots[num_data][-10:-7]:
238 | print(n_label)
239 | print(data_roots[num_data][-10:-7])
240 | break
241 |
242 | for num_mask in range(len(mask_roots)):
243 | if data_roots[num_data][-10:-4] == mask_roots[num_mask][-10:-4]:
244 | temp_re.append(re_roots[num_data])
245 | temp_imgs.append(data_roots[num_data])
246 | temp_masks.append(mask_roots[num_mask])
247 | domain_labels.append(label)
248 | num_label_data += 1
249 | else:
250 | pass
251 | else:
252 | temp_re.append(re_roots[num_data])
253 | temp_imgs.append(data_roots[num_data])
254 | domain_labels.append(label)
255 |
256 | for num_set in range(len(ref_dir)):
257 | data_roots = sorted(make_dataset(ref_dir[num_set]))
258 | for num_data in range(len(data_roots)):
259 | tem_ref_imgs.append(data_roots[num_data])
260 |
261 | reso = temp_re
262 | imgs = temp_imgs
263 | masks = temp_masks
264 | labels = domain_labels
265 |
266 | print(len(masks))
267 |
268 | ref_imgs = tem_ref_imgs
269 |
270 | # add something here to index, such that can split the data
271 | # index = random.sample(range(len(temp_img)), len(temp_mask))
272 |
273 | self.reso = reso
274 | self.imgs = imgs
275 | self.masks = masks
276 | self.labels = labels
277 | self.new_size = 288
278 | self.loader = loader
279 | self.labeled = labeled
280 | self.train = train
281 | self.ref = ref_imgs
282 |
283 | def __getitem__(self, index):
284 | if self.train:
285 | index = random.randrange(len(self.imgs))
286 | else:
287 | pass
288 |
289 | path_re = self.reso[index]
290 | re = self.loader(path_re)
291 | re = re[0]
292 |
293 | path_img = self.imgs[index]
294 | img = self.loader(path_img) # numpy, HxW, numpy.Float64
295 | #
296 | # ref_paths = random.sample(self.ref, 1)
297 | # ref_img = self.loader(ref_paths[0])
298 | #
299 | # img = match_histograms(img, ref_img)
300 |
301 | label = self.labels[index]
302 |
303 |
304 | if label==0:
305 | one_hot_label = torch.tensor([[1], [0], [0]])
306 | elif label==1:
307 | one_hot_label = torch.tensor([[0], [1], [0]])
308 | elif label==2:
309 | one_hot_label = torch.tensor([[0], [0], [1]])
310 | else:
311 | one_hot_label = torch.tensor([[0], [0], [0]])
312 |
313 | # Intensity cropping:
314 | p5 = np.percentile(img.flatten(), 0.5)
315 | p95 = np.percentile(img.flatten(), 99.5)
316 | img = np.clip(img, p5, p95)
317 |
318 | img -= img.min()
319 | img /= img.max()
320 | img = img.astype('float32')
321 |
322 | crop_size = 300
323 |
324 | # Augmentations:
325 | # 1. random rotation
326 | # 2. random scaling 0.8 - 1.2
327 | # 3. random crop from 280x280
328 | # 4. random hflip
329 | # 5. random vflip
330 | # 6. color jitter
331 | # 7. Gaussian filtering
332 |
333 | img_tensor = F.to_tensor(np.array(img))
334 | img_size = img_tensor.size()
335 |
336 | if self.labeled:
337 | if self.train:
338 | img = F.to_pil_image(img)
339 | # rotate, random angle between 0 - 90
340 | angle = random.randint(0, 90)
341 | img = F.rotate(img, angle, Image.BILINEAR)
342 |
343 | path_mask = self.masks[index]
344 | mask = Image.open(path_mask) # numpy, HxWx3
345 | # rotate, random angle between 0 - 90
346 | mask = F.rotate(mask, angle, Image.NEAREST)
347 |
348 | ## Find the region of mask
349 | norm_mask = F.to_tensor(np.array(mask))
350 | region = norm_mask[0] + norm_mask[1] + norm_mask[2]
351 | non_zero_index = torch.nonzero(region == 1, as_tuple=False)
352 | if region.sum() > 0:
353 | len_m = len(non_zero_index[0])
354 | x_region = non_zero_index[len_m//2][0]
355 | y_region = non_zero_index[len_m//2][1]
356 | x_region = int(x_region.item())
357 | y_region = int(y_region.item())
358 | else:
359 | x_region = norm_mask.size(-2) // 2
360 | y_region = norm_mask.size(-1) // 2
361 |
362 | # resize and center-crop to 280x280
363 | resize_order = re / resampling_rate
364 | resize_size_h = int(img_size[-2] * resize_order)
365 | resize_size_w = int(img_size[-1] * resize_order)
366 |
367 | left_size = 0
368 | top_size = 0
369 | right_size = 0
370 | bot_size = 0
371 | if resize_size_h < self.new_size:
372 | top_size = (self.new_size - resize_size_h) // 2
373 | bot_size = (self.new_size - resize_size_h) - top_size
374 | if resize_size_w < self.new_size:
375 | left_size = (self.new_size - resize_size_w) // 2
376 | right_size = (self.new_size - resize_size_w) - left_size
377 |
378 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))]
379 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list
380 | transform = transforms.Compose(transform_list)
381 |
382 | img = transform(img)
383 |
384 |
385 | ## Define the crop index
386 | if top_size >= 0:
387 | top_crop = 0
388 | else:
389 | if x_region > self.new_size//2:
390 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2):
391 | top_crop = x_region - self.new_size//2
392 | else:
393 | top_crop = norm_mask.size(-2) - self.new_size
394 | else:
395 | top_crop = 0
396 |
397 | if left_size >= 0:
398 | left_crop = 0
399 | else:
400 | if y_region > self.new_size//2:
401 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1):
402 | left_crop = y_region - self.new_size//2
403 | else:
404 | left_crop = norm_mask.size(-1) - self.new_size
405 | else:
406 | left_crop = 0
407 |
408 | # random crop to 224x224
409 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size)
410 |
411 | # random flip
412 | hflip_p = random.random()
413 | img = F.hflip(img) if hflip_p >= 0.5 else img
414 | vflip_p = random.random()
415 | img = F.vflip(img) if vflip_p >= 0.5 else img
416 |
417 | img = F.to_tensor(np.array(img))
418 |
419 | # # Gamma correction: random gamma from [0.5, 1.5]
420 | # gamma = 0.5 + random.random()
421 | # img_max = img.max()
422 | # img = img_max * torch.pow((img / img_max), gamma)
423 |
424 | # Gaussian bluring:
425 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))]
426 | transform = transforms.Compose(transform_list)
427 | img = transform(img)
428 |
429 | # resize and center-crop to 280x280
430 | transform_mask_list = [transforms.Pad(
431 | (left_size, top_size, right_size, bot_size))]
432 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w),
433 | interpolation=Image.NEAREST)] + transform_mask_list
434 | transform_mask = transforms.Compose(transform_mask_list)
435 |
436 | mask = transform_mask(mask) # C,H,W
437 |
438 | # random crop to 224x224
439 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size)
440 |
441 | # random flip
442 | mask = F.hflip(mask) if hflip_p >= 0.5 else mask
443 | mask = F.vflip(mask) if vflip_p >= 0.5 else mask
444 |
445 | mask = F.to_tensor(np.array(mask))
446 |
447 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W
448 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1)))
449 | mask = torch.cat((mask, mask_bg), dim=0)
450 |
451 | return img, mask, one_hot_label.squeeze() # pytorch: N,C,H,W
452 |
453 | else:
454 | path_mask = self.masks[index]
455 | mask = Image.open(path_mask) # numpy, HxWx3
456 | # resize and center-crop to 280x280
457 |
458 | ## Find the region of mask
459 | norm_mask = F.to_tensor(np.array(mask))
460 | region = norm_mask[0] + norm_mask[1] + norm_mask[2]
461 | non_zero_index = torch.nonzero(region == 1, as_tuple=False)
462 | if region.sum() > 0:
463 | len_m = len(non_zero_index[0])
464 | x_region = non_zero_index[len_m//2][0]
465 | y_region = non_zero_index[len_m//2][1]
466 | x_region = int(x_region.item())
467 | y_region = int(y_region.item())
468 | else:
469 | x_region = norm_mask.size(-2) // 2
470 | y_region = norm_mask.size(-1) // 2
471 |
472 | resize_order = re / resampling_rate
473 | resize_size_h = int(img_size[-2] * resize_order)
474 | resize_size_w = int(img_size[-1] * resize_order)
475 |
476 | left_size = 0
477 | top_size = 0
478 | right_size = 0
479 | bot_size = 0
480 | if resize_size_h < self.new_size:
481 | top_size = (self.new_size - resize_size_h) // 2
482 | bot_size = (self.new_size - resize_size_h) - top_size
483 | if resize_size_w < self.new_size:
484 | left_size = (self.new_size - resize_size_w) // 2
485 | right_size = (self.new_size - resize_size_w) - left_size
486 |
487 |
488 | # transform_list = [transforms.CenterCrop((crop_size, crop_size))]
489 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))]
490 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list
491 | transform_list = [transforms.ToPILImage()] + transform_list
492 | transform = transforms.Compose(transform_list)
493 | img = transform(img)
494 | img = F.to_tensor(np.array(img))
495 |
496 | ## Define the crop index
497 | if top_size >= 0:
498 | top_crop = 0
499 | else:
500 | if x_region > self.new_size//2:
501 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2):
502 | top_crop = x_region - self.new_size//2
503 | else:
504 | top_crop = norm_mask.size(-2) - self.new_size
505 | else:
506 | top_crop = 0
507 |
508 | if left_size >= 0:
509 | left_crop = 0
510 | else:
511 | if y_region > self.new_size//2:
512 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1):
513 | left_crop = y_region - self.new_size//2
514 | else:
515 | left_crop = norm_mask.size(-1) - self.new_size
516 | else:
517 | left_crop = 0
518 |
519 | # random crop to 224x224
520 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size)
521 |
522 | # resize and center-crop to 280x280
523 | # transform_mask_list = [transforms.CenterCrop((crop_size, crop_size))]
524 | transform_mask_list = [transforms.Pad(
525 | (left_size, top_size, right_size, bot_size))]
526 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w),
527 | interpolation=Image.NEAREST)] + transform_mask_list
528 | transform_mask = transforms.Compose(transform_mask_list)
529 |
530 | mask = transform_mask(mask) # C,H,W
531 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size)
532 | mask = F.to_tensor(np.array(mask))
533 |
534 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W
535 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1)))
536 | mask = torch.cat((mask, mask_bg), dim=0)
537 |
538 | return img, mask
539 |
540 | else:
541 | img = F.to_pil_image(img)
542 | # rotate, random angle between 0 - 90
543 | angle = random.randint(0, 90)
544 | img = F.rotate(img, angle, Image.BILINEAR)
545 |
546 | # resize and center-crop to 280x280
547 | resize_order = re / resampling_rate
548 | resize_size_h = int(img_size[-2] * resize_order)
549 | resize_size_w = int(img_size[-1] * resize_order)
550 |
551 | left_size = 0
552 | top_size = 0
553 | right_size = 0
554 | bot_size = 0
555 | if resize_size_h < crop_size:
556 | top_size = (crop_size - resize_size_h) // 2
557 | bot_size = (crop_size - resize_size_h) - top_size
558 | if resize_size_w < crop_size:
559 | left_size = (crop_size - resize_size_w) // 2
560 | right_size = (crop_size - resize_size_w) - left_size
561 |
562 | transform_list = [transforms.CenterCrop((crop_size, crop_size))]
563 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] + transform_list
564 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list
565 | transform = transforms.Compose(transform_list)
566 |
567 | img = transform(img)
568 |
569 | # random crop to 224x224
570 | top_crop = random.randint(0, crop_size - self.new_size)
571 | left_crop = random.randint(0, crop_size - self.new_size)
572 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size)
573 |
574 | # random flip
575 | hflip_p = random.random()
576 | img = F.hflip(img) if hflip_p >= 0.5 else img
577 | vflip_p = random.random()
578 | img = F.vflip(img) if vflip_p >= 0.5 else img
579 |
580 | img = F.to_tensor(np.array(img))
581 |
582 | # # Gamma correction: random gamma from [0.5, 1.5]
583 | # gamma = 0.5 + random.random()
584 | # img_max = img.max()
585 | # img = img_max*torch.pow((img/img_max), gamma)
586 |
587 | # Gaussian bluring:
588 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))]
589 | transform = transforms.Compose(transform_list)
590 | img = transform(img)
591 |
592 | return img, one_hot_label.squeeze() # pytorch: N,C,H,W
593 |
594 | def __len__(self):
595 | return len(self.imgs)
596 |
597 |
598 |
--------------------------------------------------------------------------------
/loaders/mms_dataloader_meta_split_test.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torchfile
3 | from torch.utils.data import DataLoader, TensorDataset
4 | from torchvision import transforms
5 | import torchvision.transforms.functional as F
6 | import torch
7 | import torch.nn as nn
8 | import os
9 | import torchvision.utils as vutils
10 | import numpy as np
11 | import torch.nn.init as init
12 | import torch.utils.data as data
13 | import torch
14 | import random
15 | import xlrd
16 | import math
17 | from skimage.exposure import match_histograms
18 |
19 | # Data directories
20 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorA/'
21 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorA/'
22 | ReA_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorA/'
23 |
24 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center2/'
25 | LabeledVendorB2_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center2/'
26 | ReB2_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center2/'
27 |
28 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorB/center3/'
29 | LabeledVendorB3_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorB/center3/'
30 | ReB3_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorB/center3/'
31 |
32 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorC/'
33 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorC/'
34 | ReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorC/'
35 |
36 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Labeled/vendorD/'
37 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_mask/Labeled/vendorD/'
38 | ReD_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Labeled/vendorD/'
39 |
40 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_data/Unlabeled/vendorC/'
41 | UnReC_dir = '/home/s1575424/xiao/Year2/miccai2021/mnms_split_2D_re/Unlabeled/vendorC/'
42 |
43 | Re_dir = [ReA_dir, ReB2_dir, ReB3_dir, ReC_dir, ReD_dir]
44 | Labeled_data_dir = [LabeledVendorA_data_dir, LabeledVendorB2_data_dir, LabeledVendorB3_data_dir, LabeledVendorC_data_dir, LabeledVendorD_data_dir]
45 | Labeled_mask_dir = [LabeledVendorA_mask_dir, LabeledVendorB2_mask_dir, LabeledVendorB3_mask_dir, LabeledVendorC_mask_dir, LabeledVendorD_mask_dir]
46 |
47 | resampling_rate = 1.2
48 |
49 | def get_meta_split_data_loaders(batch_size, test_vendor='D', image_size=224):
50 |
51 | random.seed(14)
52 |
53 | domain_1_labeled_loader, domain_1_unlabeled_loader, \
54 | domain_2_labeled_loader, domain_2_unlabeled_loader,\
55 | domain_3_labeled_loader, domain_3_unlabeled_loader, \
56 | test_loader, \
57 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset = \
58 | get_data_loader_folder(Labeled_data_dir, Labeled_mask_dir, batch_size, image_size, test_num=test_vendor)
59 |
60 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \
61 | domain_2_labeled_loader, domain_2_unlabeled_loader,\
62 | domain_3_labeled_loader, domain_3_unlabeled_loader, \
63 | test_loader, \
64 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset
65 |
66 |
67 | def get_data_loader_folder(data_folders, mask_folders, batch_size, new_size=288, test_num='D', num_workers=2):
68 | if test_num=='A':
69 | domain_1_img_dirs = [data_folders[1], data_folders[2]]
70 | domain_1_mask_dirs = [mask_folders[1], mask_folders[2]]
71 | domain_2_img_dirs = [data_folders[3]]
72 | domain_2_mask_dirs = [mask_folders[3]]
73 | domain_3_img_dirs = [data_folders[4]]
74 | domain_3_mask_dirs = [mask_folders[4]]
75 |
76 | test_data_dirs = [data_folders[0]]
77 | test_mask_dirs = [mask_folders[0]]
78 |
79 | domain_1_re = [Re_dir[1], Re_dir[2]]
80 | domain_2_re = [Re_dir[3]]
81 | domain_3_re = [Re_dir[4]]
82 |
83 | test_re = [Re_dir[0]]
84 |
85 | domain_1_num = [74, 51]
86 | domain_2_num = [50]
87 | domain_3_num = [50]
88 | test_num = [95]
89 |
90 | elif test_num=='B':
91 | domain_1_img_dirs = [data_folders[0]]
92 | domain_1_mask_dirs = [mask_folders[0]]
93 | domain_2_img_dirs = [data_folders[3]]
94 | domain_2_mask_dirs = [mask_folders[3]]
95 | domain_3_img_dirs = [data_folders[4]]
96 | domain_3_mask_dirs = [mask_folders[4]]
97 |
98 | test_data_dirs = [data_folders[1], data_folders[2]]
99 | test_mask_dirs = [mask_folders[1], mask_folders[2]]
100 |
101 | domain_1_re = [Re_dir[0]]
102 | domain_2_re = [Re_dir[3]]
103 | domain_3_re = [Re_dir[4]]
104 | test_re = [Re_dir[1], Re_dir[2]]
105 |
106 | domain_1_num = [95]
107 | domain_2_num = [50]
108 | domain_3_num = [50]
109 | test_num = [74, 51]
110 |
111 | elif test_num=='C':
112 | domain_1_img_dirs = [data_folders[0]]
113 | domain_1_mask_dirs = [mask_folders[0]]
114 | domain_2_img_dirs = [data_folders[1], data_folders[2]]
115 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]]
116 | domain_3_img_dirs = [data_folders[4]]
117 | domain_3_mask_dirs = [mask_folders[4]]
118 |
119 | test_data_dirs = [data_folders[3]]
120 | test_mask_dirs = [mask_folders[3]]
121 |
122 | domain_1_re = [Re_dir[0]]
123 | domain_2_re = [Re_dir[1], Re_dir[2]]
124 | domain_3_re = [Re_dir[4]]
125 | test_re = [Re_dir[3]]
126 |
127 | domain_1_num = [95]
128 | domain_2_num = [74, 51]
129 | domain_3_num = [50]
130 | test_num = [50]
131 |
132 | elif test_num=='D':
133 | domain_1_img_dirs = [data_folders[0]]
134 | domain_1_mask_dirs = [mask_folders[0]]
135 | domain_2_img_dirs = [data_folders[1], data_folders[2]]
136 | domain_2_mask_dirs = [mask_folders[1], mask_folders[2]]
137 | domain_3_img_dirs = [data_folders[3]]
138 | domain_3_mask_dirs = [mask_folders[3]]
139 |
140 | test_data_dirs = [data_folders[4]]
141 | test_mask_dirs = [mask_folders[4]]
142 |
143 | domain_1_re = [Re_dir[0]]
144 | domain_2_re = [Re_dir[1], Re_dir[2]]
145 | domain_3_re = [Re_dir[3]]
146 | test_re = [Re_dir[4]]
147 |
148 | domain_1_num = [95]
149 | domain_2_num = [74, 51]
150 | domain_3_num = [50]
151 | test_num = [50]
152 |
153 | else:
154 | print('Wrong test vendor!')
155 |
156 |
157 | domain_1_labeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, num_label=domain_1_num, train=True, labeled=True)
158 | domain_2_labeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, num_label=domain_2_num, train=True, labeled=True)
159 | domain_3_labeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, num_label=domain_3_num, train=True, labeled=True)
160 |
161 |
162 | # domain_1_labeled_loader = DataLoader(dataset=domain_1_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
163 | # domain_2_labeled_loader = DataLoader(dataset=domain_2_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
164 | # domain_3_labeled_loader = DataLoader(dataset=domain_3_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
165 |
166 | domain_1_labeled_loader = None
167 | domain_2_labeled_loader = None
168 | domain_3_labeled_loader = None
169 |
170 | domain_1_unlabeled_dataset = ImageFolder(domain_1_img_dirs, domain_1_mask_dirs, domain_1_img_dirs, domain_1_re, label=0, train=True, labeled=False)
171 | domain_2_unlabeled_dataset = ImageFolder(domain_2_img_dirs, domain_2_mask_dirs, domain_1_img_dirs, domain_2_re, label=1, train=True, labeled=False)
172 | domain_3_unlabeled_dataset = ImageFolder(domain_3_img_dirs, domain_3_mask_dirs, domain_1_img_dirs, domain_3_re, label=2, train=True, labeled=False)
173 |
174 | domain_1_unlabeled_loader = DataLoader(dataset=domain_1_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
175 | domain_2_unlabeled_loader = DataLoader(dataset=domain_2_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
176 | domain_3_unlabeled_loader = DataLoader(dataset=domain_3_unlabeled_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
177 |
178 |
179 | test_dataset = ImageFolder(test_data_dirs, test_mask_dirs, domain_1_img_dirs, test_re, train=False, labeled=True)
180 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers, pin_memory=True)
181 |
182 | return domain_1_labeled_loader, domain_1_unlabeled_loader, \
183 | domain_2_labeled_loader, domain_2_unlabeled_loader,\
184 | domain_3_labeled_loader, domain_3_unlabeled_loader, \
185 | test_loader, \
186 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset
187 |
188 | def default_loader(path):
189 | return np.load(path)['arr_0']
190 |
191 | def make_dataset(dir):
192 | images = []
193 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
194 |
195 | for root, _, fnames in sorted(os.walk(dir)):
196 | for fname in fnames:
197 | path = os.path.join(root, fname)
198 | images.append(path)
199 | return images
200 |
201 | class ImageFolder(data.Dataset):
202 | def __init__(self, data_dirs, mask_dirs, ref_dir, re, train=True, label=None, num_label=None, labeled=True, loader=default_loader):
203 |
204 | print(data_dirs)
205 | print(mask_dirs)
206 |
207 | reso_dir = re
208 | temp_imgs = []
209 | temp_masks = []
210 | temp_re = []
211 | domain_labels = []
212 |
213 | tem_ref_imgs = []
214 |
215 | if train:
216 | #100%
217 | # k = 1
218 | #10%
219 | # k = 0.1
220 | #5%
221 | # k = 0.05
222 | #2%
223 | k = 0.02
224 | else:
225 | k = 1
226 |
227 |
228 | for num_set in range(len(data_dirs)):
229 | re_roots = sorted(make_dataset(reso_dir[num_set]))
230 | data_roots = sorted(make_dataset(data_dirs[num_set]))
231 | mask_roots = sorted(make_dataset(mask_dirs[num_set]))
232 | num_label_data = 0
233 | for num_data in range(len(data_roots)):
234 | if labeled:
235 | if train:
236 | n_label = str(math.ceil(num_label[num_set] * k + 1))
237 | if '00'+n_label==data_roots[num_data][-10:-7] or '0'+n_label==data_roots[num_data][-10:-7]:
238 | print(n_label)
239 | print(data_roots[num_data][-10:-7])
240 | break
241 |
242 | for num_mask in range(len(mask_roots)):
243 | if data_roots[num_data][-10:-4] == mask_roots[num_mask][-10:-4]:
244 | temp_re.append(re_roots[num_data])
245 | temp_imgs.append(data_roots[num_data])
246 | temp_masks.append(mask_roots[num_mask])
247 | domain_labels.append(label)
248 | num_label_data += 1
249 | else:
250 | pass
251 | else:
252 | temp_re.append(re_roots[num_data])
253 | temp_imgs.append(data_roots[num_data])
254 | domain_labels.append(label)
255 |
256 | for num_set in range(len(ref_dir)):
257 | data_roots = sorted(make_dataset(ref_dir[num_set]))
258 | for num_data in range(len(data_roots)):
259 | tem_ref_imgs.append(data_roots[num_data])
260 |
261 | reso = temp_re
262 | imgs = temp_imgs
263 | masks = temp_masks
264 | labels = domain_labels
265 |
266 | print(len(masks))
267 |
268 | ref_imgs = tem_ref_imgs
269 |
270 | # add something here to index, such that can split the data
271 | # index = random.sample(range(len(temp_img)), len(temp_mask))
272 |
273 | self.reso = reso
274 | self.imgs = imgs
275 | self.masks = masks
276 | self.labels = labels
277 | self.new_size = 288
278 | self.loader = loader
279 | self.labeled = labeled
280 | self.train = train
281 | self.ref = ref_imgs
282 |
283 | def __getitem__(self, index):
284 | if self.train:
285 | index = random.randrange(len(self.imgs))
286 | else:
287 | pass
288 |
289 | path_re = self.reso[index]
290 | re = self.loader(path_re)
291 | re = re[0]
292 |
293 | path_img = self.imgs[index]
294 | img = self.loader(path_img) # numpy, HxW, numpy.Float64
295 |
296 | # ref_paths = random.sample(self.ref, 1)
297 | # ref_img = self.loader(ref_paths[0])
298 | #
299 | # img = match_histograms(img, ref_img)
300 |
301 | label = self.labels[index]
302 |
303 |
304 | if label==0:
305 | one_hot_label = torch.tensor([[1], [0], [0]])
306 | elif label==1:
307 | one_hot_label = torch.tensor([[0], [1], [0]])
308 | elif label==2:
309 | one_hot_label = torch.tensor([[0], [0], [1]])
310 | else:
311 | one_hot_label = torch.tensor([[0], [0], [0]])
312 |
313 | # Intensity cropping:
314 | p5 = np.percentile(img.flatten(), 0.5)
315 | p95 = np.percentile(img.flatten(), 99.5)
316 | img = np.clip(img, p5, p95)
317 |
318 | img -= img.min()
319 | img /= img.max()
320 | img = img.astype('float32')
321 |
322 | crop_size = 300
323 |
324 | # Augmentations:
325 | # 1. random rotation
326 | # 2. random scaling 0.8 - 1.2
327 | # 3. random crop from 280x280
328 | # 4. random hflip
329 | # 5. random vflip
330 | # 6. color jitter
331 | # 7. Gaussian filtering
332 |
333 | img_tensor = F.to_tensor(np.array(img))
334 | img_size = img_tensor.size()
335 |
336 |
337 | if self.labeled:
338 | if self.train:
339 |
340 | img = F.to_pil_image(img)
341 | # rotate, random angle between 0 - 90
342 | angle = random.randint(0, 90)
343 | img = F.rotate(img, angle, Image.BILINEAR)
344 |
345 | path_mask = self.masks[index]
346 | mask = Image.open(path_mask) # numpy, HxWx3
347 | # rotate, random angle between 0 - 90
348 | mask = F.rotate(mask, angle, Image.NEAREST)
349 |
350 | ## Find the region of mask
351 | norm_mask = F.to_tensor(np.array(mask))
352 | region = norm_mask[0] + norm_mask[1] + norm_mask[2]
353 | non_zero_index = torch.nonzero(region == 1, as_tuple=False)
354 | if region.sum() > 0:
355 | len_m = len(non_zero_index[0])
356 | x_region = non_zero_index[len_m//2][0]
357 | y_region = non_zero_index[len_m//2][1]
358 | x_region = int(x_region.item())
359 | y_region = int(y_region.item())
360 | else:
361 | x_region = norm_mask.size(-2) // 2
362 | y_region = norm_mask.size(-1) // 2
363 |
364 | # resize and center-crop to 280x280
365 | resize_order = re / resampling_rate
366 | resize_size_h = int(img_size[-2] * resize_order)
367 | resize_size_w = int(img_size[-1] * resize_order)
368 |
369 | left_size = 0
370 | top_size = 0
371 | right_size = 0
372 | bot_size = 0
373 | if resize_size_h < self.new_size:
374 | top_size = (self.new_size - resize_size_h) // 2
375 | bot_size = (self.new_size - resize_size_h) - top_size
376 | if resize_size_w < self.new_size:
377 | left_size = (self.new_size - resize_size_w) // 2
378 | right_size = (self.new_size - resize_size_w) - left_size
379 |
380 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))]
381 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list
382 | transform = transforms.Compose(transform_list)
383 |
384 | img = transform(img)
385 |
386 |
387 | ## Define the crop index
388 | if top_size >= 0:
389 | top_crop = 0
390 | else:
391 | if x_region > self.new_size//2:
392 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2):
393 | top_crop = x_region - self.new_size//2
394 | else:
395 | top_crop = norm_mask.size(-2) - self.new_size
396 | else:
397 | top_crop = 0
398 |
399 | if left_size >= 0:
400 | left_crop = 0
401 | else:
402 | if y_region > self.new_size//2:
403 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1):
404 | left_crop = y_region - self.new_size//2
405 | else:
406 | left_crop = norm_mask.size(-1) - self.new_size
407 | else:
408 | left_crop = 0
409 |
410 | # random crop to 224x224
411 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size)
412 |
413 | # random flip
414 | hflip_p = random.random()
415 | img = F.hflip(img) if hflip_p >= 0.5 else img
416 | vflip_p = random.random()
417 | img = F.vflip(img) if vflip_p >= 0.5 else img
418 |
419 | img = F.to_tensor(np.array(img))
420 |
421 | # # Gamma correction: random gamma from [0.5, 1.5]
422 | # gamma = 0.5 + random.random()
423 | # img_max = img.max()
424 | # img = img_max * torch.pow((img / img_max), gamma)
425 |
426 | # Gaussian bluring:
427 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))]
428 | transform = transforms.Compose(transform_list)
429 | img = transform(img)
430 |
431 | # resize and center-crop to 280x280
432 | transform_mask_list = [transforms.Pad(
433 | (left_size, top_size, right_size, bot_size))]
434 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w),
435 | interpolation=Image.NEAREST)] + transform_mask_list
436 | transform_mask = transforms.Compose(transform_mask_list)
437 |
438 | mask = transform_mask(mask) # C,H,W
439 |
440 | # random crop to 224x224
441 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size)
442 |
443 | # random flip
444 | mask = F.hflip(mask) if hflip_p >= 0.5 else mask
445 | mask = F.vflip(mask) if vflip_p >= 0.5 else mask
446 |
447 | mask = F.to_tensor(np.array(mask))
448 |
449 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W
450 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1)))
451 | mask = torch.cat((mask, mask_bg), dim=0)
452 |
453 | return img, mask, one_hot_label.squeeze() # pytorch: N,C,H,W
454 |
455 | else:
456 | path_mask = self.masks[index]
457 | mask = Image.open(path_mask) # numpy, HxWx3
458 | # resize and center-crop to 280x280
459 |
460 | ## Find the region of mask
461 | norm_mask = F.to_tensor(np.array(mask))
462 | region = norm_mask[0] + norm_mask[1] + norm_mask[2]
463 | non_zero_index = torch.nonzero(region == 1, as_tuple=False)
464 | if region.sum() > 0:
465 | len_m = len(non_zero_index[0])
466 | x_region = non_zero_index[len_m//2][0]
467 | y_region = non_zero_index[len_m//2][1]
468 | x_region = int(x_region.item())
469 | y_region = int(y_region.item())
470 | else:
471 | x_region = norm_mask.size(-2) // 2
472 | y_region = norm_mask.size(-1) // 2
473 |
474 | resize_order = re / resampling_rate
475 | resize_size_h = int(img_size[-2] * resize_order)
476 | resize_size_w = int(img_size[-1] * resize_order)
477 |
478 | left_size = 0
479 | top_size = 0
480 | right_size = 0
481 | bot_size = 0
482 | if resize_size_h < self.new_size:
483 | top_size = (self.new_size - resize_size_h) // 2
484 | bot_size = (self.new_size - resize_size_h) - top_size
485 | if resize_size_w < self.new_size:
486 | left_size = (self.new_size - resize_size_w) // 2
487 | right_size = (self.new_size - resize_size_w) - left_size
488 |
489 |
490 | # transform_list = [transforms.CenterCrop((crop_size, crop_size))]
491 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))]
492 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list
493 | transform_list = [transforms.ToPILImage()] + transform_list
494 | transform = transforms.Compose(transform_list)
495 | img = transform(img)
496 | img = F.to_tensor(np.array(img))
497 |
498 | ## Define the crop index
499 | if top_size >= 0:
500 | top_crop = 0
501 | else:
502 | if x_region > self.new_size//2:
503 | if x_region - self.new_size//2 + self.new_size <= norm_mask.size(-2):
504 | top_crop = x_region - self.new_size//2
505 | else:
506 | top_crop = norm_mask.size(-2) - self.new_size
507 | else:
508 | top_crop = 0
509 |
510 | if left_size >= 0:
511 | left_crop = 0
512 | else:
513 | if y_region > self.new_size//2:
514 | if y_region - self.new_size//2 + self.new_size <= norm_mask.size(-1):
515 | left_crop = y_region - self.new_size//2
516 | else:
517 | left_crop = norm_mask.size(-1) - self.new_size
518 | else:
519 | left_crop = 0
520 |
521 | # random crop to 224x224
522 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size)
523 |
524 | # resize and center-crop to 280x280
525 | # transform_mask_list = [transforms.CenterCrop((crop_size, crop_size))]
526 | transform_mask_list = [transforms.Pad(
527 | (left_size, top_size, right_size, bot_size))]
528 | transform_mask_list = [transforms.Resize((resize_size_h, resize_size_w),
529 | interpolation=Image.NEAREST)] + transform_mask_list
530 | transform_mask = transforms.Compose(transform_mask_list)
531 |
532 | mask = transform_mask(mask) # C,H,W
533 | mask = F.crop(mask, top_crop, left_crop, self.new_size, self.new_size)
534 | mask = F.to_tensor(np.array(mask))
535 |
536 | mask_bg = (mask.sum(0) == 0).type_as(mask) # H,W
537 | mask_bg = mask_bg.reshape((1, mask_bg.size(0), mask_bg.size(1)))
538 | mask = torch.cat((mask, mask_bg), dim=0)
539 |
540 | return img, mask, path_img
541 |
542 | else:
543 | img = F.to_pil_image(img)
544 | # rotate, random angle between 0 - 90
545 | angle = random.randint(0, 90)
546 | img = F.rotate(img, angle, Image.BILINEAR)
547 |
548 | # resize and center-crop to 280x280
549 | resize_order = re / resampling_rate
550 | resize_size_h = int(img_size[-2] * resize_order)
551 | resize_size_w = int(img_size[-1] * resize_order)
552 |
553 | left_size = 0
554 | top_size = 0
555 | right_size = 0
556 | bot_size = 0
557 | if resize_size_h < crop_size:
558 | top_size = (crop_size - resize_size_h) // 2
559 | bot_size = (crop_size - resize_size_h) - top_size
560 | if resize_size_w < crop_size:
561 | left_size = (crop_size - resize_size_w) // 2
562 | right_size = (crop_size - resize_size_w) - left_size
563 |
564 | transform_list = [transforms.CenterCrop((crop_size, crop_size))]
565 | transform_list = [transforms.Pad((left_size, top_size, right_size, bot_size))] + transform_list
566 | transform_list = [transforms.Resize((resize_size_h, resize_size_w))] + transform_list
567 | transform = transforms.Compose(transform_list)
568 |
569 | img = transform(img)
570 |
571 | # random crop to 224x224
572 | top_crop = random.randint(0, crop_size - self.new_size)
573 | left_crop = random.randint(0, crop_size - self.new_size)
574 | img = F.crop(img, top_crop, left_crop, self.new_size, self.new_size)
575 |
576 | # random flip
577 | hflip_p = random.random()
578 | img = F.hflip(img) if hflip_p >= 0.5 else img
579 | vflip_p = random.random()
580 | img = F.vflip(img) if vflip_p >= 0.5 else img
581 |
582 | img = F.to_tensor(np.array(img))
583 |
584 | # # Gamma correction: random gamma from [0.5, 1.5]
585 | # gamma = 0.5 + random.random()
586 | # img_max = img.max()
587 | # img = img_max*torch.pow((img/img_max), gamma)
588 |
589 | # Gaussian bluring:
590 | transform_list = [transforms.GaussianBlur(5, sigma=(0.25, 1.25))]
591 | transform = transforms.Compose(transform_list)
592 | img = transform(img)
593 |
594 | return img, one_hot_label.squeeze() # pytorch: N,C,H,W
595 |
596 | def __len__(self):
597 | return len(self.imgs)
598 |
599 |
600 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
2 |
--------------------------------------------------------------------------------
/losses/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/losses/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/losses/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def beta_vae_loss(reco_x, x, logvar, mu, beta, type='L2'):
6 | if type == 'BCE':
7 | reco_x_loss = F.binary_cross_entropy(reco_x, x, reduction='sum')
8 | elif type == 'L1':
9 | reco_x_loss = F.l1_loss(reco_x, x, size_average=False)
10 | else:
11 | reco_x_loss = F.mse_loss(reco_x, x, size_average=False)
12 | kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
13 |
14 | return (reco_x_loss + beta*kld)/x.shape[0]
15 |
16 | def KL_divergence(logvar, mu):
17 | kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
18 | return kld.mean()
19 |
20 | def dice_loss(pred, target):
21 | """
22 | This definition generalize to real valued pred and target vector.
23 | This should be differentiable.
24 | pred: tensor with first dimension as batch
25 | target: tensor with first dimension as batch
26 | """
27 | smooth = 0.1 #1e-12
28 |
29 | # have to use contiguous since they may from a torch.view op
30 | iflat = pred.contiguous().view(-1)
31 | tflat = target.contiguous().view(-1)
32 | intersection = (iflat * tflat).sum()
33 |
34 | #A_sum = torch.sum(tflat * iflat)
35 | #B_sum = torch.sum(tflat * tflat)
36 | loss = ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)).mean()
37 |
38 | return 1 - loss
39 |
40 | def HSIC_lossfunc(x, y):
41 | assert x.dim() == y.dim() == 2
42 | assert x.size(0) == y.size(0)
43 | m = x.size(0)
44 | h = torch.eye(m) - 1/m
45 | h = h.to(x.device)
46 | K_x = gaussian_kernel(x)
47 | K_y = gaussian_kernel(y)
48 | return K_x.mm(h).mm(K_y).mm(h).trace() / (m-1+1e-10)
49 |
50 |
51 | def gaussian_kernel(x, y=None, sigma=5):
52 | if y is None:
53 | y = x
54 | assert x.dim() == y.dim() == 2
55 | assert x.size() == y.size()
56 | z = ((x.unsqueeze(0) - y.unsqueeze(1)) ** 2).sum(-1)
57 | return torch.exp(- 0.5 * z / (sigma * sigma))
58 |
59 | def LS_dis(score_real, score_fake):
60 | return 0.5 * (torch.mean((score_real-1)**2) + torch.mean(score_fake**2))
61 |
62 | def LS_model(score_fake):
63 | return 0.5 * (torch.mean((score_fake-1)**2))
64 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .dice_loss import *
2 | from .focal_loss import *
3 | from .gan_loss import *
--------------------------------------------------------------------------------
/metrics/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/dice_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/dice_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/focal_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/focal_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/gan_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/metrics/__pycache__/gan_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/dice_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 |
5 | class DiceCoeff(Function):
6 | """Dice coeff for individual examples"""
7 |
8 | def forward(self, input, target):
9 | self.save_for_backward(input, target)
10 | eps = 0.0001
11 | self.inter = torch.dot(input.view(-1), target.view(-1))
12 | self.union = torch.sum(input) + torch.sum(target) + eps
13 |
14 | t = (2 * self.inter.float() + eps) / self.union.float()
15 | return t
16 |
17 | # This function has only a single output, so it gets only one gradient
18 | def backward(self, grad_output):
19 |
20 | input, target = self.saved_variables
21 | grad_input = grad_target = None
22 |
23 | if self.needs_input_grad[0]:
24 | grad_input = grad_output * 2 * (target * self.union - self.inter) \
25 | / (self.union * self.union)
26 | if self.needs_input_grad[1]:
27 | grad_target = None
28 |
29 | return grad_input, grad_target
30 |
31 |
32 | def dice_coeff(input, target, device):
33 | """Dice coeff for batches"""
34 | if input.is_cuda:
35 | s = torch.FloatTensor(1).zero_()
36 | s = s.to(device)
37 | else:
38 | s = torch.FloatTensor(1).zero_()
39 |
40 | for i, c in enumerate(zip(input, target)):
41 | s = s + DiceCoeff().forward(c[0], c[1])
42 |
43 | return s / (i + 1)
--------------------------------------------------------------------------------
/metrics/focal_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FocalLoss(nn.Module):
6 | def __init__(self, gamma=0, alpha=None, size_average=True):
7 | super(FocalLoss, self).__init__()
8 | self.gamma = gamma
9 | self.alpha = alpha
10 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
11 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
12 | self.size_average = size_average
13 |
14 | def forward(self, input, target):
15 | if input.dim()>2:
16 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
17 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C
18 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
19 | target = target.view(-1,1)
20 |
21 | logpt = F.log_softmax(input, dim=1)
22 | logpt = logpt.gather(1,target)
23 | logpt = logpt.view(-1)
24 | pt = logpt.data.exp()
25 |
26 | if self.alpha is not None:
27 | if self.alpha.type()!=input.data.type():
28 | self.alpha = self.alpha.type_as(input.data)
29 | at = self.alpha.gather(0,target.data.view(-1))
30 | logpt = logpt * at
31 |
32 | loss = -1 * (1-pt)**self.gamma * logpt
33 | if self.size_average: return loss.mean()
34 | else: return loss.sum()
--------------------------------------------------------------------------------
/metrics/gan_loss.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | import torch
4 |
5 |
6 | def ls_discriminator_loss(scores_real, scores_fake):
7 | """
8 | Compute the Least-Squares GAN metrics for the discriminator.
9 |
10 | Inputs:
11 | - scores_real: PyTorch Variable of shape (N,) giving scores for the real data.
12 | - scores_fake: PyTorch Variable of shape (N,) giving scores for the fake data.
13 |
14 | Outputs:
15 | - metrics: A PyTorch Variable containing the metrics.
16 | """
17 | loss = (torch.mean((scores_real - 1) ** 2) + torch.mean(scores_fake ** 2)) / 2
18 | return loss
19 |
20 |
21 | def ls_generator_loss(scores_fake):
22 | """
23 | Computes the Least-Squares GAN metrics for the generator.
24 |
25 | Inputs:
26 | - scores_fake: PyTorch Variable of shape (N,) giving scores for the fake data.
27 |
28 | Outputs:
29 | - metrics: A PyTorch Variable containing the metrics.
30 | """
31 | loss = torch.mean((scores_fake - 1) ** 2) / 2
32 | return loss
33 |
34 |
35 | # %%
36 |
37 | def bce_loss(input, target):
38 | """
39 | Numerically stable version of the binary cross-entropy metrics function.
40 |
41 | As per https://github.com/pytorch/pytorch/issues/751
42 | See the TensorFlow docs for a derivation of this formula:
43 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
44 |
45 | Inputs:
46 | - input: PyTorch Variable of shape (N, ) giving scores.
47 | - target: PyTorch Variable of shape (N,) containing 0 and 1 giving targets.
48 |
49 | Returns:
50 | - A PyTorch Variable containing the mean BCE metrics over the minibatch of input data.
51 | """
52 | neg_abs = - input.abs()
53 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
54 | return loss.mean()
55 |
56 |
57 | # %%
58 |
59 | def discriminator_loss(logits_real, logits_fake, device):
60 | """
61 | Computes the discriminator metrics described above.
62 |
63 | Inputs:
64 | - logits_real: PyTorch Variable of shape (N,) giving scores for the real data.
65 | - logits_fake: PyTorch Variable of shape (N,) giving scores for the fake data.
66 |
67 | Returns:
68 | - metrics: PyTorch Variable containing (scalar) the metrics for the discriminator.
69 | """
70 | true_labels = torch.ones(logits_real.size()).to(device=device, dtype=torch.float32)
71 | loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, true_labels - 1)
72 | return loss
73 |
74 |
75 | def generator_loss(logits_fake, device):
76 | """
77 | Computes the generator metrics described above.
78 |
79 | Inputs:
80 | - logits_fake: PyTorch Variable of shape (N,) giving scores for the fake data.
81 |
82 | Returns:
83 | - metrics: PyTorch Variable containing the (scalar) metrics for the generator.
84 | """
85 | true_labels = torch.ones(logits_fake.size()).to(device=device, dtype=torch.float32)
86 | loss = bce_loss(logits_fake, true_labels)
87 | return loss
--------------------------------------------------------------------------------
/metrics/hausdorff.py:
--------------------------------------------------------------------------------
1 | from scipy.spatial.distance import directed_hausdorff
2 |
3 | def hausdorff_distance(x, y):
4 | x = x.cpu().data.numpy()
5 | u = x.reshape(x.shape[1], -1)
6 | y = y.cpu().data.numpy()
7 | v = y.reshape(y.shape[1], -1)
8 | return max(directed_hausdorff(u, v)[0], directed_hausdorff(v, u)[0])
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgnet import *
2 | from .weight_init import *
3 |
4 |
5 | import sys
6 |
7 | def get_model(name, params):
8 | if name == 'dgnet':
9 | return DGNet(params['width'], params['height'], params['num_classes'], params['ndf'], params['z_length'],
10 | params['norm'], params['upsample'], params['decoder_type'], params['anatomy_out_channels'],
11 | params['num_mask_channels'])
12 | else:
13 | print("Could not find the requested model ({})".format(name), file=sys.stderr)
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dgnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/dgnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/meta_decoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_decoder.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/meta_segmentor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_segmentor.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/meta_styleencoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_styleencoder.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/meta_unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/meta_unet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ops.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/ops.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/sdnet_ada.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/sdnet_ada.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_parts.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/unet_parts.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/weight_init.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/models/__pycache__/weight_init.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dgnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import sys
5 | import time
6 | from models.meta_segmentor import *
7 | from models.meta_styleencoder import *
8 | from models.meta_decoder import *
9 |
10 |
11 |
12 | class DGNet(nn.Module):
13 | def __init__(self, width, height, num_classes, ndf, z_length, norm, upsample, decoder_type, anatomy_out_channels, num_mask_channels):
14 | super(DGNet, self).__init__()
15 | """
16 | Args:
17 | width: input width
18 | height: input height
19 | upsample: upsampling type (nearest | bilateral)
20 | nclasses: number of semantice segmentation classes
21 | """
22 | self.h = height
23 | self.w = width
24 | self.ndf = ndf
25 | self.z_length = z_length
26 | self.anatomy_out_channels = anatomy_out_channels
27 | self.norm = norm
28 | self.upsample = upsample
29 | self.num_classes = num_classes
30 | self.decoder_type = decoder_type
31 | self.num_mask_channels = num_mask_channels
32 |
33 | self.m_encoder = StyleEncoder(z_length*2)
34 | self.a_encoder = ContentEncoder(self.h, self.w, self.ndf, self.anatomy_out_channels, self.norm, self.upsample)
35 | # self.segmentor = Segmentor(self.anatomy_out_channels, self.num_classes)
36 | self.decoder = Ada_Decoder(self.decoder_type, self.anatomy_out_channels, self.z_length*2, self.num_mask_channels)
37 |
38 | def forward(self, x, mask, script_type, meta_loss=None, meta_step_size=0.001, stop_gradient=False, a_in=None, z_in=None):
39 | self.meta_loss = meta_loss
40 | self.meta_step_size = meta_step_size
41 | self.stop_gradient = stop_gradient
42 |
43 | z_out, mu_out, logvar_out, cls_out= self.m_encoder(x, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
44 | stop_gradient=self.stop_gradient)
45 | a_out = self.a_encoder(x, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
46 | stop_gradient=self.stop_gradient)
47 | # seg_pred = self.segmentor(a_out, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
48 | # stop_gradient=self.stop_gradient)
49 |
50 | z_out_tilede = None
51 | cls_out_tild = None
52 |
53 | #t0 = time.time()
54 | if a_in is None:
55 | if script_type == 'training':
56 | reco = self.decoder(a_out, z_out, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
57 | stop_gradient=self.stop_gradient)
58 | z_out_tilede, mu_out_tilde, logvar_tilde, cls_out_tild = self.m_encoder(reco, meta_loss=self.meta_loss,
59 | meta_step_size=self.meta_step_size,
60 | stop_gradient=self.stop_gradient)
61 | elif script_type == 'val' or script_type == 'test':
62 | z_out_tilede, mu_out_tilde, logvar_tilde, cls_out_tild = self.m_encoder(x, meta_loss=self.meta_loss,
63 | meta_step_size=self.meta_step_size,
64 | stop_gradient=self.stop_gradient)
65 | reco = self.decoder(a_out, z_out_tilede, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
66 | stop_gradient=self.stop_gradient)
67 | else:
68 | reco = self.decoder(a_in, z_in, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
69 | stop_gradient=self.stop_gradient)
70 |
71 | return reco, z_out, z_out_tilede, a_out, None, mu_out, logvar_out, cls_out, cls_out_tild
72 |
--------------------------------------------------------------------------------
/models/meta_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.ops import *
5 |
6 | class LayerNorm(nn.Module):
7 | def __init__(self, num_features, eps=1e-5, affine=True):
8 | super(LayerNorm, self).__init__()
9 | self.num_features = num_features
10 | self.affine = affine
11 | self.eps = eps
12 |
13 | if self.affine:
14 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
15 | self.beta = nn.Parameter(torch.zeros(num_features))
16 |
17 | def forward(self, x):
18 | shape = [-1] + [1] * (x.dim() - 1)
19 | # print(x.size())
20 | if x.size(0) == 1:
21 | # These two lines run much faster in pytorch 0.4 than the two lines listed below.
22 | mean = x.view(-1).mean().view(*shape)
23 | std = x.view(-1).std().view(*shape)
24 | else:
25 | mean = x.view(x.size(0), -1).mean(1).view(*shape)
26 | std = x.view(x.size(0), -1).std(1).view(*shape)
27 |
28 | x = (x - mean) / (std + self.eps)
29 |
30 | if self.affine:
31 | shape = [1, -1] + [1] * (x.dim() - 2)
32 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
33 | return x
34 |
35 | class MLP(nn.Module):
36 | def __init__(self, input_dim, output_dim, dim, n_blk):
37 | super(MLP, self).__init__()
38 |
39 | self.fc1 = nn.Linear(input_dim, dim)
40 | self.fc2 = nn.Linear(dim, dim)
41 | self.fc3 = nn.Linear(dim, output_dim)
42 |
43 | def forward(self, x, meta_loss, meta_step_size, stop_gradient):
44 | self.meta_loss = meta_loss
45 | self.meta_step_size = meta_step_size
46 | self.stop_gradient = stop_gradient
47 |
48 | x = x.view(x.size(0), -1)
49 |
50 | out = linear(x, self.fc1.weight, self.fc1.bias, meta_loss=self.meta_loss,
51 | meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
52 | out = relu(out)
53 | out = linear(out, self.fc2.weight, self.fc2.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
54 | stop_gradient=self.stop_gradient)
55 | out = relu(out)
56 | out = linear(out, self.fc3.weight, self.fc3.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
57 | stop_gradient=self.stop_gradient)
58 | out = relu(out)
59 |
60 | return out
61 |
62 | class AdaptiveInstanceNorm2d(nn.Module):
63 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
64 | super(AdaptiveInstanceNorm2d, self).__init__()
65 | self.num_features = num_features
66 | self.eps = eps
67 | self.momentum = momentum
68 | # weight and bias are dynamically assigned
69 | self.weight = None
70 | self.bias = None
71 | # just dummy buffers, not used
72 | self.register_buffer('running_mean', torch.zeros(num_features))
73 | self.register_buffer('running_var', torch.ones(num_features))
74 |
75 | def forward(self, x):
76 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
77 | b, c = x.size(0), x.size(1)
78 | running_mean = self.running_mean.repeat(b)
79 | running_var = self.running_var.repeat(b)
80 |
81 | # Apply instance norm
82 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
83 |
84 | out = F.batch_norm(
85 | x_reshaped, running_mean, running_var, self.weight, self.bias,
86 | True, self.momentum, self.eps)
87 |
88 | return out.view(b, c, *x.size()[2:])
89 |
90 | def __repr__(self):
91 | return self.__class__.__name__ + '(' + str(self.num_features) + ')'
92 |
93 | class Decoder(nn.Module):
94 | def __init__(self, dim, output_dim=1):
95 | super(Decoder, self).__init__()
96 |
97 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1, bias=True)
98 | self.adain1 = AdaptiveInstanceNorm2d(dim)
99 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1, bias=True)
100 | self.adain2 = AdaptiveInstanceNorm2d(dim)
101 |
102 | self.conv3 = nn.Conv2d(dim, dim//2, 3, 1, 1, bias=True)
103 | dim //= 2
104 | # self.ln3 = LayerNorm(dim)
105 | self.bn3 = normalization(dim, norm='bn')
106 | self.conv4 = nn.Conv2d(dim, output_dim, 3, 1, 1, bias=True)
107 |
108 | def forward(self, x, meta_loss, meta_step_size, stop_gradient):
109 | self.meta_loss = meta_loss
110 | self.meta_step_size = meta_step_size
111 | self.stop_gradient = stop_gradient
112 |
113 | # x = F.softmax(x, dim=1)
114 |
115 | out = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss,
116 | meta_step_size=self.meta_step_size,
117 | stop_gradient=self.stop_gradient)
118 | out = self.adain1(out)
119 | out = conv2d(out, self.conv2.weight, self.conv2.bias, stride=1, padding=1, meta_loss=self.meta_loss,
120 | meta_step_size=self.meta_step_size,
121 | stop_gradient=self.stop_gradient)
122 | out = self.adain2(out)
123 | out = conv2d(out, self.conv3.weight, self.conv3.bias, stride=1, padding=1, meta_loss=self.meta_loss,
124 | meta_step_size=self.meta_step_size,
125 | stop_gradient=self.stop_gradient)
126 | # out = self.bn3(out)
127 | out = conv2d(out, self.conv4.weight, self.conv4.bias, stride=1, padding=1, meta_loss=self.meta_loss,
128 | meta_step_size=self.meta_step_size,
129 | stop_gradient=self.stop_gradient)
130 | out = tanh(out)
131 | return out
132 |
133 | # decoder
134 | class Ada_Decoder(nn.Module):
135 | # AdaIN auto-encoder architecture
136 | def __init__(self, decoder_type, anatomy_out_channels, z_length, num_mask_channels):
137 | super(Ada_Decoder, self).__init__()
138 | """
139 | """
140 | self.dec = Decoder(anatomy_out_channels)
141 | # MLP to generate AdaIN parameters
142 | self.mlp = MLP(z_length, self.get_num_adain_params(self.dec), 256, 3)
143 |
144 | def forward(self, a, z, meta_loss, meta_step_size, stop_gradient):
145 | self.meta_loss = meta_loss
146 | self.meta_step_size = meta_step_size
147 | self.stop_gradient = stop_gradient
148 | # reconstruct an image
149 | images_recon = self.decode(a, z, meta_loss=self.meta_loss,
150 | meta_step_size=self.meta_step_size,
151 | stop_gradient=self.stop_gradient)
152 | return images_recon
153 |
154 | def decode(self, content, style, meta_loss, meta_step_size, stop_gradient):
155 | self.meta_loss = meta_loss
156 | self.meta_step_size = meta_step_size
157 | self.stop_gradient = stop_gradient
158 | # decode content and style codes to an image
159 | adain_params = self.mlp(style, meta_loss=self.meta_loss,
160 | meta_step_size=self.meta_step_size,
161 | stop_gradient=self.stop_gradient)
162 | self.assign_adain_params(adain_params, self.dec)
163 | images = self.dec(content, meta_loss=self.meta_loss,
164 | meta_step_size=self.meta_step_size,
165 | stop_gradient=self.stop_gradient)
166 | return images
167 |
168 | def assign_adain_params(self, adain_params, model):
169 | # assign the adain_params to the AdaIN layers in model
170 | for m in model.modules():
171 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
172 | mean = adain_params[:, :m.num_features]
173 | std = adain_params[:, m.num_features:2*m.num_features]
174 | m.bias = mean.contiguous().view(-1)
175 | m.weight = std.contiguous().view(-1)
176 | if adain_params.size(1) > 2*m.num_features:
177 | adain_params = adain_params[:, 2*m.num_features:]
178 |
179 | def get_num_adain_params(self, model):
180 | # return the number of AdaIN parameters needed by the model
181 | num_adain_params = 0
182 | for m in model.modules():
183 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
184 | num_adain_params += 2*m.num_features
185 | return num_adain_params
--------------------------------------------------------------------------------
/models/meta_segmentor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.ops import *
5 | from models.meta_unet import *
6 |
7 | class ContentEncoder(nn.Module):
8 | def __init__(self, width, height, ndf, num_output_channels, norm, upsample):
9 | super(ContentEncoder, self).__init__()
10 | """
11 | Build an encoder to extract anatomical information from the image.
12 | """
13 | self.width = width
14 | self.height = height
15 | self.ndf = ndf
16 | self.num_output_channels = num_output_channels
17 | self.norm = norm
18 | self.upsample = upsample
19 |
20 | self.unet = UNet(c=1, n=32, norm='in', num_classes=self.num_output_channels)
21 |
22 | def forward(self, x, meta_loss, meta_step_size, stop_gradient):
23 | self.meta_loss = meta_loss
24 | self.meta_step_size = meta_step_size
25 | self.stop_gradient = stop_gradient
26 |
27 | out = self.unet(x, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
28 | return out
29 |
30 | class Segmentor(nn.Module):
31 | def __init__(self, num_output_channels, num_classes):
32 | super(Segmentor, self).__init__()
33 | """
34 | """
35 | self.num_output_channels = num_output_channels
36 | self.num_classes = num_classes+1 # check again
37 |
38 | self.conv1 = nn.Conv2d(self.num_output_channels, 64, 3, 1, 1, bias=True)
39 | self.bn1 = normalization(64, norm='bn')
40 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
41 | self.bn2 = normalization(64, norm='bn')
42 | self.pred = nn.Conv2d(64, self.num_classes, 1, 1, 0)
43 |
44 | def forward(self, x, meta_loss, meta_step_size, stop_gradient):
45 | self.meta_loss = meta_loss
46 | self.meta_step_size = meta_step_size
47 | self.stop_gradient = stop_gradient
48 |
49 | out = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
50 | stop_gradient=self.stop_gradient)
51 | out = self.bn1(out)
52 | out = relu(out)
53 | out = conv2d(out, self.conv2.weight, self.conv2.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
54 | stop_gradient=self.stop_gradient)
55 | out = self.bn2(out)
56 | out = relu(out)
57 | out = conv2d(out, self.pred.weight, self.pred.bias, stride=1, padding=0, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
58 |
59 | return out
--------------------------------------------------------------------------------
/models/meta_styleencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.ops import *
5 |
6 | class StyleEncoder(nn.Module):
7 | def __init__(self, style_dim):
8 | super(StyleEncoder, self).__init__()
9 | dim = 8
10 | self.style_dim = style_dim // 2
11 | self.conv1 = nn.Conv2d(1, dim, 7, 1, 3, bias=True)
12 | self.conv2 = nn.Conv2d(dim, dim*2, 4, 2, 1, bias=True)
13 | self.conv3 = nn.Conv2d(dim*2, dim*4, 4, 2, 1, bias=True)
14 | self.conv4 = nn.Conv2d(dim*4, dim*8, 4, 2, 1, bias=True)
15 | self.conv5 = nn.Conv2d(dim*8, dim*16, 4, 2, 1, bias=True)
16 | self.conv6 = nn.Conv2d(dim*16, dim*32, 4, 2, 1, bias=True)
17 |
18 | self.fc1 = nn.Linear(256*9*9, 4*9*9)
19 | self.fc2 = nn.Linear(4*9*9, 32)
20 | self.mu = nn.Linear(32, style_dim)
21 | self.logvar = nn.Linear(32, style_dim)
22 | self.classifier = nn.Linear(self.style_dim, 3)
23 |
24 | def reparameterize(self, mu, logvar):
25 | std = torch.exp(0.5 * logvar)
26 | eps = torch.randn_like(std)
27 |
28 | return mu + eps * std
29 |
30 | def forward(self, x, meta_loss, meta_step_size, stop_gradient):
31 | self.meta_loss = meta_loss
32 | self.meta_step_size = meta_step_size
33 | self.stop_gradient = stop_gradient
34 |
35 | out = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=3, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
36 | stop_gradient=self.stop_gradient)
37 | out = conv2d(out, self.conv2.weight, self.conv2.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
38 | stop_gradient=self.stop_gradient)
39 | out = conv2d(out, self.conv3.weight, self.conv3.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
40 | stop_gradient=self.stop_gradient)
41 | out = conv2d(out, self.conv4.weight, self.conv4.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
42 | stop_gradient=self.stop_gradient)
43 | out = conv2d(out, self.conv5.weight, self.conv5.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
44 | stop_gradient=self.stop_gradient)
45 | out = conv2d(out, self.conv6.weight, self.conv6.bias, stride=2, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
46 | stop_gradient=self.stop_gradient)
47 |
48 | out = linear(out.view(-1, out.shape[1] * out.shape[2] * out.shape[3]), self.fc1.weight, self.fc1.bias, meta_loss=self.meta_loss,
49 | meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
50 | out = lrelu(out)
51 | out = linear(out, self.fc2.weight, self.fc2.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
52 | stop_gradient=self.stop_gradient)
53 | out = lrelu(out)
54 |
55 | mu = linear(out, self.mu.weight, self.mu.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
56 | stop_gradient=self.stop_gradient)
57 | logvar = linear(out, self.logvar.weight, self.logvar.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
58 | stop_gradient=self.stop_gradient)
59 |
60 | zs = self.reparameterize(mu[:,self.style_dim:], logvar[:,self.style_dim:])
61 | zd = self.reparameterize(mu[:,:self.style_dim], logvar[:,:self.style_dim])
62 | z = torch.cat((zs,zd), dim=1)
63 |
64 | cls = linear(z[:,:self.style_dim], self.classifier.weight, self.classifier.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size,
65 | stop_gradient=self.stop_gradient)
66 | cls = F.softmax(cls, dim=1)
67 | return z, mu, logvar, cls
--------------------------------------------------------------------------------
/models/meta_unet.py:
--------------------------------------------------------------------------------
1 |
2 | """ Parts of the U-Net model """
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from models.ops import *
8 |
9 | class ConvD(nn.Module):
10 | def __init__(self, inplanes, planes, meta_loss, meta_step_size, stop_gradient, norm='in', first=False):
11 | super(ConvD, self).__init__()
12 |
13 | self.meta_loss = meta_loss
14 | self.meta_step_size = meta_step_size
15 | self.stop_gradient = stop_gradient
16 |
17 | self.first = first
18 |
19 | self.conv1 = nn.Conv2d(inplanes, planes, 3, 1, 1, bias=True)
20 | self.in1 = normalization(planes, norm)
21 |
22 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=True)
23 | self.in2 = normalization(planes, norm)
24 |
25 | self.conv3 = nn.Conv2d(planes, planes, 3, 1, 1, bias=True)
26 | self.in3 = normalization(planes, norm)
27 |
28 | def forward(self, x):
29 |
30 | if not self.first:
31 | x = maxpool2D(x, kernel_size=2)
32 |
33 | #layer 1 conv, in
34 | x = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
35 | x = self.in1(x)
36 |
37 | #layer 2 conv, in, lrelu
38 | y = conv2d(x, self.conv2.weight, self.conv2.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
39 | y = self.in2(y)
40 | y = lrelu(y)
41 |
42 | #layer 3 conv, in, lrelu
43 | z = conv2d(y, self.conv3.weight, self.conv3.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
44 | z = self.in3(z)
45 | z = lrelu(z)
46 |
47 | return z
48 |
49 | class ConvU(nn.Module):
50 | def __init__(self, planes, meta_loss, meta_step_size, stop_gradient, norm='in', first=False):
51 | super(ConvU, self).__init__()
52 |
53 | self.meta_loss = meta_loss
54 | self.meta_step_size = meta_step_size
55 | self.stop_gradient = stop_gradient
56 |
57 | self.first = first
58 | if not self.first:
59 | self.conv1 = nn.Conv2d(2*planes, planes, 3, 1, 1, bias=True)
60 | self.in1 = normalization(planes, norm)
61 |
62 | self.conv2 = nn.Conv2d(planes, planes//2, 1, 1, 0, bias=True)
63 | self.in2 = normalization(planes//2, norm)
64 |
65 | self.conv3 = nn.Conv2d(planes, planes, 3, 1, 1, bias=True)
66 | self.in3 = normalization(planes, norm)
67 |
68 | def forward(self, x, prev):
69 | #layer 1 conv, in, lrelu
70 | if not self.first:
71 | x = conv2d(x, self.conv1.weight, self.conv1.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
72 | x = self.in1(x)
73 | x = lrelu(x)
74 |
75 | #upsample, layer 2 conv, bn, relu
76 | y = upsample(x)
77 | y = conv2d(y, self.conv2.weight, self.conv2.bias, stride=1, padding=0, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
78 | y = self.in2(y)
79 | y = lrelu(y)
80 |
81 | #concatenation of two layers
82 | y = torch.cat([prev, y], 1)
83 |
84 | #layer 3 conv, bn
85 | y = conv2d(y, self.conv3.weight, self.conv3.bias, stride=1, padding=1, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient)
86 | y = self.in3(y)
87 | y = lrelu(y)
88 |
89 | return y
90 |
91 | class UNet(nn.Module):
92 | def __init__(self, c, n, num_classes, norm='in'):
93 | super(UNet, self).__init__()
94 |
95 | meta_loss = None
96 | meta_step_size = 0.01
97 | stop_gradient = False
98 |
99 | self.convd1 = ConvD(c, n, meta_loss, meta_step_size, stop_gradient, norm, first=True)
100 | self.convd2 = ConvD(n, 2*n, meta_loss, meta_step_size, stop_gradient, norm)
101 | self.convd3 = ConvD(2*n, 4*n, meta_loss, meta_step_size, stop_gradient, norm)
102 | self.convd4 = ConvD(4*n, 8*n, meta_loss, meta_step_size, stop_gradient, norm)
103 | self.convd5 = ConvD(8*n,16*n, meta_loss, meta_step_size, stop_gradient, norm)
104 |
105 | self.convu4 = ConvU(16*n, meta_loss, meta_step_size, stop_gradient, norm, first=True)
106 | self.convu3 = ConvU(8*n, meta_loss, meta_step_size, stop_gradient, norm)
107 | self.convu2 = ConvU(4*n, meta_loss, meta_step_size, stop_gradient, norm)
108 | self.convu1 = ConvU(2*n, meta_loss, meta_step_size, stop_gradient, norm)
109 |
110 | self.seg1 = nn.Conv2d(2*n, num_classes, 1)
111 |
112 | def forward(self, x, meta_loss, meta_step_size, stop_gradient):
113 | self.meta_loss = meta_loss
114 | self.meta_step_size = meta_step_size
115 | self.stop_gradient = stop_gradient
116 |
117 | x1 = self.convd1(x)
118 | x2 = self.convd2(x1)
119 | x3 = self.convd3(x2)
120 | x4 = self.convd4(x3)
121 | x5 = self.convd5(x4)
122 |
123 | y4 = self.convu4(x5, x4)
124 | y3 = self.convu3(y4, x3)
125 | y2 = self.convu2(y3, x2)
126 | y1 = self.convu1(y2, x1)
127 |
128 | y1 = conv2d(y1, self.seg1.weight, self.seg1.bias, meta_loss=self.meta_loss, meta_step_size=self.meta_step_size, stop_gradient=self.stop_gradient, kernel_size=None, stride=1, padding=0)
129 |
130 | return y1
--------------------------------------------------------------------------------
/models/ops.py:
--------------------------------------------------------------------------------
1 | import torch.autograd as autograd
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import torch
5 | from torch.autograd import Variable
6 |
7 | def normalization(planes, norm='in'):
8 | if norm == 'bn':
9 | m = nn.BatchNorm2d(planes)
10 | elif norm == 'gn':
11 | m = nn.GroupNorm(1, planes)
12 | elif norm == 'in':
13 | m = nn.InstanceNorm2d(planes)
14 | else:
15 | raise ValueError('normalization type {} is not supported'.format(norm))
16 | return m
17 |
18 | def linear(inputs, weight, bias, meta_step_size=0.001, meta_loss=None, stop_gradient=False):
19 | inputs = inputs
20 | weight = weight
21 | bias = bias
22 |
23 | if meta_loss is not None:
24 |
25 | if not stop_gradient:
26 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True) [0]
27 |
28 | if bias is not None:
29 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True) [0]
30 | bias_adapt = bias - grad_bias * meta_step_size
31 | else:
32 | bias_adapt = bias
33 |
34 | else:
35 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0].data, requires_grad=False)
36 |
37 | if bias is not None:
38 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0].data, requires_grad=False)
39 | bias_adapt = bias - grad_bias * meta_step_size
40 | else:
41 | bias_adapt = bias
42 |
43 | return F.linear(inputs,
44 | weight - grad_weight * meta_step_size,
45 | bias_adapt)
46 | else:
47 | return F.linear(inputs, weight, bias)
48 |
49 |
50 | def conv2d(inputs, weight, bias, meta_step_size=0.001, stride=1, padding=0, dilation=1, groups=1, meta_loss=None,
51 | stop_gradient=False, kernel_size=None):
52 |
53 | inputs = inputs
54 | weight = weight
55 | bias = bias
56 |
57 |
58 | if meta_loss is not None:
59 |
60 | if not stop_gradient:
61 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0]
62 |
63 | if bias is not None:
64 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0]
65 | if grad_bias is not None:
66 | bias_adapt = bias - grad_bias * meta_step_size
67 | else:
68 | bias_adapt = bias
69 | else:
70 | bias_adapt = bias
71 |
72 | else:
73 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0].data,
74 | requires_grad=False)
75 | if bias is not None:
76 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0].data, requires_grad=False)
77 | bias_adapt = bias - grad_bias * meta_step_size
78 | else:
79 | bias_adapt = bias
80 | if grad_weight is not None:
81 | weight_adapt = weight - grad_weight * meta_step_size
82 | else:
83 | weight_adapt = weight
84 |
85 | return F.conv2d(inputs,
86 | weight_adapt,
87 | bias_adapt, stride,
88 | padding,
89 | dilation, groups)
90 | else:
91 | return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups)
92 |
93 |
94 | def deconv2d(inputs, weight, bias, meta_step_size=0.001, stride=2, padding=0, dilation=0, groups=1, meta_loss=None,
95 | stop_gradient=False, kernel_size=None):
96 |
97 | inputs = inputs
98 | weight = weight
99 | bias = bias
100 |
101 |
102 | if meta_loss is not None:
103 |
104 | if not stop_gradient:
105 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0]
106 |
107 | if bias is not None:
108 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0]
109 | bias_adapt = bias - grad_bias * meta_step_size
110 | else:
111 | bias_adapt = bias
112 |
113 | else:
114 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True, allow_unused=True)[0].data,
115 | requires_grad=False)
116 | if bias is not None:
117 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True, allow_unused=True)[0].data, requires_grad=False)
118 | bias_adapt = bias - grad_bias * meta_step_size
119 | else:
120 | bias_adapt = bias
121 |
122 | return F.conv_transpose2d(inputs,
123 | weight - grad_weight * meta_step_size,
124 | bias_adapt, stride,
125 | padding,
126 | dilation, groups)
127 | else:
128 | return F.conv_transpose2d(inputs, weight, bias, stride, padding, dilation, groups)
129 |
130 | def tanh(inputs):
131 | return torch.tanh(inputs)
132 |
133 | def relu(inputs):
134 | return F.relu(inputs, inplace=True)
135 |
136 | def lrelu(inputs):
137 | return F.leaky_relu(inputs, negative_slope=0.01, inplace=False)
138 |
139 | def maxpool(inputs, kernel_size, stride=None, padding=0):
140 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding)
141 |
142 | def dropout(inputs):
143 | return F.dropout(inputs, p=0.5, training=False, inplace=False)
144 |
145 | def batchnorm(inputs, running_mean, running_var):
146 | return F.batch_norm(inputs, running_mean, running_var)
147 |
148 | def instancenorm(input):
149 | return F.instance_norm(input)
150 |
151 | def groupnorm(input):
152 | return F.group_norm(input)
153 |
154 | def dropout2D(inputs):
155 | return F.dropout2d(inputs, p=0.5, training=False, inplace=False)
156 |
157 | def maxpool2D(inputs, kernel_size, stride=None, padding=0):
158 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding)
159 |
160 | def upsample(input):
161 | return F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False)
162 |
--------------------------------------------------------------------------------
/models/unet_parts.py:
--------------------------------------------------------------------------------
1 |
2 | """ Parts of the U-Net model """
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
17 | nn.InstanceNorm2d(mid_channels, momentum=0.1, affine=True),
18 | nn.LeakyReLU(negative_slope=0.01, inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
20 | nn.InstanceNorm2d(out_channels, momentum=0.1, affine=True),
21 | nn.LeakyReLU(negative_slope=0.01, inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 |
57 | def forward(self, x1, x2):
58 | x1 = self.up(x1)
59 | # input is CHW
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 | # if you have padding issues, see
66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
68 | x = torch.cat([x2, x1], dim=1)
69 | return self.conv(x)
70 |
71 |
72 | class OutConv(nn.Module):
73 | def __init__(self, in_channels, out_channels):
74 | super(OutConv, self).__init__()
75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
79 |
--------------------------------------------------------------------------------
/models/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import os
4 | import sys
5 |
6 | def init_dcgan_weights(model):
7 | for m in model.modules():
8 | if isinstance(m, torch.nn.Conv2d):
9 | m.weight.data.normal_(0.0, 0.02)
10 | if m.bias is not None:
11 | m.bias.data.zero_()
12 |
13 | def initialize_weights(model, init = "xavier"):
14 | init_func = None
15 | if init == "xavier":
16 | init_func = torch.nn.init.xavier_normal_
17 | elif init == "kaiming":
18 | init_func = torch.nn.init.kaiming_normal_
19 | elif init == "gaussian" or init == "normal":
20 | init_func = torch.nn.init.normal_
21 |
22 | if init_func is not None:
23 | for module in model.modules():
24 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
25 | init_func(module.weight)
26 | if module.bias is not None:
27 | module.bias.data.zero_()
28 | elif isinstance(module, torch.nn.BatchNorm2d):
29 | module.weight.data.fill_(1)
30 | module.bias.data.zero_()
31 | else:
32 | print("Error when initializing model's weights, {} either doesn't exist or is not a valid initialization function.".format(init), \
33 | file=sys.stderr)
--------------------------------------------------------------------------------
/preprocess/save_MNMS_2D.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import argparse
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | import configparser
7 | import xlrd
8 |
9 | def safe_mkdir(path):
10 | try:
11 | os.makedirs(path)
12 | except OSError:
13 | pass
14 |
15 | def save_mask(img, path, name):
16 | h = img.shape[0]
17 | w = img.shape[1]
18 | img_1 = (img >= 0.5) * (img<1.5) * 255
19 | img_2 = (img >= 1.5) * (img<2.5) * 255
20 | img_3 = (img >= 2.5) * 255
21 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2)
22 | img = Image.fromarray(img.astype(np.uint8), 'RGB')
23 | img.save(os.path.join(path, name))
24 |
25 | def save_image(img, path, name):
26 | if img.min() < 0 or img.max() <= 0:
27 | pass
28 | else:
29 | img -= img.min()
30 | img /= img.max()
31 |
32 | img *= 255
33 | img = Image.fromarray(img.astype(np.uint8), 'L')
34 | img.save(os.path.join(path, name))
35 |
36 | def save_np(img, path, name):
37 | np.savez_compressed(os.path.join(path, name), img)
38 |
39 | def save_mask_np(img, path, name):
40 | h = img.shape[0]
41 | w = img.shape[1]
42 | img_1 = (img >= 0.5) * (img<1.5)
43 | img_2 = (img >= 1.5) * (img<2.5)
44 | img_3 = (img >= 2.5)
45 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2)
46 | np.savez_compressed(os.path.join(path, name), img)
47 |
48 |
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument('--LabeledVendorA', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorA/', help='The root of dataset.')
51 |
52 | parser.add_argument('--LabeledVendorBcenter2', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center2/', help='The root of dataset.')
53 | parser.add_argument('--LabeledVendorBcenter3', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center3/', help='The root of dataset.')
54 |
55 | parser.add_argument('--LabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorC/', help='The root of dataset.')
56 |
57 | parser.add_argument('--LabeledVendorD', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorD/', help='The root of dataset.')
58 |
59 | parser.add_argument('--UnlabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Unlabeled/vendorC/', help='The root of dataset.')
60 |
61 | arg = parser.parse_args()
62 |
63 | ######################################################################################################
64 | #Load excel information:
65 | # cell_value(1,0) -> cell_value(175,0)
66 | ex_file = '/home/s1575424/xiao/Year2/mnms_split_data/mnms_dataset_info.xlsx'
67 | wb = xlrd.open_workbook(ex_file)
68 | sheet = wb.sheet_by_index(0)
69 | # sheet.cell_value(r, c)
70 |
71 | # Save data dirs
72 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorA/'
73 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorA/'
74 |
75 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorB/center2/'
76 | LabeledVendorB2_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorB/center2/'
77 |
78 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorB/center3/'
79 | LabeledVendorB3_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorB/center3/'
80 |
81 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorC/'
82 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorC/'
83 |
84 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Labeled/vendorD/'
85 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_mask/Labeled/vendorD/'
86 |
87 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_data/Unlabeled/vendorC/'
88 |
89 |
90 | # Load all the data names
91 | LabeledVendorA_names = sorted(os.listdir(arg.LabeledVendorA))
92 |
93 | LabeledVendorB2_names = sorted(os.listdir(arg.LabeledVendorBcenter2))
94 | LabeledVendorB3_names = sorted(os.listdir(arg.LabeledVendorBcenter3))
95 |
96 | LabeledVendorC_names = sorted(os.listdir(arg.LabeledVendorC))
97 |
98 | LabeledVendorD_names = sorted(os.listdir(arg.LabeledVendorD))
99 |
100 | UnlabeledVendorC_names = sorted(os.listdir(arg.UnlabeledVendorC))
101 |
102 | #### Output: non-normed, non-cropped, no HW transposed npz file.
103 |
104 | ######################################################################################################
105 | # Load LabeledVendorA data and save them to 2D images
106 | for num_pat in range(0, len(LabeledVendorA_names)):
107 | gz_name = sorted(os.listdir(arg.LabeledVendorA+LabeledVendorA_names[num_pat]+'/'))
108 | patient_root = arg.LabeledVendorA+LabeledVendorA_names[num_pat] + '/' + gz_name[0]
109 | img = nib.load(patient_root)
110 | img_np = img.get_fdata()
111 |
112 | # p5 = np.percentile(img_np.flatten(), 5)
113 | # p95 = np.percentile(img_np.flatten(), 95)
114 | # img_np = np.clip(img_np, p5, p95)
115 |
116 | print('patient%03d...' % num_pat)
117 | save_labeled_data_root = LabeledVendorA_data_dir
118 |
119 | img_np = np.transpose(img_np, (0, 1, 3, 2))
120 |
121 | if num_pat == 0:
122 | print(type(img_np[0,0,0,0]))
123 |
124 | ## save each image of the patient
125 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
126 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
127 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
128 |
129 |
130 | ######################################################################################################
131 | # Load LabeledVendorA mask and save them to 2D images
132 | for num_pat in range(0, len(LabeledVendorA_names)):
133 | gz_name = sorted(os.listdir(arg.LabeledVendorA+LabeledVendorA_names[num_pat]+'/'))
134 | patient_root = arg.LabeledVendorA+LabeledVendorA_names[num_pat] + '/' + gz_name[1]
135 | img = nib.load(patient_root)
136 | img_np = img.get_fdata()
137 |
138 | print('patient%03d...' % num_pat)
139 |
140 | save_mask_root = LabeledVendorA_mask_dir
141 |
142 | img_np = np.transpose(img_np, (0, 1, 3, 2))
143 |
144 | for num_row in range(1, 346):
145 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]:
146 | ED = sheet.cell_value(num_row, 4)
147 | ES = sheet.cell_value(num_row, 5)
148 | break
149 |
150 | ## save masks of the patient
151 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
152 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
153 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES:
154 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice))
155 |
156 | ######################################################################################################
157 | # Load LabeledVendorB2 data and save them to 2D images
158 | for num_pat in range(0, len(LabeledVendorB2_names)):
159 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat]+'/'))
160 | patient_root = arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat] + '/' + gz_name[0]
161 | img = nib.load(patient_root)
162 | img_np = img.get_fdata()
163 |
164 | # p5 = np.percentile(img_np.flatten(), 5)
165 | # p95 = np.percentile(img_np.flatten(), 95)
166 | # img_np = np.clip(img_np, p5, p95)
167 |
168 | print('patient%03d...' % num_pat)
169 | save_labeled_data_root = LabeledVendorB2_data_dir
170 |
171 | img_np = np.transpose(img_np, (0, 1, 3, 2))
172 |
173 | if num_pat == 0:
174 | print(type(img_np[0,0,0,0]))
175 |
176 | ## save each image of the patient
177 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
178 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
179 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
180 |
181 |
182 | ######################################################################################################
183 | # Load LabeledVendorB2 mask and save them to 2D images
184 | for num_pat in range(0, len(LabeledVendorB2_names)):
185 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat]+'/'))
186 | patient_root = arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat] + '/' + gz_name[1]
187 | img = nib.load(patient_root)
188 | img_np = img.get_fdata()
189 |
190 | print('patient%03d...' % num_pat)
191 |
192 | save_mask_root = LabeledVendorB2_mask_dir
193 |
194 | img_np = np.transpose(img_np, (0, 1, 3, 2))
195 |
196 | for num_row in range(1, 346):
197 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]:
198 | ED = sheet.cell_value(num_row, 4)
199 | ES = sheet.cell_value(num_row, 5)
200 | break
201 |
202 | ## save masks of the patient
203 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
204 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
205 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES:
206 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice))
207 |
208 |
209 | ######################################################################################################
210 | # Load LabeledVendorB3 data and save them to 2D images
211 | for num_pat in range(0, len(LabeledVendorB3_names)):
212 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat]+'/'))
213 | patient_root = arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat] + '/' + gz_name[0]
214 | img = nib.load(patient_root)
215 | img_np = img.get_fdata()
216 |
217 | # p5 = np.percentile(img_np.flatten(), 5)
218 | # p95 = np.percentile(img_np.flatten(), 95)
219 | # img_np = np.clip(img_np, p5, p95)
220 |
221 | print('patient%03d...' % num_pat)
222 | save_labeled_data_root = LabeledVendorB3_data_dir
223 |
224 | img_np = np.transpose(img_np, (0, 1, 3, 2))
225 |
226 | if num_pat == 0:
227 | print(type(img_np[0,0,0,0]))
228 |
229 | ## save each image of the patient
230 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
231 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
232 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
233 |
234 |
235 | ######################################################################################################
236 | # Load LabeledVendorB3 mask and save them to 2D images
237 | for num_pat in range(0, len(LabeledVendorB3_names)):
238 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat]+'/'))
239 | patient_root = arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat] + '/' + gz_name[1]
240 | img = nib.load(patient_root)
241 | img_np = img.get_fdata()
242 |
243 | print('patient%03d...' % num_pat)
244 |
245 | save_mask_root = LabeledVendorB3_mask_dir
246 |
247 | img_np = np.transpose(img_np, (0, 1, 3, 2))
248 |
249 | for num_row in range(1, 346):
250 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]:
251 | ED = sheet.cell_value(num_row, 4)
252 | ES = sheet.cell_value(num_row, 5)
253 | break
254 |
255 | ## save masks of the patient
256 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
257 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
258 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES:
259 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice))
260 |
261 |
262 | ######################################################################################################
263 | # Load LabeledVendorC data and save them to 2D images
264 | for num_pat in range(0, len(LabeledVendorC_names)):
265 | gz_name = sorted(os.listdir(arg.LabeledVendorC+LabeledVendorC_names[num_pat]+'/'))
266 | patient_root = arg.LabeledVendorC+LabeledVendorC_names[num_pat] + '/' + gz_name[0]
267 | img = nib.load(patient_root)
268 | img_np = img.get_fdata()
269 |
270 | # p5 = np.percentile(img_np.flatten(), 5)
271 | # p95 = np.percentile(img_np.flatten(), 95)
272 | # img_np = np.clip(img_np, p5, p95)
273 |
274 | print('patient%03d...' % num_pat)
275 | save_labeled_data_root = LabeledVendorC_data_dir
276 |
277 | img_np = np.transpose(img_np, (0, 1, 3, 2))
278 |
279 | if num_pat == 0:
280 | print(type(img_np[0,0,0,0]))
281 |
282 | ## save each image of the patient
283 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
284 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
285 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
286 |
287 |
288 | ######################################################################################################
289 | # Load LabeledVendorC mask and save them to 2D images
290 | for num_pat in range(0, len(LabeledVendorC_names)):
291 | gz_name = sorted(os.listdir(arg.LabeledVendorC+LabeledVendorC_names[num_pat]+'/'))
292 | patient_root = arg.LabeledVendorC+LabeledVendorC_names[num_pat] + '/' + gz_name[1]
293 | img = nib.load(patient_root)
294 | img_np = img.get_fdata()
295 |
296 | print('patient%03d...' % num_pat)
297 |
298 | save_mask_root = LabeledVendorC_mask_dir
299 |
300 | img_np = np.transpose(img_np, (0, 1, 3, 2))
301 |
302 | for num_row in range(1, 346):
303 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]:
304 | ED = sheet.cell_value(num_row, 4)
305 | ES = sheet.cell_value(num_row, 5)
306 | break
307 |
308 | ## save masks of the patient
309 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
310 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
311 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES:
312 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice))
313 |
314 |
315 | ######################################################################################################
316 | # Load LabeledVendorD data and save them to 2D images
317 | for num_pat in range(0, len(LabeledVendorD_names)):
318 | gz_name = sorted(os.listdir(arg.LabeledVendorD+LabeledVendorD_names[num_pat]+'/'))
319 | patient_root = arg.LabeledVendorD+LabeledVendorD_names[num_pat] + '/' + gz_name[0]
320 | img = nib.load(patient_root)
321 | img_np = img.get_fdata()
322 |
323 | # p5 = np.percentile(img_np.flatten(), 5)
324 | # p95 = np.percentile(img_np.flatten(), 95)
325 | # img_np = np.clip(img_np, p5, p95)
326 |
327 | print('patient%03d...' % num_pat)
328 | save_labeled_data_root = LabeledVendorD_data_dir
329 |
330 | img_np = np.transpose(img_np, (0, 1, 3, 2))
331 |
332 | if num_pat == 0:
333 | print(type(img_np[0,0,0,0]))
334 |
335 | ## save each image of the patient
336 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
337 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
338 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
339 |
340 |
341 | ######################################################################################################
342 | # Load LabeledVendorD mask and save them to 2D images
343 | for num_pat in range(0, len(LabeledVendorD_names)):
344 | gz_name = sorted(os.listdir(arg.LabeledVendorD+LabeledVendorD_names[num_pat]+'/'))
345 | patient_root = arg.LabeledVendorD+LabeledVendorD_names[num_pat] + '/' + gz_name[1]
346 | img = nib.load(patient_root)
347 | img_np = img.get_fdata()
348 |
349 | print('patient%03d...' % num_pat)
350 |
351 | save_mask_root = LabeledVendorD_mask_dir
352 |
353 | img_np = np.transpose(img_np, (0, 1, 3, 2))
354 |
355 | for num_row in range(1, 346):
356 | if sheet.cell_value(num_row, 0) == gz_name[0][0:6]:
357 | ED = sheet.cell_value(num_row, 4)
358 | ES = sheet.cell_value(num_row, 5)
359 | break
360 |
361 | ## save masks of the patient
362 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
363 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
364 | if num_slice//img_np.shape[3] == ED or num_slice//img_np.shape[3] == ES:
365 | save_mask(img_save, save_mask_root, '%03d%03d.png' % (num_pat, num_slice))
366 |
367 |
368 | ######################################################################################################
369 | # Load UnlabeledVendorC data and save them to 2D images
370 | for num_pat in range(0, len(UnlabeledVendorC_names)):
371 | gz_name = sorted(os.listdir(arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat]+'/'))
372 | patient_root = arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat] + '/' + gz_name[0]
373 | img = nib.load(patient_root)
374 | img_np = img.get_fdata()
375 |
376 | # p5 = np.percentile(img_np.flatten(), 5)
377 | # p95 = np.percentile(img_np.flatten(), 95)
378 | # img_np = np.clip(img_np, p5, p95)
379 |
380 | print('patient%03d...' % num_pat)
381 | save_labeled_data_root = UnlabeledVendorC_data_dir
382 |
383 | img_np = np.transpose(img_np, (0, 1, 3, 2))
384 |
385 | if num_pat == 0:
386 | print(type(img_np[0,0,0,0]))
387 |
388 | ## save each image of the patient
389 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
390 | img_save = img_np[:,:,num_slice//img_np.shape[3],num_slice%img_np.shape[3]]
391 | save_np(img_save, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
--------------------------------------------------------------------------------
/preprocess/save_MNMS_re.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import argparse
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | import configparser
7 | import xlrd
8 |
9 | def safe_mkdir(path):
10 | try:
11 | os.makedirs(path)
12 | except OSError:
13 | pass
14 |
15 | def save_mask(img, path, name):
16 | h = img.shape[0]
17 | w = img.shape[1]
18 | img_1 = (img == 1) * 255
19 | img_2 = (img == 2) * 255
20 | img_3 = (img == 3) * 255
21 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2)
22 | img = Image.fromarray(img.astype(np.uint8), 'RGB')
23 | img.save(os.path.join(path, name))
24 |
25 | def save_image(img, path, name):
26 | if img.min() < 0 or img.max() <= 0:
27 | pass
28 | else:
29 | img -= img.min()
30 | img /= img.max()
31 |
32 | img *= 255
33 | img = Image.fromarray(img.astype(np.uint8), 'L')
34 | img.save(os.path.join(path, name))
35 |
36 | def save_np(img, path, name):
37 | np.savez_compressed(os.path.join(path, name), img)
38 |
39 | def save_mask_np(img, path, name):
40 | h = img.shape[0]
41 | w = img.shape[1]
42 | img_1 = (img == 1)
43 | img_2 = (img == 2)
44 | img_3 = (img == 3)
45 | img = np.concatenate((img_1.reshape(h,w,1), img_2.reshape(h,w,1), img_3.reshape(h,w,1)), axis=2)
46 | np.savez_compressed(os.path.join(path, name), img)
47 |
48 |
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument('--LabeledVendorA', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorA/', help='The root of dataset.')
51 |
52 | parser.add_argument('--LabeledVendorBcenter2', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center2/', help='The root of dataset.')
53 | parser.add_argument('--LabeledVendorBcenter3', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/center3/', help='The root of dataset.')
54 |
55 | parser.add_argument('--LabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorC/', help='The root of dataset.')
56 |
57 | parser.add_argument('--LabeledVendorD', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorD/', help='The root of dataset.')
58 |
59 | parser.add_argument('--UnlabeledVendorC', type=str, default='/home/s1575424/xiao/Year2/mnms_split_data/Unlabeled/vendorC/', help='The root of dataset.')
60 |
61 | arg = parser.parse_args()
62 |
63 | ######################################################################################################
64 | #Load excel information:
65 | # cell_value(1,0) -> cell_value(175,0)
66 | ex_file = '/home/s1575424/xiao/Year2/mnms_split_data/mnms_dataset_info.xlsx'
67 | wb = xlrd.open_workbook(ex_file)
68 | sheet = wb.sheet_by_index(0)
69 | # sheet.cell_value(r, c)
70 |
71 | # Save data dirs
72 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorA/'
73 |
74 | LabeledVendorB2_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorB/center2/'
75 |
76 | LabeledVendorB3_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorB/center3/'
77 |
78 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorC/'
79 |
80 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Labeled/vendorD/'
81 |
82 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/mnms_split_2D_re/Unlabeled/vendorC/'
83 |
84 |
85 | # Load all the data names
86 | LabeledVendorA_names = sorted(os.listdir(arg.LabeledVendorA))
87 |
88 | LabeledVendorB2_names = sorted(os.listdir(arg.LabeledVendorBcenter2))
89 | LabeledVendorB3_names = sorted(os.listdir(arg.LabeledVendorBcenter3))
90 |
91 | LabeledVendorC_names = sorted(os.listdir(arg.LabeledVendorC))
92 |
93 | LabeledVendorD_names = sorted(os.listdir(arg.LabeledVendorD))
94 |
95 | UnlabeledVendorC_names = sorted(os.listdir(arg.UnlabeledVendorC))
96 |
97 | #### Output: non-normed, non-cropped, no HW transposed npz file.
98 |
99 | ######################################################################################################
100 | # Load LabeledVendorA data and save them to 2D images
101 |
102 | for num_pat in range(0, len(LabeledVendorA_names)):
103 | gz_name = sorted(os.listdir(arg.LabeledVendorA+LabeledVendorA_names[num_pat]+'/'))
104 | patient_root = arg.LabeledVendorA+LabeledVendorA_names[num_pat] + '/' + gz_name[0]
105 | img = nib.load(patient_root)
106 | img_np = img.get_fdata()
107 |
108 | header = img.header
109 | resolution = header.get_zooms()
110 | X_scan_re = []
111 | X_scan_re.append(resolution[0])
112 | X_scan_re_np = np.array(X_scan_re)
113 |
114 | print('patient%03d...' % num_pat)
115 | save_labeled_data_root = LabeledVendorA_data_dir
116 |
117 | img_np = np.transpose(img_np, (0, 1, 3, 2))
118 |
119 | ## save each image of the patient
120 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
121 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
122 |
123 |
124 | ######################################################################################################
125 | # Load LabeledVendorB2 data and save them to 2D images
126 | for num_pat in range(0, len(LabeledVendorB2_names)):
127 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat]+'/'))
128 | patient_root = arg.LabeledVendorBcenter2+LabeledVendorB2_names[num_pat] + '/' + gz_name[0]
129 | img = nib.load(patient_root)
130 | img_np = img.get_fdata()
131 |
132 | header = img.header
133 | resolution = header.get_zooms()
134 | X_scan_re = []
135 | X_scan_re.append(resolution[0])
136 | X_scan_re_np = np.array(X_scan_re)
137 |
138 |
139 | print('patient%03d...' % num_pat)
140 | save_labeled_data_root = LabeledVendorB2_data_dir
141 |
142 | img_np = np.transpose(img_np, (0, 1, 3, 2))
143 |
144 | if num_pat == 0:
145 | print(type(img_np[0,0,0,0]))
146 |
147 | ## save each image of the patient
148 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
149 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
150 |
151 |
152 |
153 | ######################################################################################################
154 | # Load LabeledVendorB3 data and save them to 2D images
155 | for num_pat in range(0, len(LabeledVendorB3_names)):
156 | gz_name = sorted(os.listdir(arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat]+'/'))
157 | patient_root = arg.LabeledVendorBcenter3+LabeledVendorB3_names[num_pat] + '/' + gz_name[0]
158 | img = nib.load(patient_root)
159 | img_np = img.get_fdata()
160 |
161 | header = img.header
162 | resolution = header.get_zooms()
163 | X_scan_re = []
164 | X_scan_re.append(resolution[0])
165 | X_scan_re_np = np.array(X_scan_re)
166 |
167 |
168 | print('patient%03d...' % num_pat)
169 | save_labeled_data_root = LabeledVendorB3_data_dir
170 |
171 | img_np = np.transpose(img_np, (0, 1, 3, 2))
172 |
173 | if num_pat == 0:
174 | print(type(img_np[0,0,0,0]))
175 |
176 | ## save each image of the patient
177 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
178 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
179 |
180 |
181 | ######################################################################################################
182 | # Load LabeledVendorC data and save them to 2D images
183 | for num_pat in range(0, len(LabeledVendorC_names)):
184 | gz_name = sorted(os.listdir(arg.LabeledVendorC+LabeledVendorC_names[num_pat]+'/'))
185 | patient_root = arg.LabeledVendorC+LabeledVendorC_names[num_pat] + '/' + gz_name[0]
186 | img = nib.load(patient_root)
187 | img_np = img.get_fdata()
188 |
189 | header = img.header
190 | resolution = header.get_zooms()
191 | X_scan_re = []
192 | X_scan_re.append(resolution[0])
193 | X_scan_re_np = np.array(X_scan_re)
194 |
195 | print('patient%03d...' % num_pat)
196 | save_labeled_data_root = LabeledVendorC_data_dir
197 |
198 | img_np = np.transpose(img_np, (0, 1, 3, 2))
199 |
200 | if num_pat == 0:
201 | print(type(img_np[0,0,0,0]))
202 |
203 | ## save each image of the patient
204 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
205 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
206 |
207 |
208 | ######################################################################################################
209 | # Load LabeledVendorD data and save them to 2D images
210 | for num_pat in range(0, len(LabeledVendorD_names)):
211 | gz_name = sorted(os.listdir(arg.LabeledVendorD+LabeledVendorD_names[num_pat]+'/'))
212 | patient_root = arg.LabeledVendorD+LabeledVendorD_names[num_pat] + '/' + gz_name[0]
213 | img = nib.load(patient_root)
214 | img_np = img.get_fdata()
215 |
216 | header = img.header
217 | resolution = header.get_zooms()
218 | X_scan_re = []
219 | X_scan_re.append(resolution[0])
220 | X_scan_re_np = np.array(X_scan_re)
221 |
222 | print('patient%03d...' % num_pat)
223 | save_labeled_data_root = LabeledVendorD_data_dir
224 |
225 | img_np = np.transpose(img_np, (0, 1, 3, 2))
226 |
227 | if num_pat == 0:
228 | print(type(img_np[0,0,0,0]))
229 |
230 | ## save each image of the patient
231 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
232 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
233 |
234 |
235 | ######################################################################################################
236 | # Load UnlabeledVendorC data and save them to 2D images
237 | for num_pat in range(0, len(UnlabeledVendorC_names)):
238 | gz_name = sorted(os.listdir(arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat]+'/'))
239 | patient_root = arg.UnlabeledVendorC+UnlabeledVendorC_names[num_pat] + '/' + gz_name[0]
240 | img = nib.load(patient_root)
241 | img_np = img.get_fdata()
242 |
243 | header = img.header
244 | resolution = header.get_zooms()
245 | X_scan_re = []
246 | X_scan_re.append(resolution[0])
247 | X_scan_re_np = np.array(X_scan_re)
248 |
249 | print('patient%03d...' % num_pat)
250 | save_labeled_data_root = UnlabeledVendorC_data_dir
251 |
252 | img_np = np.transpose(img_np, (0, 1, 3, 2))
253 |
254 | if num_pat == 0:
255 | print(type(img_np[0,0,0,0]))
256 |
257 | ## save each image of the patient
258 | for num_slice in range(img_np.shape[2]*img_np.shape[3]):
259 | save_np(X_scan_re_np, save_labeled_data_root, '%03d%03d' % (num_pat, num_slice))
--------------------------------------------------------------------------------
/preprocess/save_SCGM_2D.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import argparse
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | import configparser
7 |
8 | import re
9 | import SimpleITK as stik
10 | from collections import OrderedDict
11 | from torchvision import transforms
12 | import random
13 | import torchvision.transforms.functional as F
14 | import cv2
15 | import torch
16 |
17 | def safe_mkdir(path):
18 | try:
19 | os.makedirs(path)
20 | except OSError:
21 | pass
22 |
23 | def save_np(img, path, name):
24 | np.savez_compressed(os.path.join(path, name), img)
25 |
26 | def save_mask_np(mask1, mask2, path, name):
27 | h = mask1.shape[0]
28 | w = mask1.shape[1]
29 | mask = np.concatenate((mask1.reshape(h,w,1), mask2.reshape(h,w,1)), axis=2)
30 | np.savez_compressed(os.path.join(path, name), mask)
31 |
32 | path_train = '/home/s1575424/xiao/Year2/scgm_rawdata/train/'
33 | past_test = '/home/s1575424/xiao/Year2/scgm_rawdata/test/'
34 |
35 | ######################################################################################################
36 |
37 | # Save data dirs
38 | LabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorA/'
39 | LabeledVendorA_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorA/'
40 |
41 | LabeledVendorB_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorB/'
42 | LabeledVendorB_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorB/'
43 |
44 | LabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorC/'
45 | LabeledVendorC_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorC/'
46 |
47 | LabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Labeled/vendorD/'
48 | LabeledVendorD_mask_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_mask/Labeled/vendorD/'
49 |
50 | UnlabeledVendorA_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorA/'
51 | UnlabeledVendorB_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorB/'
52 | UnlabeledVendorC_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorC/'
53 | UnlabeledVendorD_data_dir = '/home/s1575424/xiao/Year2/scgm_split_2D_data/Unlabeled/vendorD/'
54 |
55 | labeled_data_dir = [LabeledVendorA_data_dir, LabeledVendorB_data_dir, LabeledVendorC_data_dir, LabeledVendorD_data_dir]
56 | labeled_mask_dir = [LabeledVendorA_mask_dir, LabeledVendorB_mask_dir, LabeledVendorC_mask_dir, LabeledVendorD_mask_dir]
57 | un_labeled_data_dir = [UnlabeledVendorA_data_dir, UnlabeledVendorB_data_dir, UnlabeledVendorC_data_dir, UnlabeledVendorD_data_dir]
58 |
59 | safe_mkdir(LabeledVendorA_data_dir)
60 | safe_mkdir(LabeledVendorA_mask_dir)
61 | safe_mkdir(LabeledVendorB_data_dir)
62 | safe_mkdir(LabeledVendorB_mask_dir)
63 | safe_mkdir(LabeledVendorC_data_dir)
64 | safe_mkdir(LabeledVendorC_mask_dir)
65 | safe_mkdir(LabeledVendorD_data_dir)
66 | safe_mkdir(LabeledVendorD_mask_dir)
67 | safe_mkdir(UnlabeledVendorA_data_dir)
68 | safe_mkdir(UnlabeledVendorB_data_dir)
69 | safe_mkdir(UnlabeledVendorC_data_dir)
70 | safe_mkdir(UnlabeledVendorD_data_dir)
71 |
72 |
73 | def read_numpy(file_name):
74 | reader = stik.ImageFileReader()
75 | reader.SetImageIO("NiftiImageIO")
76 | reader.SetFileName(file_name)
77 | data = reader.Execute()
78 | return stik.GetArrayFromImage(data)
79 |
80 | def read_dataset_into_memory(data_list, labeled=True):
81 | for val in data_list.values():
82 | val['input'] = read_numpy(val['input'])
83 | if labeled:
84 | for idx, gt in enumerate(val['gt']):
85 | val['gt'][idx] = read_numpy(gt)
86 | else:
87 | pass
88 | return data_list
89 |
90 | def get_index_map(data_list, labeled=True):
91 | map_list = []
92 | num_sbj = 0
93 | for data in data_list.values():
94 | slice_num = data['input'].shape[0]
95 | print(data['input'].shape)
96 | for i in range(slice_num):
97 | if labeled:
98 | map_list.append([data['input'][i], np.stack([data['gt'][idx][i] for idx in range(4)], axis=0), num_sbj])
99 | else:
100 | map_list.append([data['input'][i], num_sbj])
101 | num_sbj += 1
102 | return map_list
103 |
104 | # data_list: data_list['input'] path, data_list['gt'] path
105 | def Save_vendor_data(data_list, labeled=True, data_dir='', mask_dir=''):
106 | data_list = read_dataset_into_memory(data_list, labeled)
107 | # data_list: data_list['input'] images, data_list['gt'] ground-truth labels
108 | map_list = get_index_map(data_list,labeled)
109 | # map_list: map_list[i]: [images, gt]
110 | if labeled:
111 | num_sbj = 0
112 | num_slice = 0
113 | for idx in range(len(map_list)):
114 | img, gt_list, flag = map_list[idx]
115 | if flag != num_sbj:
116 | num_sbj = flag
117 | num_slice = 0
118 | img = img / (img.max() if img.max() > 0 else 1)
119 | gt_list = torch.tensor(gt_list, dtype=torch.uint8)
120 | spinal_cord_mask = (torch.mean(((gt_list > 0)).float(), dim=0) > 0.5).float()
121 | spinal_cord_mask = spinal_cord_mask.numpy()
122 | gm_mask = (torch.mean((gt_list == 1).float(), dim=0) > 0.5).float()
123 | gm_mask = gm_mask.numpy()
124 | save_np(img, data_dir, '%03d%03d' % (num_sbj, num_slice))
125 | save_mask_np(spinal_cord_mask, gm_mask, mask_dir, '%03d%03d' % (num_sbj, num_slice))
126 | num_slice += 1
127 | else:
128 | num_sbj = 0
129 | num_slice = 0
130 | for idx in range(len(map_list)):
131 | img, flag = map_list[idx]
132 | if flag != num_sbj:
133 | num_sbj = flag
134 | num_slice = 0
135 | img = img / (img.max() if img.max() > 0 else 1)
136 | save_np(img, data_dir, '%03d%03d' % (num_sbj, num_slice))
137 | num_slice += 1
138 |
139 | resolution = {
140 | 'site1': [5, 0.5, 0.5],
141 | 'site2': [5, 0.5, 0.5],
142 | 'site3': [2.5, 0.5, 0.5],
143 | 'site4': [5, 0.29, 0.29],
144 | }
145 |
146 | labeled_imageFileList = [os.path.join(path_train, f) for f in os.listdir(path_train) if 'site' in f and '.txt' not in f]
147 | labeled_data_dict = {'site1': OrderedDict(), 'site2': OrderedDict(), 'site3': OrderedDict(), 'site4': OrderedDict()}
148 | for file in sorted(labeled_imageFileList):
149 | res = re.search('site(\d)-sc(\d*)-(image|mask)', file).groups()
150 | if res[1] not in labeled_data_dict['site' + res[0]].keys():
151 | labeled_data_dict['site' + res[0]][res[1]] = {'input': None, 'gt': []}
152 | if res[2] == 'image':
153 | labeled_data_dict['site' + res[0]][res[1]]['input'] = file
154 | if res[2] == 'mask':
155 | labeled_data_dict['site' + res[0]][res[1]]['gt'].append(file)
156 | i = 0
157 | for domain, data_list in labeled_data_dict.items():
158 | print(domain)
159 | Save_vendor_data(data_list, labeled=True, data_dir=labeled_data_dir[i], mask_dir=labeled_mask_dir[i])
160 | i += 1
161 |
162 | unlabeled_imageFileList = [os.path.join(past_test, f) for f in os.listdir(past_test) if 'site' in f and '.txt' not in f]
163 | unlabeled_data_dict = {'site1': OrderedDict(), 'site2': OrderedDict(), 'site3': OrderedDict(), 'site4': OrderedDict()}
164 | for file in sorted(unlabeled_imageFileList):
165 | res = re.search('site(\d)-sc(\d*)-(image|mask)', file).groups()
166 | if res[1] not in unlabeled_data_dict['site' + res[0]].keys():
167 | unlabeled_data_dict['site' + res[0]][res[1]] = {'input': None, 'gt': []}
168 | if res[2] == 'image':
169 | unlabeled_data_dict['site' + res[0]][res[1]]['input'] = file
170 | if res[2] == 'mask':
171 | unlabeled_data_dict['site' + res[0]][res[1]]['gt'].append(file)
172 | i = 0
173 | for domain, data_list in unlabeled_data_dict.items():
174 | print(domain)
175 | Save_vendor_data(data_list, labeled=False, data_dir=un_labeled_data_dir[i], mask_dir=None)
176 | i += 1
177 | # two items: 1. domain name, 2. data['input'], data[gt]
--------------------------------------------------------------------------------
/preprocess/split_MNMS_data.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import shutil
4 | import xlrd
5 |
6 | def walk_path(dir):
7 | dir = dir
8 | paths = []
9 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
10 |
11 | for root, dirs, _ in sorted(os.walk(dir)):
12 | for name in dirs:
13 | paths.append(os.path.join(root, name))
14 | # print(paths)
15 | return paths
16 |
17 |
18 | ######################################################################################################
19 | #Load excel information:
20 | # cell_value(1,0) -> cell_value(175,0)
21 | ex_file = '/home/s1575424/xiao/Year2/mnms_split_data/mnms_dataset_info.xlsx'
22 | wb = xlrd.open_workbook(ex_file)
23 | sheet = wb.sheet_by_index(0)
24 | # sheet.cell_value(r, c)
25 |
26 | vendor_A = []
27 | vendor_B = []
28 | vendor_C = []
29 | vendor_D = []
30 |
31 | center_2 = []
32 | center_3 = []
33 |
34 | for i in range(1, 346):
35 | if sheet.cell_value(i, 2)=='A':
36 | vendor_A.append(sheet.cell_value(i, 0))
37 | elif sheet.cell_value(i, 2)=='B':
38 | vendor_B.append(sheet.cell_value(i, 0))
39 | if sheet.cell_value(i, 3)==2:
40 | center_2.append(sheet.cell_value(i, 0))
41 | else:
42 | center_3.append(sheet.cell_value(i, 0))
43 | elif sheet.cell_value(i, 2) == 'C':
44 | vendor_C.append(sheet.cell_value(i, 0))
45 | elif sheet.cell_value(i, 2) == 'D':
46 | vendor_D.append(sheet.cell_value(i, 0))
47 | else:
48 | break
49 |
50 | ######################################################################################################
51 | # move data to the corresponding folders: mnms_split_data/Labeled/(vendorA, vendorB, vendorC, vendorD), mnms_split_data/Unlabeled/vendorC
52 | path_vendorA = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorA/'
53 | path_vendorB = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorB/'
54 | path_vendorC = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorC/'
55 | path_vendorD = '/home/s1575424/xiao/Year2/mnms_split_data/Labeled/vendorD/'
56 |
57 | labeled_train_paths = walk_path('/home/s1575424/xiao/Year2/OpenDataset/Training/Labeled')
58 | testing_paths = walk_path('/home/s1575424/xiao/Year2/OpenDataset/Testing')
59 | val_paths = walk_path('/home/s1575424/xiao/Year2/OpenDataset/Validation')
60 |
61 | i = 0
62 | for train_path in labeled_train_paths:
63 | if train_path[-6:] in vendor_A:
64 | shutil.move(train_path, path_vendorA)
65 | elif train_path[-6:] in vendor_B:
66 | if train_path[-6:] in center_2:
67 | shutil.move(train_path, path_vendorB + '/center2')
68 | else:
69 | shutil.move(train_path, path_vendorB + '/center3')
70 | else:
71 | continue
72 | i += 1
73 | print(i)
74 |
75 | i = 0
76 | for test_path in testing_paths:
77 | if test_path[-6:] in vendor_A:
78 | shutil.move(test_path, path_vendorA)
79 | elif test_path[-6:] in vendor_B:
80 | if test_path[-6:] in center_2:
81 | shutil.move(test_path, path_vendorB + '/center2')
82 | else:
83 | shutil.move(test_path, path_vendorB + '/center3')
84 | elif test_path[-6:] in vendor_C:
85 | shutil.move(test_path, path_vendorC)
86 | elif test_path[-6:] in vendor_D:
87 | shutil.move(test_path, path_vendorD)
88 | else:
89 | continue
90 | i += 1
91 | print(i)
92 |
93 | i = 0
94 | for val_path in val_paths:
95 | if val_path[-6:] in vendor_A:
96 | shutil.move(val_path, path_vendorA)
97 | elif val_path[-6:] in vendor_B:
98 | if val_path[-6:] in center_2:
99 | shutil.move(val_path, path_vendorB + '/center2')
100 | else:
101 | shutil.move(val_path, path_vendorB + '/center3')
102 | elif val_path[-6:] in vendor_C:
103 | shutil.move(val_path, path_vendorC)
104 | elif val_path[-6:] in vendor_D:
105 | shutil.move(val_path, path_vendorD)
106 | else:
107 | continue
108 | i += 1
109 | print(i)
--------------------------------------------------------------------------------
/train_meta.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import numpy as np
5 | import os
6 | import argparse
7 | from eval import eval_dgnet
8 | from tqdm import tqdm
9 | import logging
10 | from metrics.focal_loss import FocalLoss
11 | from torch.utils.data import DataLoader, random_split, ConcatDataset
12 | import torch.nn.functional as F
13 | import utils
14 | from loaders.mms_dataloader_meta_split import get_meta_split_data_loaders
15 | import models
16 | import losses
17 | from torch.utils.tensorboard import SummaryWriter
18 | import time
19 |
20 |
21 | def get_args():
22 | usage_text = (
23 | "DGNet Pytorch Implementation"
24 | "Usage: python train_meta.py [options],"
25 | " with [options]:"
26 | )
27 | parser = argparse.ArgumentParser(description=usage_text)
28 | #training details
29 | parser.add_argument('-e','--epochs', type=int, default=100, help='Number of epochs')
30 | parser.add_argument('-bs','--batch_size', type=int, default=4, help='Number of inputs per batch')
31 | parser.add_argument('-c', '--cp', type=str, default='checkpoints/', help='The name of the checkpoints.')
32 | parser.add_argument('-tc', '--tcp', type=str, default='temp_checkpoints/', help='The name of the checkpoints.')
33 | parser.add_argument('-t', '--tv', type=str, default='D', help='The name of the target vendor.')
34 | parser.add_argument('-w', '--wc', type=str, default='DGNet_LR00002_LDv5', help='The name of the writter summary.')
35 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
36 | parser.add_argument('-mn','--model_name', type=str, default='dgnet', help='Name of the model architecture to be used for training/testing.')
37 | parser.add_argument('-lr','--learning_rate', type=float, default='0.00004', help='The learning rate for model training')
38 | parser.add_argument('-wi','--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)')
39 | parser.add_argument('--save_path', type=str, default='checkpoints', help= 'Path to save model checkpoints')
40 | parser.add_argument('--decoder_type', type=str, default='film', help='Choose decoder type between FiLM and SPADE')
41 | #hardware
42 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
43 | parser.add_argument('--num_workers' ,type= int, default = 0, help='Number of workers to use for dataload')
44 |
45 | return parser.parse_args()
46 |
47 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvA/ -t A -w DGNetRE_COM_META_100_tvA -g 0
48 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvB/ -t B -w DGNetRE_COM_META_100_tvB -g 1
49 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvC/ -t C -w DGNetRE_COM_META_100_tvC -g 2
50 | # python train_meta.py -e 80 -c cp_dgnet_meta_100_tvD/ -t D -w DGNetRE_COM_META_100_tvD -g 3
51 | # k_un = 1
52 | # k1 = 20
53 | # k2 = 2
54 | # opt_patience = 4
55 |
56 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvA/ -t A -w DGNetRE_COM_META_50_tvA -g 0
57 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvB/ -t B -w DGNetRE_COM_META_50_tvB -g 1
58 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvC/ -t C -w DGNetRE_COM_META_50_tvC -g 2
59 | # python train_meta.py -e 100 -c cp_dgnet_meta_50_tvD/ -t D -w DGNetRE_COM_META_50_tvD -g 3
60 | # k_un = 1
61 | # k1 = 20
62 | # k2 = 3
63 | # opt_patience = 4
64 |
65 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvA/ -t A -w DGNetRE_COM_META_20_tvA -g 0
66 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvB/ -t B -w DGNetRE_COM_META_20_tvB -g 1
67 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvC/ -t C -w DGNetRE_COM_META_20_tvC -g 2
68 | # python train_meta.py -e 120 -c cp_dgnet_meta_20_tvD/ -t D -w DGNetRE_COM_META_20_tvD -g 3
69 | # k_un = 1
70 | # k1 = 30
71 | # k2 = 3
72 | # opt_patience = 4
73 |
74 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvA/ -t A -w DGNetRE_COM_META_5_tvA -g 0
75 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvB/ -t B -w DGNetRE_COM_META_5_tvB -g 1
76 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvC/ -t C -w DGNetRE_COM_META_5_tvC -g 2
77 | # python train_meta.py -e 150 -c cp_dgnet_meta_5_tvD/ -t D -w DGNetRE_COM_META_5_tvD -g 3
78 | k_un = 1
79 | k1 = 30
80 | k2 = 3
81 | opt_patience = 4
82 |
83 | def latent_norm(a):
84 | n_batch, n_channel, _, _ = a.size()
85 | for batch in range(n_batch):
86 | for channel in range(n_channel):
87 | a_min = a[batch,channel,:,:].min()
88 | a_max = a[batch, channel, :, :].max()
89 | a[batch,channel,:,:] += a_min
90 | a[batch, channel, :, :] /= a_max - a_min
91 | return a
92 |
93 | def train_net(args):
94 | best_dice = 0
95 | best_lv = 0
96 | best_myo = 0
97 | best_rv = 0
98 |
99 | epochs = args.epochs
100 | batch_size = args.batch_size
101 | lr = args.learning_rate
102 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
103 |
104 | dir_checkpoint = args.cp
105 | test_vendor = args.tv
106 | wc = args.wc
107 |
108 | #Model selection and initialization
109 | model_params = {
110 | 'width': 288,
111 | 'height': 288,
112 | 'ndf': 64,
113 | 'norm': "batchnorm",
114 | 'upsample': "nearest",
115 | 'num_classes': 3,
116 | 'decoder_type': args.decoder_type,
117 | 'anatomy_out_channels': 8,
118 | 'z_length': 8,
119 | 'num_mask_channels': 8,
120 |
121 | }
122 | model = models.get_model(args.model_name, model_params)
123 | num_params = utils.count_parameters(model)
124 | print('Model Parameters: ', num_params)
125 | models.initialize_weights(model, args.weight_init)
126 | model.to(device)
127 |
128 | # size:
129 | # X: N, 1, 224, 224
130 | # Y: N, 3, 224, 224
131 |
132 | _, domain_1_unlabeled_loader, \
133 | _, domain_2_unlabeled_loader,\
134 | _, domain_3_unlabeled_loader, \
135 | test_loader, \
136 | domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset = get_meta_split_data_loaders(batch_size//2, test_vendor=test_vendor, image_size=224)
137 |
138 |
139 | val_dataset = ConcatDataset([domain_1_labeled_dataset, domain_2_labeled_dataset, domain_3_labeled_dataset])
140 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True, num_workers=2)
141 | n_val = len(val_dataset)
142 | print(n_val)
143 | print(len(test_loader))
144 | print(len(domain_1_unlabeled_loader))
145 | print(len(domain_2_unlabeled_loader))
146 | print(len(domain_3_unlabeled_loader))
147 |
148 | d_len = []
149 | d_len.append(len(domain_1_labeled_dataset))
150 | d_len.append(len(domain_2_labeled_dataset))
151 | d_len.append(len(domain_3_labeled_dataset))
152 | long_len = d_len[0]
153 | for i in range(len(d_len)):
154 | long_len = d_len[i] if d_len[i]>=long_len else long_len
155 | print(long_len)
156 |
157 | new_d_1 = domain_1_labeled_dataset
158 | for i in range(long_len//d_len[0]+1):
159 | if long_len == d_len[0]:
160 | break
161 | new_d_1 = ConcatDataset([new_d_1, domain_1_labeled_dataset])
162 | domain_1_labeled_dataset = new_d_1
163 | domain_1_labeled_loader = DataLoader(dataset=domain_1_labeled_dataset, batch_size=batch_size//2, shuffle=False,
164 | drop_last=True, num_workers=2, pin_memory=True)
165 |
166 | new_d_2 = domain_2_labeled_dataset
167 | for i in range(long_len//d_len[1]+1):
168 | if long_len == d_len[1]:
169 | break
170 | new_d_2 = ConcatDataset([new_d_2, domain_2_labeled_dataset])
171 | domain_2_labeled_dataset = new_d_2
172 | domain_2_labeled_loader = DataLoader(dataset=domain_2_labeled_dataset, batch_size=batch_size//2, shuffle=False,
173 | drop_last=True, num_workers=2, pin_memory=True)
174 |
175 | new_d_3 = domain_3_labeled_dataset
176 | for i in range(long_len//d_len[2]+1):
177 | if long_len == d_len[2]:
178 | break
179 | new_d_3 = ConcatDataset([new_d_3, domain_3_labeled_dataset])
180 | domain_3_labeled_dataset = new_d_3
181 | domain_3_labeled_loader = DataLoader(dataset=domain_3_labeled_dataset, batch_size=batch_size//2, shuffle=False,
182 | drop_last=True, num_workers=2, pin_memory=True)
183 |
184 | print(len(domain_1_labeled_loader))
185 | print(len(domain_2_labeled_loader))
186 | print(len(domain_3_labeled_loader))
187 |
188 | #metrics initialization
189 | # l2_distance = nn.MSELoss().to(device)
190 | criterion = nn.BCEWithLogitsLoss().to(device)
191 | l1_distance = nn.L1Loss().to(device)
192 | focal = FocalLoss()
193 |
194 | #optimizer initialization
195 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
196 | # need to use a more useful lr_scheduler
197 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=opt_patience)
198 |
199 | writer = SummaryWriter(comment=wc)
200 |
201 | global_step = 0
202 | for epoch in range(epochs):
203 | model.train()
204 | with tqdm(total=long_len, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
205 | domain_1_labeled_itr = iter(domain_1_labeled_loader)
206 | domain_2_labeled_itr = iter(domain_2_labeled_loader)
207 | domain_3_labeled_itr = iter(domain_3_labeled_loader)
208 | domain_labeled_iter_list = [domain_1_labeled_itr, domain_2_labeled_itr, domain_3_labeled_itr]
209 |
210 | domain_1_unlabeled_itr = iter(domain_1_unlabeled_loader)
211 | domain_2_unlabeled_itr = iter(domain_2_unlabeled_loader)
212 | domain_3_unlabeled_itr = iter(domain_3_unlabeled_loader)
213 | domain_unlabeled_iter_list = [domain_1_unlabeled_itr, domain_2_unlabeled_itr, domain_3_unlabeled_itr]
214 |
215 |
216 | for num_itr in range(long_len//batch_size):
217 | # Randomly choosing meta train and meta test domains
218 | domain_list = np.random.permutation(3)
219 | meta_train_domain_list = domain_list[:2]
220 | meta_test_domain_list = domain_list[2]
221 |
222 | meta_train_imgs = []
223 | meta_train_masks = []
224 | meta_train_labels = []
225 | meta_test_imgs = []
226 | meta_test_masks = []
227 | meta_test_labels = []
228 | meta_test_un_imgs = []
229 | meta_test_un_labels = []
230 |
231 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_train_domain_list[0]])
232 | meta_train_imgs.append(imgs)
233 | meta_train_masks.append(true_masks)
234 | meta_train_labels.append(labels)
235 |
236 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_train_domain_list[1]])
237 | meta_train_imgs.append(imgs)
238 | meta_train_masks.append(true_masks)
239 | meta_train_labels.append(labels)
240 |
241 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_test_domain_list])
242 | meta_test_imgs.append(imgs)
243 | meta_test_masks.append(true_masks)
244 | meta_test_labels.append(labels)
245 | imgs, true_masks, labels = next(domain_labeled_iter_list[meta_test_domain_list])
246 | meta_test_imgs.append(imgs)
247 | meta_test_masks.append(true_masks)
248 | meta_test_labels.append(labels)
249 |
250 | imgs, labels = next(domain_unlabeled_iter_list[meta_test_domain_list])
251 | meta_test_un_imgs.append(imgs)
252 | meta_test_un_labels.append(labels)
253 | imgs, labels = next(domain_unlabeled_iter_list[meta_test_domain_list])
254 | meta_test_un_imgs.append(imgs)
255 | meta_test_un_labels.append(labels)
256 |
257 | meta_train_imgs = torch.cat((meta_train_imgs[0], meta_train_imgs[1]), dim=0)
258 | meta_train_masks = torch.cat((meta_train_masks[0], meta_train_masks[1]), dim=0)
259 | meta_train_labels = torch.cat((meta_train_labels[0], meta_train_labels[1]), dim=0)
260 | meta_test_imgs = torch.cat((meta_test_imgs[0], meta_test_imgs[1]), dim=0)
261 | meta_test_masks = torch.cat((meta_test_masks[0], meta_test_masks[1]), dim=0)
262 | meta_test_labels = torch.cat((meta_test_labels[0], meta_test_labels[1]), dim=0)
263 | meta_test_un_imgs = torch.cat((meta_test_un_imgs[0], meta_test_un_imgs[1]), dim=0)
264 | meta_test_un_labels = torch.cat((meta_test_un_labels[0], meta_test_un_labels[1]), dim=0)
265 |
266 | meta_train_un_imgs = []
267 | meta_train_un_labels = []
268 | for i in range(k_un):
269 | train_un_imgs = []
270 | train_un_labels = []
271 | un_imgs, un_labels = next(domain_unlabeled_iter_list[meta_train_domain_list[0]])
272 | train_un_imgs.append(un_imgs)
273 | train_un_labels.append(un_labels)
274 | un_imgs, un_labels = next(domain_unlabeled_iter_list[meta_train_domain_list[1]])
275 | train_un_imgs.append(un_imgs)
276 | train_un_labels.append(un_labels)
277 | meta_train_un_imgs.append(torch.cat((train_un_imgs[0], train_un_imgs[1]), dim=0))
278 | meta_train_un_labels.append(torch.cat((train_un_labels[0], train_un_labels[1]), dim=0))
279 |
280 | total_meta_un_loss = 0.0
281 | for i in range(k_un):
282 | # meta-train: 1. load meta-train data 2. calculate meta-train loss
283 | ###############################Meta train#######################################################
284 | un_imgs = meta_train_un_imgs[i].to(device=device, dtype=torch.float32)
285 | un_labels = meta_train_un_labels[i].to(device=device, dtype=torch.float32)
286 |
287 | un_reco, un_z_out, un_z_tilde, un_a_out, _, un_mu, un_logvar, un_cls_out, _ = model(un_imgs, true_masks, 'training')
288 |
289 | un_a_feature = F.softmax(un_a_out, dim=1)
290 | # un_a_feature = un_a_feature[:,4:,:,:]
291 | # un_seg_pred = un_a_out[:,:4,:,:]
292 |
293 | latent_dim = un_a_feature.size(1)
294 | un_a_feature = un_a_out.permute(0, 2, 3, 1).contiguous().view(-1, latent_dim)
295 | un_a_feature = un_a_feature[torch.randperm(len(un_a_feature))]
296 | un_U_a, un_S_a, un_V_a = torch.svd(un_a_feature[0:2000])
297 |
298 | # loss_low_rank_Un_a = 0.1*torch.sum(un_S_a)
299 | loss_low_rank_Un_a = un_S_a[4]
300 |
301 | un_reco_loss = l1_distance(un_reco, un_imgs)
302 | un_regression_loss = l1_distance(un_z_tilde, un_z_out)
303 |
304 | kl_loss1 = losses.KL_divergence(un_logvar[:, :8], un_mu[:, :8])
305 | kl_loss2 = losses.KL_divergence(un_logvar[:, 8:], un_mu[:, 8:])
306 | hsic_loss = losses.HSIC_lossfunc(un_z_out[:, :8], un_z_out[:, 8:])
307 | un_kl_loss = kl_loss1 + kl_loss2 + hsic_loss
308 |
309 | d_cls = criterion(un_cls_out, un_labels)
310 | un_batch_loss = un_reco_loss + (0.1*un_regression_loss) + 0.1*un_kl_loss + d_cls + 0.1*loss_low_rank_Un_a
311 |
312 | total_meta_un_loss += un_batch_loss
313 |
314 | # meta-test: 1. load meta-test data 2. calculate meta-test loss
315 | ###############################Meta test#######################################################
316 | un_imgs = meta_test_un_imgs.to(device=device, dtype=torch.float32)
317 | un_labels = meta_test_un_labels.to(device=device, dtype=torch.float32)
318 | un_reco, un_z_out, un_z_tilde, un_a_out, _, un_mu, un_logvar, un_cls_out, _ = model(
319 | un_imgs, true_masks, 'training', meta_loss=un_batch_loss)
320 |
321 | un_seg_pred = un_a_out[:, :4, :, :]
322 | sf_un_seg_pred = F.softmax(un_seg_pred, dim=1)
323 |
324 | un_reco_loss = l1_distance(un_reco, un_imgs)
325 | un_regression_loss = l1_distance(un_z_tilde, un_z_out)
326 |
327 | # kl_loss1 = losses.KL_divergence(un_logvar[:, :8], un_mu[:, :8])
328 | # kl_loss2 = losses.KL_divergence(un_logvar[:, 8:], un_mu[:, 8:])
329 | # hsic_loss = losses.HSIC_lossfunc(un_z_out[:, :8], un_z_out[:, 8:])
330 | # un_kl_loss = kl_loss1 + kl_loss2 + hsic_loss
331 |
332 | d_cls = criterion(un_cls_out, un_labels)
333 | un_batch_loss = un_reco_loss + d_cls
334 |
335 | total_meta_un_loss += un_batch_loss
336 |
337 | writer.add_scalar('Meta_test_loss/un_reco_loss', un_reco_loss.item(), global_step)
338 | writer.add_scalar('Meta_test_loss/un_regression_loss', un_regression_loss.item(), global_step)
339 | # writer.add_scalar('Meta_test_loss/un_kl_loss', un_kl_loss.item(), global_step)
340 | writer.add_scalar('Meta_test_loss/d_cls', d_cls.item(), global_step)
341 | # writer.add_scalar('Meta_test_loss/loss_low_rank_Un_a', loss_low_rank_Un_a.item(), un_step)
342 | writer.add_scalar('Meta_test_loss/un_batch_loss', un_batch_loss.item(), global_step)
343 |
344 | optimizer.zero_grad()
345 | total_meta_un_loss.backward()
346 | nn.utils.clip_grad_value_(model.parameters(), 0.1)
347 | optimizer.step()
348 |
349 | total_meta_loss = 0.0
350 | # meta-train: 1. load meta-train data 2. calculate meta-train loss
351 | ###############################Meta train#######################################################
352 | imgs = meta_train_imgs.to(device=device, dtype=torch.float32)
353 | mask_type = torch.float32
354 | ce_mask = meta_train_masks.clone().to(device=device, dtype=torch.long)
355 | true_masks = meta_train_masks.to(device=device, dtype=mask_type)
356 | labels = meta_train_labels.to(device=device, dtype=torch.float32)
357 |
358 | reco, z_out, z_out_tilde, a_out, _, mu, logvar, cls_out, _ = model(imgs, true_masks, 'training')
359 |
360 | # mode-1 flattering and change the original 4,8,224,224 features to 4x224x224, 8
361 | # randomly pick 4000, 8 features to calculate the singular values
362 | a_feature = F.softmax(a_out, dim=1)
363 | # a_feature = a_feature[:, 4:, :, :]
364 | seg_pred = a_out[:, :4, :, :]
365 |
366 | latent_dim = a_feature.size(1)
367 | a_feature = a_feature.permute(0, 2, 3, 1).contiguous().view(-1, latent_dim)
368 | a_feature = a_feature[torch.randperm(len(a_feature))]
369 | U_a, S_a, V_a = torch.svd(a_feature[0:2000])
370 |
371 | # loss_low_rank_a = 0.1*torch.sum(S_a)
372 | loss_low_rank_a = S_a[4]
373 |
374 | reco_loss = l1_distance(reco, imgs)
375 | kl_loss1 = losses.KL_divergence(logvar[:,:8], mu[:,:8])
376 | kl_loss2 = losses.KL_divergence(logvar[:,8:], mu[:,8:])
377 | hsic_loss = losses.HSIC_lossfunc(z_out[:,:8], z_out[:,8:])
378 | kl_loss = kl_loss1 + kl_loss2 + hsic_loss
379 | regression_loss = l1_distance(z_out_tilde, z_out)
380 |
381 | sf_seg = F.softmax(seg_pred, dim=1)
382 | dice_loss_lv = losses.dice_loss(sf_seg[:,0,:,:], true_masks[:,0,:,:])
383 | dice_loss_myo = losses.dice_loss(sf_seg[:,1,:,:], true_masks[:,1,:,:])
384 | dice_loss_rv = losses.dice_loss(sf_seg[:,2,:,:], true_masks[:,2,:,:])
385 | dice_loss_bg = losses.dice_loss(sf_seg[:, 3, :, :], true_masks[:, 3, :, :])
386 | loss_dice = dice_loss_lv + dice_loss_myo + dice_loss_rv + dice_loss_bg
387 |
388 | ce_target = ce_mask[:, 3, :, :]*0 + ce_mask[:, 0, :, :]*1 + ce_mask[:, 1, :, :]*2 + ce_mask[:, 2, :, :]*3
389 |
390 | seg_pred_swap = torch.cat((seg_pred[:,3,:,:].unsqueeze(1), seg_pred[:,:3,:,:]), dim=1)
391 |
392 | loss_focal = focal(seg_pred_swap, ce_target)
393 |
394 | d_cls = criterion(cls_out, labels)
395 | d_losses = d_cls
396 |
397 | batch_loss = reco_loss + (0.1 * regression_loss) + 0.1*kl_loss + 5*loss_dice + 5*loss_focal + d_losses + 0.1*loss_low_rank_a
398 |
399 | total_meta_loss += batch_loss
400 |
401 | writer.add_scalar('Meta_train_Loss/loss_dice', loss_dice.item(), global_step)
402 | writer.add_scalar('Meta_train_Loss/dice_loss_lv', dice_loss_lv.item(), global_step)
403 | writer.add_scalar('Meta_train_Loss/dice_loss_myo', dice_loss_myo.item(), global_step)
404 | writer.add_scalar('Meta_train_Loss/dice_loss_rv', dice_loss_rv.item(), global_step)
405 | writer.add_scalar('Meta_train_Loss/loss_focal', loss_focal.item(), global_step)
406 | writer.add_scalar('Meta_train_Loss/kl_loss', kl_loss.item(), global_step)
407 | writer.add_scalar('Meta_train_Loss/loss_low_rank_a', loss_low_rank_a.item(), global_step)
408 | writer.add_scalar('Meta_train_Loss/batch_loss', batch_loss.item(), global_step)
409 |
410 | # meta-test: 1. load meta-test data 2. calculate meta-test loss
411 | ###############################Meta test#######################################################
412 | imgs = meta_test_imgs.to(device=device, dtype=torch.float32)
413 | mask_type = torch.float32
414 | ce_mask = meta_test_masks.clone().to(device=device, dtype=torch.long)
415 | true_masks = meta_test_masks.to(device=device, dtype=mask_type)
416 | labels = meta_test_labels.to(device=device, dtype=torch.float32)
417 | reco, z_out, z_out_tilde, a_out, _, mu, logvar, cls_out, _ = model(imgs, true_masks, 'training', meta_loss=batch_loss)
418 |
419 | # mode-1 flattering and change the original 4,8,224,224 features to 4x224x224, 8
420 | # randomly pick 4000, 8 features to calculate the singular values
421 | # latent_dim = a_out.size(1)
422 | # a_feature = a_out.permute(0, 2, 3, 1).contiguous().view(-1, latent_dim)
423 | # a_feature = a_feature[torch.randperm(len(a_feature))]
424 | # U_a, S_a, V_a = torch.svd(a_feature[0:2000])
425 |
426 | seg_pred = a_out[:, :4, :, :]
427 |
428 | reco_loss = l1_distance(reco, imgs)
429 | # kl_loss = losses.KL_divergence(logvar, mu)
430 | # regression_loss = l1_distance(z_out_tilde, z_out)
431 |
432 | sf_seg = F.softmax(seg_pred, dim=1)
433 | dice_loss_lv = losses.dice_loss(sf_seg[:,0,:,:], true_masks[:,0,:,:])
434 | dice_loss_myo = losses.dice_loss(sf_seg[:,1,:,:], true_masks[:,1,:,:])
435 | dice_loss_rv = losses.dice_loss(sf_seg[:,2,:,:], true_masks[:,2,:,:])
436 | dice_loss_bg = losses.dice_loss(sf_seg[:, 3, :, :], true_masks[:, 3, :, :])
437 | loss_dice = dice_loss_lv + dice_loss_myo + dice_loss_rv + dice_loss_bg
438 |
439 | ce_target = ce_mask[:, 3, :, :]*0 + ce_mask[:, 0, :, :]*1 + ce_mask[:, 1, :, :]*2 + ce_mask[:, 2, :, :]*3
440 |
441 | seg_pred_swap = torch.cat((seg_pred[:,3,:,:].unsqueeze(1), seg_pred[:,:3,:,:]), dim=1)
442 |
443 | loss_focal = focal(seg_pred_swap, ce_target)
444 |
445 | d_cls = criterion(cls_out, labels)
446 | d_losses = d_cls
447 |
448 | batch_loss = 5*loss_dice + 5*loss_focal + reco_loss + d_losses
449 | total_meta_loss += batch_loss
450 |
451 | optimizer.zero_grad()
452 | total_meta_loss.backward()
453 | nn.utils.clip_grad_value_(model.parameters(), 0.1)
454 | optimizer.step()
455 |
456 | pbar.set_postfix(**{'loss (batch)': total_meta_loss.item()})
457 | pbar.update(imgs.shape[0])
458 |
459 | if (epoch + 1) > (k1) and (epoch + 1) % k2 == 0:
460 | if global_step % ((long_len//batch_size) // 2) == 0:
461 | a_feature = F.softmax(a_out, dim=1)
462 | a_feature = latent_norm(a_feature)
463 | writer.add_images('Meta_train_images/train', imgs, global_step)
464 | writer.add_images('Meta_train_images/a_out0', a_feature[:,0,:,:].unsqueeze(1), global_step)
465 | writer.add_images('Meta_train_images/a_out1', a_feature[:, 1, :, :].unsqueeze(1), global_step)
466 | writer.add_images('Meta_train_images/a_out2', a_feature[:, 2, :, :].unsqueeze(1), global_step)
467 | writer.add_images('Meta_train_images/a_out3', a_feature[:, 3, :, :].unsqueeze(1), global_step)
468 | writer.add_images('Meta_train_images/a_out4', a_feature[:, 4, :, :].unsqueeze(1), global_step)
469 | writer.add_images('Meta_train_images/a_out5', a_feature[:, 5, :, :].unsqueeze(1), global_step)
470 | writer.add_images('Meta_train_images/a_out6', a_feature[:, 6, :, :].unsqueeze(1), global_step)
471 | writer.add_images('Meta_train_images/a_out7', a_feature[:, 7, :, :].unsqueeze(1), global_step)
472 | writer.add_images('Meta_train_images/train_reco', reco, global_step)
473 | writer.add_images('Meta_train_images/train_true', true_masks[:,0:3,:,:], global_step)
474 | writer.add_images('Meta_train_images/train_pred', sf_seg[:,0:3,:,:] > 0.5, global_step)
475 | writer.add_images('Meta_test_images/train_un_img', un_imgs, global_step)
476 | writer.add_images('Meta_test_images/train_un_mask', sf_un_seg_pred[:, 0:3, :, :] > 0.5, global_step)
477 |
478 | global_step += 1
479 |
480 | if optimizer.param_groups[0]['lr']<=2e-8:
481 | print('Converge')
482 | if (epoch + 1) > k1 and (epoch + 1) % k2 == 0:
483 |
484 | val_score, val_lv, val_myo, val_rv = eval_dgnet(model, val_loader, device, mode='val')
485 | scheduler.step(val_score)
486 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
487 |
488 | logging.info('Validation Dice Coeff: {}'.format(val_score))
489 | logging.info('Validation LV Dice Coeff: {}'.format(val_lv))
490 | logging.info('Validation MYO Dice Coeff: {}'.format(val_myo))
491 | logging.info('Validation RV Dice Coeff: {}'.format(val_rv))
492 |
493 | writer.add_scalar('Dice/val', val_score, epoch)
494 | writer.add_scalar('Dice/val_lv', val_lv, epoch)
495 | writer.add_scalar('Dice/val_myo', val_myo, epoch)
496 | writer.add_scalar('Dice/val_rv', val_rv, epoch)
497 |
498 | initial_itr = 0
499 | for imgs, true_masks in test_loader:
500 | if initial_itr == 5:
501 | model.eval()
502 | imgs = imgs.to(device=device, dtype=torch.float32)
503 | with torch.no_grad():
504 | reco, z_out, z_out_tilde, a_out, seg_pred, mu, logvar, _, _ = model(imgs, true_masks,
505 | 'test')
506 | seg_pred = a_out[:, :4, :, :]
507 | mask_type = torch.float32
508 | true_masks = true_masks.to(device=device, dtype=mask_type)
509 | sf_seg_pred = F.softmax(seg_pred, dim=1)
510 | writer.add_images('Test_images/test', imgs, epoch)
511 | writer.add_images('Test_images/test_reco', reco, epoch)
512 | writer.add_images('Test_images/test_true', true_masks[:, 0:3, :, :], epoch)
513 | writer.add_images('Test_images/test_pred', sf_seg_pred[:, 0:3, :, :] > 0.5, epoch)
514 | model.train()
515 | break
516 | else:
517 | pass
518 | initial_itr += 1
519 | test_score, test_lv, test_myo, test_rv = eval_dgnet(model, test_loader, device, mode='test')
520 |
521 | if best_dice < test_score:
522 | best_dice = test_score
523 | best_lv = test_lv
524 | best_myo = test_myo
525 | best_rv = test_rv
526 | print("Epoch checkpoint")
527 | try:
528 | os.mkdir(dir_checkpoint)
529 | logging.info('Created checkpoint directory')
530 | except OSError:
531 | pass
532 | torch.save(model.state_dict(),
533 | dir_checkpoint + 'CP_epoch.pth')
534 | logging.info('Checkpoint saved !')
535 | else:
536 | pass
537 | logging.info('Best Dice Coeff: {}'.format(best_dice))
538 | logging.info('Best LV Dice Coeff: {}'.format(best_lv))
539 | logging.info('Best MYO Dice Coeff: {}'.format(best_myo))
540 | logging.info('Best RV Dice Coeff: {}'.format(best_rv))
541 | writer.add_scalar('Dice/test', test_score, epoch)
542 | writer.add_scalar('Dice/test_lv', test_lv, epoch)
543 | writer.add_scalar('Dice/test_myo', test_myo, epoch)
544 | writer.add_scalar('Dice/test_rv', test_rv, epoch)
545 | writer.close()
546 |
547 |
548 | if __name__ == '__main__':
549 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
550 | args = get_args()
551 | device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')
552 | logging.info(f'Using device {device}')
553 |
554 |
555 | torch.manual_seed(14)
556 | if device.type == 'cuda':
557 | torch.cuda.manual_seed(14)
558 |
559 | train_net(args)
560 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_utils import *
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vios-s/DGNet/8af4b82b62b53e29e96084113a5d379774c11b12/utils/__pycache__/model_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from os import path
3 |
4 | def count_parameters(model):
5 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
6 |
7 | def save_network_state(model, width, height, ndf, norm, upsample, num_classes, decoder_type, anatomy_out_channels, z_length, num_mask_channels, optimizer, epoch , name , save_path):
8 | if not path.exists(save_path):
9 | raise ValueError("{} not a valid path to save model state".format(save_path))
10 | torch.save(
11 | {
12 | 'epoch' : epoch,
13 | 'width': width,
14 | 'height': height,
15 | 'ndf' : ndf,
16 | 'norm' : norm,
17 | 'upsample' : upsample,
18 | 'num_classes': num_classes,
19 | 'decoder_type' : decoder_type,
20 | 'anatomy_out_channels' : anatomy_out_channels,
21 | 'z_length' : z_length,
22 | 'num_mask_channels' : num_mask_channels,
23 | 'model_state_dict' : model.state_dict(),
24 | 'optimizer_state_dict' : optimizer.state_dict()
25 | }, path.join(save_path, name))
--------------------------------------------------------------------------------