├── .gitignore ├── LICENSE ├── README.md ├── config ├── prostateCT_3dresunet_train.yml ├── prostateCT_deeper3dresunet_dev.yml └── prostateCT_deeper3dresunet_train.yml ├── libs ├── data_loaders │ ├── __init__.py │ └── prostate_ct_volume_loader.py ├── data_processing │ ├── shrink_size.py │ └── slices_to_volumes.py ├── loss_funcs │ ├── __init__.py │ └── loss.py ├── metrics │ ├── __init__.py │ └── metrics.py ├── models │ ├── __init__.py │ ├── deeper_resunet_3d.py │ └── resunet_3d.py ├── optimizers │ └── __init__.py ├── schedulers │ ├── __init__.py │ └── schedulers.py └── utils │ ├── device.py │ └── logging.py ├── main.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 DonDzundza 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-3D-Medical-Image-Semantic-Segmentation 2 | 3 | This is the release version of my private research repository. It will be updated as my research goes. 4 | 5 | # Why do we need AI for medical image semantic segmentation? 6 | Radiotherapy treatment planning requires accurate contours for maximizing target coverage while minimizing the toxicities to the surrounding organs at risk (OARs). The diverse expertise and experience levels of physicians introduce large intraobserver variations in manual contouring. Interobserver and intraobserver variation of delineation results in uncertainty in treatment planning, which could compromise treatment outcome. Manual contouring by physicians in current clinical practice is time-consuming, which is incapable of supporting adaptive treatment when the patient is on the couch. 7 | 8 | ## Example 9 | 10 | |![ezgif com-gif-maker](https://user-images.githubusercontent.com/24512849/87363829-a0cec500-c537-11ea-9c74-7c94d8ba0687.gif)|![ezgif com-gif-maker (1)](https://user-images.githubusercontent.com/24512849/87363843-a6c4a600-c537-11ea-80be-4c18407cba61.gif)|![ezgif com-gif-maker (2)](https://user-images.githubusercontent.com/24512849/87363872-bc39d000-c537-11ea-866e-6f37e3ee2615.gif)| 11 | |:-:|:-:|:-:| 12 | |![ezgif com-gif-maker (3)](https://user-images.githubusercontent.com/24512849/87364053-31a5a080-c538-11ea-918a-4aa45dcae14e.gif)|![ezgif com-gif-maker (4)](https://user-images.githubusercontent.com/24512849/87364058-35d1be00-c538-11ea-9ffd-d2f9dcc2ca7c.gif)|![ezgif com-gif-maker (5)](https://user-images.githubusercontent.com/24512849/87364085-47b36100-c538-11ea-92ca-983231dbe1a3.gif)| 13 | |CT Slice|Ground Truth|Prediction| 14 | 15 | # Update Log 16 | 17 | 7/11/2020 Update 18 | 19 | - Basic training/validation function 20 | - Model: Deeper 3D Residual U-net 21 | 22 | 7/13/2020 Update 23 | 24 | - Model: 3D Residual U-net 25 | - Normalization control in dataloader 26 | 27 | # Consider citing our paper: 28 | Zhang, Z., Zhao, T., Gay, H., Zhang, W., & Sun, B. (2020). ARPM‐net: A novel CNN‐based adversarial method with Markov Random Field enhancement for prostate and organs at risk segmentation in pelvic CT images. Medical Physics. [https://aapm.onlinelibrary.wiley.com/doi/abs/10.1002/mp.14580] 29 | 30 | Zhang, Z., Zhao, T., Gay, H., Zhang, W. and Sun, B., 2021. Weaving attention U‐net: A novel hybrid CNN and attention‐based method for organs‐at‐risk segmentation in head and neck CT images. Medical physics, 48(11), pp.7052-7062. [https://aapm.onlinelibrary.wiley.com/doi/abs/10.1002/mp.15287] 31 | 32 | 33 | -------------------------------------------------------------------------------- /config/prostateCT_3dresunet_train.yml: -------------------------------------------------------------------------------- 1 | JOB: seg 2 | CUDA: True 3 | PARALLEL: True 4 | 5 | MODEL: 6 | NAME: ResUnet_3D 7 | BASE_FILTERS: 64 8 | CHANNEL_IN: 1 9 | INIT_MODEL: 10 | ADV: False 11 | 12 | DATASET: 13 | NAME: prostateCT_vol 14 | PATH: C:\Users\zhzhang\Desktop\data\prostate_CT_split 15 | N_CLASSES: 6 16 | TRAIN_SPLIT: train 17 | VAL_SPLIT: val 18 | NORMALIZE: False 19 | COLOR_MAP: [[229, 255, 204],[0, 255, 255],[204, 0, 102],[255, 0, 0],[0, 255, 0]] 20 | 21 | TRAINING: 22 | ITER_MAX: 30000 23 | WORKERS: 8 24 | BATCH_SIZE: 2 25 | VAL_INTERVAL: 1500 26 | PRINT_INTERVAL: 300 27 | OPTIM: 28 | name: adam 29 | lr: 0.001 30 | LR_SCHEDULER: 31 | name: poly_lr 32 | max_iter: 30000 33 | LOSS_FUNC: 34 | name: cross_entropy3d 35 | reduction: mean -------------------------------------------------------------------------------- /config/prostateCT_deeper3dresunet_dev.yml: -------------------------------------------------------------------------------- 1 | JOB: seg 2 | CUDA: True 3 | PARALLEL: True 4 | 5 | MODEL: 6 | NAME: Deeper_ResUnet_3D 7 | BASE_FILTERS: 64 8 | CHANNEL_IN: 1 9 | INIT_MODEL: 10 | ADV: False 11 | 12 | DATASET: 13 | NAME: prostateCT_vol 14 | PATH: C:\Users\zhzhang\Desktop\data\prostate_CT_split 15 | N_CLASSES: 6 16 | TRAIN_SPLIT: train 17 | VAL_SPLIT: val 18 | NORMALIZE: False 19 | COLOR_MAP: [[229, 255, 204],[0, 255, 255],[204, 0, 102],[255, 0, 0],[0, 255, 0]] 20 | 21 | TRAINING: 22 | ITER_MAX: 200 23 | WORKERS: 8 24 | BATCH_SIZE: 2 25 | VAL_INTERVAL: 100 26 | PRINT_INTERVAL: 50 27 | OPTIM: 28 | name: adam 29 | lr: 0.0001 30 | LR_SCHEDULER: 31 | name: poly_lr 32 | max_iter: 200 33 | LOSS_FUNC: 34 | name: cross_entropy3d 35 | reduction: mean -------------------------------------------------------------------------------- /config/prostateCT_deeper3dresunet_train.yml: -------------------------------------------------------------------------------- 1 | JOB: seg 2 | CUDA: True 3 | PARALLEL: True 4 | 5 | MODEL: 6 | NAME: Deeper_ResUnet_3D 7 | BASE_FILTERS: 64 8 | CHANNEL_IN: 1 9 | INIT_MODEL: 10 | ADV: False 11 | 12 | DATASET: 13 | NAME: prostateCT_vol 14 | PATH: C:\Users\zhzhang\Desktop\data\prostate_CT_split 15 | N_CLASSES: 6 16 | TRAIN_SPLIT: train 17 | VAL_SPLIT: val 18 | NORMALIZE: False 19 | COLOR_MAP: [[229, 255, 204],[0, 255, 255],[204, 0, 102],[255, 0, 0],[0, 255, 0]] 20 | 21 | TRAINING: 22 | ITER_MAX: 20000 23 | WORKERS: 4 24 | BATCH_SIZE: 2 25 | VAL_INTERVAL: 1000 26 | PRINT_INTERVAL: 100 27 | OPTIM: 28 | name: adam 29 | lr: 0.001 30 | LR_SCHEDULER: 31 | name: poly_lr 32 | max_iter: 20000 33 | LOSS_FUNC: 34 | name: cross_entropy3d 35 | reduction: mean -------------------------------------------------------------------------------- /libs/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/loader/__init__.py 2 | 3 | import json 4 | 5 | from torch.utils import data as torch_data 6 | 7 | from .prostate_ct_volume_loader import Prostate_CT_Volume_Loader 8 | 9 | 10 | def get_loader(name): 11 | 12 | return { 13 | "prostateCT_vol": Prostate_CT_Volume_Loader 14 | }[name] 15 | 16 | 17 | def build_data_loader(config, write, logger): 18 | 19 | # setup data_loader 20 | data_loader = get_loader(config.DATASET.NAME) 21 | data_path = config.DATASET.PATH 22 | 23 | t_loader = data_loader( 24 | root_dir=data_path, 25 | split=config.DATASET.TRAIN_SPLIT, 26 | normalize=config.DATASET.NORMALIZE 27 | ) 28 | v_loader = data_loader( 29 | root_dir=data_path, 30 | split=config.DATASET.VAL_SPLIT, 31 | normalize=config.DATASET.NORMALIZE 32 | ) 33 | 34 | train_loader = torch_data.DataLoader( 35 | t_loader, 36 | batch_size=config.TRAINING.BATCH_SIZE, 37 | num_workers=config.TRAINING.WORKERS, 38 | shuffle=True 39 | ) 40 | val_loader = torch_data.DataLoader( 41 | v_loader, 42 | batch_size=config.TRAINING.BATCH_SIZE, 43 | num_workers=config.TRAINING.WORKERS, 44 | shuffle=False 45 | ) 46 | 47 | logger.info("train_loader, val_loader ready.") 48 | 49 | return t_loader, v_loader, train_loader, val_loader -------------------------------------------------------------------------------- /libs/data_loaders/prostate_ct_volume_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils import data as torch_data 6 | 7 | class Prostate_CT_Volume_Loader(torch_data.Dataset): 8 | 9 | def __init__( 10 | self, 11 | root_dir, 12 | split, 13 | normalize=False 14 | ): 15 | 16 | self.normalize = normalize 17 | self.n_classes = 6 18 | 19 | self.split_dir = os.path.join(root_dir, split) 20 | self.split_dir_ct = os.path.join(self.split_dir, "CT") 21 | self.split_dir_seg = os.path.join(self.split_dir, "SEG") 22 | self.split_dir_onehot = os.path.join(self.split_dir, "SEG_onehot") 23 | 24 | self.volume_id_list = [x for x in os.listdir(self.split_dir_ct)] 25 | self.volume_list = [os.path.join(self.split_dir_ct, x) for x in os.listdir(self.split_dir_ct)] 26 | self.label_list = [os.path.join(self.split_dir_seg, x) for x in os.listdir(self.split_dir_seg)] 27 | self.onehot_list = [os.path.join(self.split_dir_onehot, x) for x in os.listdir(self.split_dir_onehot)] 28 | 29 | def __len__(self): 30 | return len(self.volume_id_list) 31 | 32 | def __getitem__(self, index): 33 | 34 | patient_id = self.volume_id_list[index] 35 | ct_vol_path = self.volume_list[index] 36 | label_path = self.label_list[index] 37 | onehot_path = self.onehot_list[index] 38 | 39 | ct_vol = np.load(ct_vol_path) 40 | if self.normalize: 41 | norm = np.linalg.norm(ct_vol) 42 | ct_vol = ct_vol/norm 43 | 44 | ct_vol = np.expand_dims(ct_vol, axis=0) 45 | ct = torch.from_numpy(ct_vol).float() 46 | 47 | label = np.load(label_path) 48 | label = torch.from_numpy(label).long() 49 | 50 | onehot = np.load(onehot_path) 51 | onehot = torch.from_numpy(onehot).long() 52 | 53 | return patient_id, ct, label, onehot 54 | 55 | 56 | -------------------------------------------------------------------------------- /libs/data_processing/shrink_size.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import glob 4 | import numpy as np 5 | 6 | def shrink_ct_size(vol): 7 | s, h, w = vol.shape # (64, 448, 448) 8 | 9 | new_h = int(h/7) 10 | new_w = int(w/7) 11 | print((s, h, w), "to", (s, new_h, new_w)) 12 | new_vol = np.zeros((s, new_h, new_w), dtype=np.int8) 13 | for z in range(0, s): 14 | for r in range(0, new_h): 15 | for c in range(0, new_w): 16 | new_vol[z][r][c] = vol[z][2*r][2*c] 17 | assert new_vol.shape == (64, 64, 64) 18 | return new_vol 19 | 20 | 21 | def shrink_label_size(vol): 22 | s, h, w = vol.shape # (64, 448, 448) 23 | new_h = int(h/7) 24 | new_w = int(w/7) 25 | print((s, h, w), "to", (s, new_h, new_w)) 26 | new_vol = np.zeros((s, new_h, new_w), dtype=np.int8) 27 | for z in range(0, s): 28 | for r in range(0, new_h): 29 | for c in range(0, new_w): 30 | new_vol[z][r][c] = vol[z][7*r][7*c] 31 | assert new_vol.shape == (64, 64, 64) 32 | return new_vol 33 | 34 | def shrink_onehot_size(vol): 35 | vol = torch.from_numpy(vol).permute(1, 2, 3, 0).data.numpy() 36 | s, h, w, n = vol.shape 37 | new_h = int(h/7) 38 | new_w = int(w/7) 39 | 40 | new_vol = np.zeros((s, new_h, new_w, n), dtype=np.int8) 41 | for z in range(0, s): 42 | for r in range(0, new_h): 43 | for c in range(0, new_w): 44 | new_vol[z][r][c] = vol[z][7*r][7*c] 45 | new_vol = torch.from_numpy(new_vol).permute(3, 0, 1, 2).data.numpy() 46 | assert new_vol.shape == (6, 64, 64, 64) 47 | return new_vol 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | ct_vol_dir = "C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\CT" 53 | label_vol_dir = "C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\SEG" 54 | onehot_vol_dir = "C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\SEG\\onehot_encoded" 55 | 56 | ct_vol_list = glob.glob(ct_vol_dir+"\\*.npy") 57 | label_vol_list = glob.glob(label_vol_dir+"\\*.npy") 58 | onehot_vol_list = glob.glob(onehot_vol_dir+"\\*.npy") 59 | 60 | for ct_vol in ct_vol_list: 61 | file_name = ct_vol.split("\\")[-1] 62 | print("Working on {}.".format(file_name)) 63 | vol = np.load(ct_vol) 64 | vol = shrink_ct_size(vol) 65 | np.save(os.path.join("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_64_64\\CT", file_name), vol) 66 | print("{} ct volume done.".format(file_name)) 67 | print("-"*30) 68 | 69 | 70 | for label_vol in label_vol_list: 71 | file_name = label_vol.split("\\")[-1] 72 | print("Working on {}.".format(file_name)) 73 | vol = np.load(label_vol) 74 | vol = shrink_label_size(vol) 75 | np.save(os.path.join("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_64_64\\SEG", file_name), vol) 76 | print("{} label volume done.".format(file_name)) 77 | print("-"*30) 78 | 79 | 80 | for onehot_vol in onehot_vol_list: 81 | file_name = onehot_vol.split("\\")[-1] 82 | print("Working on {}.".format(file_name)) 83 | vol = np.load(onehot_vol) 84 | vol = shrink_onehot_size(vol) 85 | np.save(os.path.join("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_64_64\\SEG\\onehot_encoded", file_name), vol) 86 | print("{} onehot label volume done.".format(file_name)) 87 | print("-"*30) -------------------------------------------------------------------------------- /libs/data_processing/slices_to_volumes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def trim(imgs, masks): 9 | assert len(imgs) == len(masks) 10 | 11 | n = len(imgs) 12 | if n > 64: 13 | start_i = int((n-64)/2)-2 14 | 15 | output_imgs = [] 16 | output_masks = [] 17 | 18 | for i in range(0,64): 19 | output_imgs.append(imgs[start_i+i]) 20 | output_masks.append(masks[start_i+i]) 21 | 22 | return output_imgs, output_masks 23 | else: 24 | img_padding = np.zeros((imgs[0].shape), dtype=np.int8) 25 | mask_padding = np.zeros((masks[0].shape), dtype=np.int8) 26 | for _ in range(0, 64-n): 27 | imgs.append(img_padding) 28 | masks.append(mask_padding) 29 | 30 | return imgs, masks 31 | 32 | 33 | 34 | 35 | 36 | def list_to_patient_volumes(img_list, color_mask_list): 37 | 38 | ''' 39 | input: img_list, color_mask_list under a dir 40 | output: [[patient0_img_volume, patient0_label_volume], ...] 41 | ''' 42 | 43 | patients = [] 44 | 45 | patient_id = int(img_list[0].split('\\')[-1].split('_')[1]) 46 | patient_imgs = [] 47 | patient_masks = [] 48 | 49 | for img_id in range(0, len(img_list)): 50 | img_path = img_list[img_id] 51 | 52 | img_array = np.asarray(Image.open(img_path).convert('L').crop((32, 32, 480, 480))) 53 | color_array = np.asarray(Image.open(color_mask_list[img_id]).convert('RGB').crop((32, 32, 480, 480))) 54 | 55 | cur_id = int(img_path.split('\\')[-1].split('_')[1]) 56 | 57 | if img_id != len(img_list)-1: 58 | if cur_id == patient_id: 59 | patient_imgs.append(img_array) 60 | patient_masks.append(color_array) 61 | else: 62 | patient_imgs, patient_masks = trim(patient_imgs, patient_masks) 63 | 64 | imgs = [] 65 | labels = [] 66 | 67 | for slice_id in range(0, 64): 68 | img = patient_imgs[slice_id] 69 | mask = patient_masks[slice_id] 70 | label = color_to_label(mask, [[229, 255, 204],[0, 255, 255],[204, 0, 102],[255, 0, 0],[0, 255, 0]]) 71 | 72 | imgs.append(img) 73 | labels.append(label) 74 | print("Slice {} done for patient {}".format(slice_id, patient_id)) 75 | 76 | patients.append([imgs, labels]) 77 | print("Patient {} done.".format(patient_id)) 78 | np.save("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\CT\\patient_{}_volume.npy".format(patient_id),imgs) 79 | np.save("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\seg\\seg_patient_{}_volume.npy".format(patient_id),labels) 80 | print('-'*25) 81 | 82 | patient_id = cur_id 83 | patient_imgs = [img_array] 84 | patient_masks = [color_array] 85 | else: 86 | patient_imgs.append(img_array) 87 | patient_masks.append(color_array) 88 | 89 | patient_imgs, patient_masks = trim(patient_imgs, patient_masks) 90 | 91 | imgs = [] 92 | labels = [] 93 | 94 | for slice_id in range(0, 64): 95 | img = patient_imgs[slice_id] 96 | mask = patient_masks[slice_id] 97 | label = color_to_label(mask, [[229, 255, 204],[0, 255, 255],[204, 0, 102],[255, 0, 0],[0, 255, 0]]) 98 | 99 | imgs.append(img) 100 | labels.append(label) 101 | print("Slice {} done for patient {}".format(slice_id, patient_id)) 102 | 103 | patients.append([imgs, labels]) 104 | print("Patient {} done.".format(patient_id)) 105 | np.save("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\CT\\patient_{}_volume.npy".format(patient_id),imgs) 106 | np.save("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\seg\\seg_patient_{}_volume.npy".format(patient_id),labels) 107 | print('-'*25) 108 | 109 | 110 | return patients 111 | 112 | 113 | def color_to_label(color_mask, color_map): 114 | 115 | ''' 116 | input: color_mask in rgb, color_map 117 | output: label_mask with class_id 118 | ''' 119 | 120 | h, w, c = color_mask.shape 121 | output = np.zeros((h,w), dtype=np.int8) 122 | 123 | for r in range(0,h): 124 | for c in range(0,w): 125 | for i in range(0, len(color_map)): 126 | if (color_mask[r][c] == color_map[i]).all(): 127 | output[r][c] = i+1 128 | break 129 | 130 | return output 131 | 132 | def pre_encode_label_vol(label_vol): 133 | 134 | ''' 135 | input: label_vol with class_id 136 | output: one-hot encoded label vol 137 | ''' 138 | z, h, w = label_vol.shape 139 | onehot_vol = np.zeros((z, h, w, 6), dtype=np.int8) 140 | 141 | for s in range(0,z): 142 | for r in range(0, h): 143 | for c in range(0, w): 144 | onehot = np.zeros(6, dtype=np.int8) 145 | onehot[label_vol[s][r][c]] = 1 146 | onehot_vol[s][r][c] = onehot 147 | onehot_vol = torch.from_numpy(onehot_vol).permute(3, 0, 1, 2).data.numpy() 148 | assert onehot_vol.shape == (6, 64, 448, 448) 149 | 150 | return onehot_vol 151 | 152 | 153 | if __name__ == "__main__": 154 | 155 | 156 | # # Slices to Volumes 157 | # img_dir = os.path.expanduser("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_2d\\orig_CT\\*.png") 158 | # color_mask_dir = os.path.expanduser("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_2d\\seg_mask\\*.png") 159 | 160 | # img_list = glob.glob(img_dir) 161 | # color_mask_list = glob.glob(color_mask_dir) 162 | # assert len(img_list) == len(color_mask_list) 163 | 164 | # patients = list_to_patient_volumes(img_list, color_mask_list) 165 | 166 | # Test Block 167 | test_vol = np.load("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_224_224\\CT\\patient_1_volume.npy") 168 | test_seg_vol = np.load("C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_224_224\\SEG\\seg_patient_1_volume.npy") 169 | 170 | print(test_vol.shape, test_seg_vol.shape) 171 | 172 | 173 | # # Onehot-encode Mask Volumes 174 | # label_vol_dir = "C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\seg\\*.npy" 175 | # onehot_vol_dir = "C:\\Users\\zhzhang\\Desktop\\data\\prostate_CT_vol_64_448_448\\seg\\onehot_encoded" 176 | # label_vol_list = glob.glob(label_vol_dir) 177 | 178 | # for label_vol in label_vol_list: 179 | # patient_id = label_vol.split("\\")[-1].split('_')[2] 180 | # vol = np.load(label_vol) 181 | # onehot_vol = pre_encode_label_vol(vol) 182 | # np.save(os.path.join(onehot_vol_dir, "encoded_seg_patient_{}_vol.npy".format(patient_id)), onehot_vol) 183 | # print("Patient {} done.".format(patient_id)) 184 | -------------------------------------------------------------------------------- /libs/loss_funcs/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/loss/__init__.py 2 | 3 | import logging 4 | import functools 5 | import numpy as np 6 | 7 | from .loss import ( 8 | cross_entropy3d 9 | ) 10 | 11 | key2loss = { 12 | "cross_entropy3d": cross_entropy3d 13 | } 14 | 15 | def get_loss_function(cfg): 16 | 17 | if cfg.TRAINING.LOSS_FUNC is None: 18 | return cross_entropy3d 19 | else: 20 | loss_dict = cfg.TRAINING.LOSS_FUNC 21 | loss_name = loss_dict.name 22 | loss_params = {k: v for k, v in loss_dict.items() if k != "name"} 23 | 24 | if loss_name not in key2loss: 25 | raise NotImplementedError("Loss {} not implemented".format(loss_name)) 26 | return functools.partial(key2loss[loss_name], **loss_params) -------------------------------------------------------------------------------- /libs/loss_funcs/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | def cross_entropy3d(input, target, weight=None, reduction="mean"): 8 | _, c, h, w, z = input.size() 9 | _, ht, wt, zt = target.size() 10 | 11 | # Handle inconsistent size between input and target 12 | if h != ht and w != wt: # upsample labels 13 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) 14 | 15 | input = input.permute(0, 2, 3, 4, 1).contiguous().view(-1, c) #input size 16 | target = target.view(-1) # target size 17 | 18 | loss = F.cross_entropy( 19 | input, target, weight=weight, reduction=reduction, ignore_index=250 20 | ) 21 | return loss -------------------------------------------------------------------------------- /libs/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * -------------------------------------------------------------------------------- /libs/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class running_seg_score(object): 9 | def __init__(self, n_classes): 10 | self.n_classes = n_classes 11 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 12 | 13 | def _fast_hist(self, label_true, label_pred, n_class): 14 | mask = (label_true >= 0) & (label_true < n_class) 15 | hist = np.bincount( 16 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 17 | ).reshape(n_class, n_class) 18 | return hist 19 | 20 | def update(self, label_trues, label_preds): 21 | for lt, lp in zip(label_trues, label_preds): 22 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 23 | 24 | def get_scores(self): 25 | """Returns accuracy score evaluation result. 26 | - overall accuracy 27 | - mean accuracy 28 | """ 29 | hist = self.confusion_matrix 30 | acc = np.diag(hist).sum() / hist.sum() 31 | acc_cls = np.diag(hist) / hist.sum(axis=1) 32 | acc_cls = np.nanmean(acc_cls) 33 | 34 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 35 | dsc = 2 * iu/(1+iu) 36 | mean_dsc = np.nanmean(dsc) 37 | 38 | freq = hist.sum(axis=1) / hist.sum() 39 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 40 | cls_dsc = dict(zip(range(self.n_classes), dsc)) 41 | 42 | return ( 43 | { 44 | "Overall Acc: \t": acc, 45 | "Mean Acc : \t": acc_cls, 46 | "FreqW Acc : \t": fwavacc, 47 | "Mean Dice Coefficient: \t": mean_dsc 48 | }, 49 | cls_dsc 50 | ) 51 | 52 | def reset(self): 53 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 54 | 55 | 56 | 57 | class averageMeter(object): 58 | """Computes and stores the average and current value""" 59 | 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /libs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .deeper_resunet_3d import deeper_resunet_3d 3 | from .resunet_3d import resunet_3d 4 | 5 | 6 | def Deeper_ResUnet_3D(n_classes, base_filters, channel_in): 7 | return deeper_resunet_3d(n_classes, base_filters, channel_in) 8 | 9 | def ResUnet_3D(n_classes, base_filters, channel_in): 10 | return resunet_3d(n_classes, base_filters, channel_in) -------------------------------------------------------------------------------- /libs/models/deeper_resunet_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def deeper_resunet_3d(n_classes, base_filters, channel_in): 6 | model_in_block = in_block(channel_in=channel_in, channel_out=base_filters) 7 | model_encoder = encoder(base_filters=base_filters) 8 | model_decoder = decoder(base_filters=base_filters) 9 | model_seg_out_block = seg_out_block(base_filters=base_filters, n_classes=n_classes) 10 | 11 | model = seg_path( 12 | model_in_block, 13 | model_encoder, 14 | model_decoder, 15 | model_seg_out_block 16 | ) 17 | 18 | return model 19 | 20 | class in_block(nn.Module): 21 | ''' 22 | in_block is used to connect the input of the whole network. 23 | 24 | number of channels is changed by conv1, and then it keeps the same for all 25 | following layers. 26 | 27 | parameters: 28 | channel_in: int 29 | the number of channels of the input. 30 | RGB images have 3, greyscale images have 1, etc. 31 | channel_out: int 32 | the number of filters for conv1; keeps unchanged for all layers following 33 | conv1 34 | 35 | ''' 36 | def __init__(self, channel_in, channel_out): 37 | super(in_block, self).__init__() 38 | 39 | self.channel_in = channel_in 40 | self.channel_out = channel_out 41 | 42 | self.conv1 = nn.Conv3d( 43 | kernel_size=3, 44 | in_channels=self.channel_in, 45 | out_channels=self.channel_out, 46 | padding=1 47 | ) 48 | self.bn1 = nn.BatchNorm3d(num_features=self.channel_out) 49 | 50 | self.conv2 = nn.Conv3d( 51 | kernel_size=3, 52 | in_channels=self.channel_out, 53 | out_channels=self.channel_out, 54 | padding=1 55 | ) 56 | 57 | self.conv3 = nn.Conv3d( 58 | kernel_size=3, 59 | in_channels=self.channel_in, 60 | out_channels=self.channel_out, 61 | padding=1 62 | ) 63 | self.bn3 = nn.BatchNorm3d(num_features=self.channel_out) 64 | 65 | def forward(self, x): 66 | path = self.conv1(x) 67 | path = self.bn1(path) 68 | path = F.leaky_relu(path) 69 | path = F.dropout(path, p=0.2) 70 | 71 | path = self.conv2(path) 72 | 73 | residual = self.conv3(x) 74 | residual = self.bn3(residual) 75 | 76 | self.down_level1 = path + residual 77 | 78 | return self.down_level1 79 | 80 | class res_block(nn.Module): 81 | ''' 82 | res_block used for down and up, toggled by downsample. 83 | 84 | "input" -> bn1 -> relu1 -> conv1 -> bn2 -> relu2 -> conv2 -> "path" 85 | -> conv3 -> bn3 -> "residual" 86 | 87 | return "output" = "path" + "residual" 88 | 89 | downsampling (if any) is done by conv1 90 | 91 | parameters: 92 | channel_in: int 93 | downsample: boolean 94 | if downsample is true, the block is used for encoding path, 95 | during which the channels out are doubled by the conv1. 96 | conv1 will have stride 2. 97 | 98 | if downsample is false, the block is used for segmenting/restoring 99 | path, during which the channels keep the same through the block. 100 | conv1 will have stride 1. 101 | 102 | ''' 103 | def __init__( 104 | self, 105 | channel_in, 106 | downsample=False, 107 | ): 108 | super(res_block, self).__init__() 109 | 110 | self.channel_in = channel_in 111 | 112 | if downsample: 113 | self.channel_out = 2*self.channel_in 114 | self.conv1_stride = 2 115 | self.conv3_stride = 2 116 | else: 117 | self.channel_out = self.channel_in 118 | self.conv1_stride = 1 119 | self.conv3_stride = 1 120 | 121 | self.bn1 = nn.BatchNorm3d(num_features=self.channel_in) 122 | self.conv1 = nn.Conv3d( 123 | in_channels=self.channel_in, 124 | kernel_size=3, 125 | out_channels=self.channel_out, 126 | stride=self.conv1_stride, 127 | padding=1 128 | ) 129 | self.bn2 = nn.BatchNorm3d(num_features=self.channel_out) 130 | self.conv2 = nn.Conv3d( 131 | in_channels=self.channel_out, 132 | out_channels=self.channel_out, 133 | kernel_size=3, 134 | padding=1 135 | ) 136 | 137 | self.conv3 = nn.Conv3d( 138 | in_channels=self.channel_in, 139 | out_channels=self.channel_out, 140 | stride=self.conv3_stride, 141 | padding=1, 142 | kernel_size=3 143 | ) 144 | self.bn3 = nn.BatchNorm3d(num_features=self.channel_out) 145 | 146 | def forward(self, x): 147 | 148 | path = self.bn1(x) 149 | path = F.leaky_relu(path) 150 | path = F.dropout(path, p=0.2) 151 | 152 | path = self.conv1(path) 153 | path = self.bn2(path) 154 | path = F.leaky_relu(path) 155 | path = F.dropout(path, p=0.2) 156 | 157 | path = self.conv2(path) 158 | 159 | residual = self.conv3(x) 160 | residual = self.bn3(residual) 161 | 162 | output = path + residual 163 | 164 | return output 165 | 166 | class encoder(nn.Module): 167 | 168 | ''' 169 | encoder 170 | 171 | dataflow: 172 | x --down_block2--> down_level2 173 | --down_block3--> down_level3 174 | --down_block4--> down_level4 175 | --down_bridge--> codes 176 | 177 | parameters: 178 | base_filters: number of filters received from in_block; 16 by default. 179 | 180 | ''' 181 | def __init__( 182 | self, 183 | base_filters 184 | ): 185 | super(encoder, self).__init__() 186 | 187 | self.bf = base_filters 188 | 189 | self.down_block2 = res_block( 190 | channel_in=self.bf , 191 | downsample=True 192 | ) 193 | self.down_block3 = res_block( 194 | channel_in=self.bf *2, 195 | downsample=True 196 | ) 197 | self.down_block4 = res_block( 198 | channel_in=self.bf *4, 199 | downsample=True 200 | ) 201 | self.down_bridge = res_block( 202 | channel_in=self.bf *8, 203 | downsample=True 204 | ) 205 | 206 | def forward(self, x): 207 | 208 | self.down_level2 = self.down_block2(x) 209 | self.down_level3 = self.down_block3(self.down_level2) 210 | self.down_level4 = self.down_block4(self.down_level3) 211 | self.codes = self.down_bridge(self.down_level4) 212 | 213 | return self.codes 214 | 215 | class decoder(nn.Module): 216 | ''' 217 | decoder 218 | 219 | dataflow: 220 | x --upsample4--> up4 --up_block4--> up_level4 221 | --upsample3--> up3 --up_block3--> up_level3 222 | --upsample2--> up2 --up_block2--> up_level2 223 | --upsample1--> up1 --up_block1--> up_level1 224 | 225 | parameters: 226 | base_filters: number of filters consistent with encoder; 16 by default. 227 | 228 | ''' 229 | def __init__( 230 | self, 231 | base_filters 232 | ): 233 | super(decoder, self).__init__() 234 | self.bf = base_filters 235 | 236 | self.upsample4 = nn.ConvTranspose3d( 237 | in_channels=self.bf*16 , 238 | out_channels=self.bf*8 , 239 | kernel_size=2, 240 | stride=2 241 | ) 242 | self.conv4 = nn.Conv3d( 243 | in_channels=self.bf*16, 244 | out_channels=self.bf*8, 245 | kernel_size=1 246 | ) 247 | self.up_block4 = res_block( 248 | channel_in=self.bf*8, 249 | downsample=False 250 | ) 251 | 252 | self.upsample3 = nn.ConvTranspose3d( 253 | in_channels=self.bf*8, 254 | out_channels=self.bf*4, 255 | kernel_size=2, 256 | stride=2 257 | ) 258 | self.conv3 = nn.Conv3d( 259 | in_channels=self.bf*8, 260 | out_channels=self.bf*4, 261 | kernel_size=1 262 | ) 263 | self.up_block3 = res_block( 264 | channel_in=self.bf*4, 265 | downsample=False 266 | ) 267 | 268 | self.upsample2 = nn.ConvTranspose3d( 269 | in_channels=self.bf*4, 270 | out_channels=self.bf*2, 271 | kernel_size=2, 272 | stride=2 273 | ) 274 | self.conv2 = nn.Conv3d( 275 | in_channels=self.bf*4, 276 | out_channels=self.bf*2, 277 | kernel_size=1 278 | ) 279 | self.up_block2 = res_block( 280 | channel_in=self.bf*2, 281 | downsample=False 282 | ) 283 | 284 | self.upsample1 = nn.ConvTranspose3d( 285 | in_channels=self.bf*2, 286 | out_channels=self.bf, 287 | kernel_size=2, 288 | stride=2 289 | ) 290 | self.conv1 = nn.Conv3d( 291 | in_channels=self.bf*2, 292 | out_channels=self.bf, 293 | kernel_size=1 294 | ) 295 | self.up_block1 = res_block( 296 | channel_in=self.bf, 297 | downsample=False 298 | ) 299 | 300 | def forward(self, x): 301 | 302 | up4 = self.upsample4(x) 303 | self.up_level4 = self.up_block4(up4) 304 | 305 | up3 = self.upsample3(self.up_level4) 306 | self.up_level3 = self.up_block3(up3) 307 | 308 | up2 = self.upsample2(self.up_level3) 309 | self.up_level2 = self.up_block2(up2) 310 | 311 | up1 = self.upsample1(self.up_level2) 312 | self.up_level1 = self.up_block1(up1) 313 | 314 | return self.up_level1 315 | 316 | class seg_out_block(nn.Module): 317 | ''' 318 | seg_out_block, receive data from decoder and output the segmentation mask 319 | 320 | parameters: 321 | base_filters: number of filters received from in_block. 322 | n_classes: number of classes 323 | 324 | ''' 325 | def __init__( 326 | self, 327 | base_filters, 328 | n_classes=6 329 | ): 330 | super(seg_out_block, self).__init__() 331 | 332 | self.bf = base_filters 333 | self.n_classes = n_classes 334 | self.conv = nn.Conv3d( 335 | in_channels=self.bf, 336 | out_channels=self.n_classes, 337 | kernel_size=1 338 | ) 339 | 340 | def forward(self, x): 341 | self.output = self.conv(x) 342 | return self.output 343 | 344 | class seg_path(nn.Module): 345 | def __init__( 346 | self, 347 | in_block, 348 | encoder, 349 | decoder, 350 | seg_out_block 351 | ): 352 | super(seg_path, self).__init__() 353 | 354 | self.in_block = in_block 355 | self.encoder = encoder 356 | self.decoder = decoder 357 | self.seg_out_block = seg_out_block 358 | 359 | def forward(self, x): 360 | 361 | self.down_level1 = self.in_block(x) 362 | 363 | self.down_level2 = self.encoder.down_block2(self.down_level1) 364 | self.down_level3 = self.encoder.down_block3(self.down_level2) 365 | self.down_level4 = self.encoder.down_block4(self.down_level3) 366 | self.codes = self.encoder.down_bridge(self.down_level4) 367 | 368 | self.up4 = self.decoder.upsample4(self.codes) 369 | up4_dummy = torch.cat([self.up4, self.down_level4],1) 370 | up4_dummy = self.decoder.conv4(up4_dummy) 371 | self.up_level4 = self.decoder.up_block4(up4_dummy) 372 | 373 | self.up3 = self.decoder.upsample3(self.up_level4) 374 | up3_dummy = torch.cat([self.up3, self.down_level3], 1) 375 | up3_dummy = self.decoder.conv3(up3_dummy) 376 | self.up_level3 = self.decoder.up_block3(up3_dummy) 377 | 378 | self.up2 = self.decoder.upsample2(self.up_level3) 379 | up2_dummy = torch.cat([self.up2, self.down_level2], 1) 380 | up2_dummy = self.decoder.conv2(up2_dummy) 381 | self.up_level2 = self.decoder.up_block2(up2_dummy) 382 | 383 | self.up1 = self.decoder.upsample1(self.up_level2) 384 | up1_dummy = torch.cat([self.up1, self.down_level1], 1) 385 | up1_dummy = self.decoder.conv1(up1_dummy) 386 | self.up_level1 = self.decoder.up_block1(up1_dummy) 387 | 388 | self.output = self.seg_out_block(self.up_level1) 389 | 390 | return self.output -------------------------------------------------------------------------------- /libs/models/resunet_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def resunet_3d(n_classes, base_filters, channel_in): 6 | model_in_block = in_block(channel_in=channel_in, channel_out=base_filters) 7 | model_encoder = encoder(base_filters=base_filters) 8 | model_decoder = decoder(base_filters=base_filters) 9 | model_seg_out_block = seg_out_block(base_filters=base_filters, n_classes=n_classes) 10 | 11 | model = seg_path( 12 | model_in_block, 13 | model_encoder, 14 | model_decoder, 15 | model_seg_out_block 16 | ) 17 | 18 | return model 19 | 20 | class in_block(nn.Module): 21 | ''' 22 | in_block is used to connect the input of the whole network. 23 | 24 | number of channels is changed by conv1, and then it keeps the same for all 25 | following layers. 26 | 27 | parameters: 28 | channel_in: int 29 | the number of channels of the input. 30 | RGB images have 3, greyscale images have 1, etc. 31 | channel_out: int 32 | the number of filters for conv1; keeps unchanged for all layers following 33 | conv1 34 | 35 | ''' 36 | def __init__(self, channel_in, channel_out): 37 | super(in_block, self).__init__() 38 | 39 | self.channel_in = channel_in 40 | self.channel_out = channel_out 41 | 42 | self.conv1 = nn.Conv3d( 43 | kernel_size=3, 44 | in_channels=self.channel_in, 45 | out_channels=self.channel_out, 46 | padding=1 47 | ) 48 | self.bn1 = nn.BatchNorm3d(num_features=self.channel_out) 49 | 50 | self.conv2 = nn.Conv3d( 51 | kernel_size=3, 52 | in_channels=self.channel_out, 53 | out_channels=self.channel_out, 54 | padding=1 55 | ) 56 | 57 | self.conv3 = nn.Conv3d( 58 | kernel_size=3, 59 | in_channels=self.channel_in, 60 | out_channels=self.channel_out, 61 | padding=1 62 | ) 63 | self.bn3 = nn.BatchNorm3d(num_features=self.channel_out) 64 | 65 | def forward(self, x): 66 | path = self.conv1(x) 67 | path = self.bn1(path) 68 | path = F.leaky_relu(path) 69 | path = F.dropout(path, p=0.2) 70 | 71 | path = self.conv2(path) 72 | 73 | residual = self.conv3(x) 74 | residual = self.bn3(residual) 75 | 76 | self.down_level1 = path + residual 77 | 78 | return self.down_level1 79 | 80 | class res_block(nn.Module): 81 | ''' 82 | res_block used for down and up, toggled by downsample. 83 | 84 | "input" -> bn1 -> relu1 -> conv1 -> bn2 -> relu2 -> conv2 -> "path" 85 | -> conv3 -> bn3 -> "residual" 86 | 87 | return "output" = "path" + "residual" 88 | 89 | downsampling (if any) is done by conv1 90 | 91 | parameters: 92 | channel_in: int 93 | downsample: boolean 94 | if downsample is true, the block is used for encoding path, 95 | during which the channels out are doubled by the conv1. 96 | conv1 will have stride 2. 97 | 98 | if downsample is false, the block is used for segmenting/restoring 99 | path, during which the channels keep the same through the block. 100 | conv1 will have stride 1. 101 | 102 | ''' 103 | def __init__( 104 | self, 105 | channel_in, 106 | downsample=False, 107 | ): 108 | super(res_block, self).__init__() 109 | 110 | self.channel_in = channel_in 111 | 112 | if downsample: 113 | self.channel_out = 2*self.channel_in 114 | self.conv1_stride = 2 115 | self.conv3_stride = 2 116 | else: 117 | self.channel_out = self.channel_in 118 | self.conv1_stride = 1 119 | self.conv3_stride = 1 120 | 121 | self.bn1 = nn.BatchNorm3d(num_features=self.channel_in) 122 | self.conv1 = nn.Conv3d( 123 | in_channels=self.channel_in, 124 | kernel_size=3, 125 | out_channels=self.channel_out, 126 | stride=self.conv1_stride, 127 | padding=1 128 | ) 129 | self.bn2 = nn.BatchNorm3d(num_features=self.channel_out) 130 | self.conv2 = nn.Conv3d( 131 | in_channels=self.channel_out, 132 | out_channels=self.channel_out, 133 | kernel_size=3, 134 | padding=1 135 | ) 136 | 137 | self.conv3 = nn.Conv3d( 138 | in_channels=self.channel_in, 139 | out_channels=self.channel_out, 140 | stride=self.conv3_stride, 141 | padding=1, 142 | kernel_size=3 143 | ) 144 | self.bn3 = nn.BatchNorm3d(num_features=self.channel_out) 145 | 146 | def forward(self, x): 147 | 148 | path = self.bn1(x) 149 | path = F.leaky_relu(path) 150 | path = F.dropout(path, p=0.2) 151 | 152 | path = self.conv1(path) 153 | path = self.bn2(path) 154 | path = F.leaky_relu(path) 155 | path = F.dropout(path, p=0.2) 156 | 157 | path = self.conv2(path) 158 | 159 | residual = self.conv3(x) 160 | residual = self.bn3(residual) 161 | 162 | output = path + residual 163 | 164 | return output 165 | 166 | class encoder(nn.Module): 167 | 168 | ''' 169 | encoder 170 | 171 | dataflow: 172 | x --down_block2--> down_level2 173 | --down_block3--> down_level3 174 | --down_block4--> codes 175 | 176 | parameters: 177 | base_filters: number of filters received from in_block; 16 by default. 178 | 179 | ''' 180 | def __init__( 181 | self, 182 | base_filters 183 | ): 184 | super(encoder, self).__init__() 185 | 186 | self.bf = base_filters 187 | 188 | self.down_block2 = res_block( 189 | channel_in=self.bf , 190 | downsample=True 191 | ) 192 | self.down_block3 = res_block( 193 | channel_in=self.bf *2, 194 | downsample=True 195 | ) 196 | self.down_block4 = res_block( 197 | channel_in=self.bf *4, 198 | downsample=True 199 | ) 200 | 201 | def forward(self, x): 202 | 203 | self.down_level2 = self.down_block2(x) 204 | self.down_level3 = self.down_block3(self.down_level2) 205 | self.codes = self.down_block4(self.down_level3) 206 | 207 | return self.codes 208 | 209 | class decoder(nn.Module): 210 | ''' 211 | decoder 212 | 213 | dataflow: 214 | x --upsample3--> up3 --up_block3--> up_level3 215 | --upsample2--> up2 --up_block2--> up_level2 216 | --upsample1--> up1 --up_block1--> up_level1 217 | 218 | parameters: 219 | base_filters: number of filters consistent with encoder; 16 by default. 220 | 221 | ''' 222 | def __init__( 223 | self, 224 | base_filters 225 | ): 226 | super(decoder, self).__init__() 227 | self.bf = base_filters 228 | 229 | 230 | self.upsample3 = nn.ConvTranspose3d( 231 | in_channels=self.bf*8, 232 | out_channels=self.bf*4, 233 | kernel_size=2, 234 | stride=2 235 | ) 236 | self.conv3 = nn.Conv3d( 237 | in_channels=self.bf*8, 238 | out_channels=self.bf*4, 239 | kernel_size=1 240 | ) 241 | self.up_block3 = res_block( 242 | channel_in=self.bf*4, 243 | downsample=False 244 | ) 245 | 246 | self.upsample2 = nn.ConvTranspose3d( 247 | in_channels=self.bf*4, 248 | out_channels=self.bf*2, 249 | kernel_size=2, 250 | stride=2 251 | ) 252 | self.conv2 = nn.Conv3d( 253 | in_channels=self.bf*4, 254 | out_channels=self.bf*2, 255 | kernel_size=1 256 | ) 257 | self.up_block2 = res_block( 258 | channel_in=self.bf*2, 259 | downsample=False 260 | ) 261 | 262 | self.upsample1 = nn.ConvTranspose3d( 263 | in_channels=self.bf*2, 264 | out_channels=self.bf, 265 | kernel_size=2, 266 | stride=2 267 | ) 268 | self.conv1 = nn.Conv3d( 269 | in_channels=self.bf*2, 270 | out_channels=self.bf, 271 | kernel_size=1 272 | ) 273 | self.up_block1 = res_block( 274 | channel_in=self.bf, 275 | downsample=False 276 | ) 277 | 278 | def forward(self, x): 279 | 280 | 281 | up3 = self.upsample3(x) 282 | self.up_level3 = self.up_block3(up3) 283 | 284 | up2 = self.upsample2(self.up_level3) 285 | self.up_level2 = self.up_block2(up2) 286 | 287 | up1 = self.upsample1(self.up_level2) 288 | self.up_level1 = self.up_block1(up1) 289 | 290 | return self.up_level1 291 | 292 | class seg_out_block(nn.Module): 293 | ''' 294 | seg_out_block, receive data from decoder and output the segmentation mask 295 | 296 | parameters: 297 | base_filters: number of filters received from in_block. 298 | n_classes: number of classes 299 | 300 | ''' 301 | def __init__( 302 | self, 303 | base_filters, 304 | n_classes=6 305 | ): 306 | super(seg_out_block, self).__init__() 307 | 308 | self.bf = base_filters 309 | self.n_classes = n_classes 310 | self.conv = nn.Conv3d( 311 | in_channels=self.bf, 312 | out_channels=self.n_classes, 313 | kernel_size=1 314 | ) 315 | 316 | def forward(self, x): 317 | self.output = self.conv(x) 318 | return self.output 319 | 320 | class seg_path(nn.Module): 321 | def __init__( 322 | self, 323 | in_block, 324 | encoder, 325 | decoder, 326 | seg_out_block 327 | ): 328 | super(seg_path, self).__init__() 329 | 330 | self.in_block = in_block 331 | self.encoder = encoder 332 | self.decoder = decoder 333 | self.seg_out_block = seg_out_block 334 | 335 | def forward(self, x): 336 | 337 | self.down_level1 = self.in_block(x) 338 | 339 | self.down_level2 = self.encoder.down_block2(self.down_level1) 340 | self.down_level3 = self.encoder.down_block3(self.down_level2) 341 | self.codes = self.encoder.down_block4(self.down_level3) 342 | 343 | self.up3 = self.decoder.upsample3(self.codes) 344 | up3_dummy = torch.cat([self.up3, self.down_level3], 1) 345 | up3_dummy = self.decoder.conv3(up3_dummy) 346 | self.up_level3 = self.decoder.up_block3(up3_dummy) 347 | 348 | self.up2 = self.decoder.upsample2(self.up_level3) 349 | up2_dummy = torch.cat([self.up2, self.down_level2], 1) 350 | up2_dummy = self.decoder.conv2(up2_dummy) 351 | self.up_level2 = self.decoder.up_block2(up2_dummy) 352 | 353 | self.up1 = self.decoder.upsample1(self.up_level2) 354 | up1_dummy = torch.cat([self.up1, self.down_level1], 1) 355 | up1_dummy = self.decoder.conv1(up1_dummy) 356 | self.up_level1 = self.decoder.up_block1(up1_dummy) 357 | 358 | self.output = self.seg_out_block(self.up_level1) 359 | 360 | return self.output -------------------------------------------------------------------------------- /libs/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/optimizers/__init__.py 2 | 3 | import logging 4 | 5 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop 6 | 7 | logger = logging.getLogger("ptsemseg") 8 | 9 | key2opt = { 10 | "sgd": SGD, 11 | "adam": Adam, 12 | "asgd": ASGD, 13 | "adamax": Adamax, 14 | "adadelta": Adadelta, 15 | "adagrad": Adagrad, 16 | "rmsprop": RMSprop, 17 | } 18 | 19 | 20 | def get_optimizer(CONFIG): 21 | 22 | 23 | if CONFIG.TRAINING.OPTIM.name is None: 24 | logger.info("Using SGD optimizer") 25 | return SGD 26 | 27 | else: 28 | opt_name = CONFIG.TRAINING.OPTIM.name 29 | if opt_name not in key2opt: 30 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name)) 31 | 32 | logger.info("Using {} optimizer".format(opt_name)) 33 | return key2opt[opt_name] -------------------------------------------------------------------------------- /libs/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/schedulers/__init__.py 2 | 3 | import logging 4 | 5 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CosineAnnealingLR 6 | 7 | from .schedulers import WarmUpLR, ConstantLR, PolynomialLR 8 | 9 | logger = logging.getLogger("seg") 10 | 11 | key2scheduler = { 12 | "constant_lr": ConstantLR, 13 | "poly_lr": PolynomialLR, 14 | "multi_step": MultiStepLR, 15 | "cosine_annealing": CosineAnnealingLR, 16 | "exp_lr": ExponentialLR, 17 | } 18 | 19 | 20 | def get_scheduler(optimizer, scheduler_dict): 21 | if scheduler_dict is None: 22 | logger.info("Using No LR Scheduling") 23 | return ConstantLR(optimizer) 24 | 25 | s_type = scheduler_dict.name 26 | scheduler_dict.pop("name") 27 | 28 | logging.info("Using {} scheduler with {} params".format(s_type, scheduler_dict)) 29 | 30 | warmup_dict = {} 31 | if "warmup_iters" in scheduler_dict: 32 | # This can be done in a more pythonic way... 33 | warmup_dict["warmup_iters"] = scheduler_dict.get("warmup_iters", 100) 34 | warmup_dict["mode"] = scheduler_dict.get("warmup_mode", "linear") 35 | warmup_dict["gamma"] = scheduler_dict.get("warmup_factor", 0.2) 36 | 37 | logger.info( 38 | "Using Warmup with {} iters {} gamma and {} mode".format( 39 | warmup_dict["warmup_iters"], warmup_dict["gamma"], warmup_dict["mode"] 40 | ) 41 | ) 42 | 43 | scheduler_dict.pop("warmup_iters", None) 44 | scheduler_dict.pop("warmup_mode", None) 45 | scheduler_dict.pop("warmup_factor", None) 46 | 47 | base_scheduler = key2scheduler[s_type](optimizer, **scheduler_dict) 48 | return WarmUpLR(optimizer, base_scheduler, **warmup_dict) 49 | 50 | return key2scheduler[s_type](optimizer, **scheduler_dict) -------------------------------------------------------------------------------- /libs/schedulers/schedulers.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/schedulers/schedulers.py 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class ConstantLR(_LRScheduler): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | super(ConstantLR, self).__init__(optimizer, last_epoch) 9 | 10 | def get_lr(self): 11 | return [base_lr for base_lr in self.base_lrs] 12 | 13 | 14 | class PolynomialLR(_LRScheduler): 15 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1): 16 | self.decay_iter = decay_iter 17 | self.max_iter = max_iter 18 | self.gamma = gamma 19 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 23 | return [base_lr for base_lr in self.base_lrs] 24 | else: 25 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma 26 | return [base_lr * factor for base_lr in self.base_lrs] 27 | 28 | 29 | class WarmUpLR(_LRScheduler): 30 | def __init__( 31 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 32 | ): 33 | self.mode = mode 34 | self.scheduler = scheduler 35 | self.warmup_iters = warmup_iters 36 | self.gamma = gamma 37 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | cold_lrs = self.scheduler.get_lr() 41 | 42 | if self.last_epoch < self.warmup_iters: 43 | if self.mode == "linear": 44 | alpha = self.last_epoch / float(self.warmup_iters) 45 | factor = self.gamma * (1 - alpha) + alpha 46 | 47 | elif self.mode == "constant": 48 | factor = self.gamma 49 | else: 50 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 51 | 52 | return [factor * base_lr for base_lr in cold_lrs] 53 | 54 | return cold_lrs -------------------------------------------------------------------------------- /libs/utils/device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | 5 | def get_device(cuda): 6 | cuda = cuda and torch.cuda.is_available() 7 | device = torch.device("cuda" if cuda else "cpu") 8 | if cuda: 9 | print("Device:") 10 | for i in range(torch.cuda.device_count()): 11 | print(" {}:".format(i), torch.cuda.get_device_name(i)) 12 | else: 13 | print("Device: CPU") 14 | return device 15 | 16 | def memory_usage_report(device, logger=None): 17 | max_memory_allocated = float(torch.cuda.max_memory_allocated(device=device))/(10**9) 18 | max_memory_cached = float(torch.cuda.max_memory_cached(device=device))/(10**9) 19 | print("Max memory allocated on {}: ".format(device) + str(max_memory_allocated) + "GB.") 20 | print("Max memory cached on {}: ".format(device)+ str(max_memory_cached) + "GB.") 21 | 22 | if logger is not None: 23 | logger.info("Max memory allocated on {}:".format(device) + str(max_memory_allocated) + "GB.") 24 | logger.info("Max memory cached on {}: ".format(device)+ str(max_memory_cached) + "GB.") 25 | 26 | 27 | def dict_conversion(d): 28 | new_state_dict = OrderedDict() 29 | for k, v in d.items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | return new_state_dict -------------------------------------------------------------------------------- /libs/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/utils.py 2 | 3 | import os 4 | import datetime 5 | import logging 6 | 7 | 8 | def get_logger(logdir, job): 9 | logger = logging.getLogger(job) 10 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 11 | ts = ts.replace(":", "_").replace("-", "_") 12 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 13 | hdlr = logging.FileHandler(file_path) 14 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 15 | hdlr.setFormatter(formatter) 16 | logger.addHandler(hdlr) 17 | logger.setLevel(logging.INFO) 18 | return logger -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from train import train 4 | 5 | if __name__ == "__main__": 6 | 7 | parser = argparse.ArgumentParser(description="Specify: mode and configuration file path.") 8 | 9 | parser.add_argument( 10 | "-m", 11 | "--mode", 12 | nargs="?", 13 | choices=['train', 'val', 'test', 'demo'], 14 | type=str, 15 | default="train", 16 | help="Specify the mode here: trian/val/test/demo" 17 | ) 18 | 19 | parser.add_argument( 20 | "-c", 21 | "--config", 22 | nargs="?", 23 | type=str, 24 | default="config/prostateCT_deeper3dresunet_train.yml", 25 | help="Specify the path of configuration file here." 26 | ) 27 | 28 | args = parser.parse_args() 29 | 30 | locals()[args.mode](args.config) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import time 4 | import shutil 5 | import numpy as np 6 | from tqdm import tqdm 7 | from addict import Dict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from tensorboardX import SummaryWriter 12 | 13 | from libs.models import Deeper_ResUnet_3D, ResUnet_3D 14 | from libs.optimizers import get_optimizer 15 | from libs.schedulers import get_scheduler 16 | from libs.utils.logging import get_logger 17 | from libs.loss_funcs import get_loss_function 18 | from libs.data_loaders import build_data_loader 19 | from libs.metrics import running_seg_score, averageMeter 20 | from libs.utils.device import get_device, memory_usage_report, dict_conversion 21 | 22 | 23 | def train(config_file): 24 | 25 | # Configuration 26 | with open(config_file) as fp: 27 | CONFIG = Dict(yaml.load(fp)) 28 | 29 | # Device 30 | device = get_device(CONFIG.CUDA) 31 | torch.backends.cudnn.benchmark = True 32 | 33 | # Setup logger&writer and run dir 34 | run_id = time.strftime("%Y%m%d-%H%M%S") 35 | log_dir = os.path.join("runs", os.path.basename(config_file)[:-4], run_id) 36 | writer = SummaryWriter(log_dir=log_dir) 37 | 38 | print("RUN DIR: {}".format(log_dir)) 39 | shutil.copy(config_file, log_dir) 40 | 41 | logger = get_logger(log_dir, CONFIG.JOB) 42 | logger.info("Starting the program...") 43 | logger.propagate = False 44 | 45 | # Dataset 46 | logger.info("Using dataset: {}".format(CONFIG.DATASET.NAME)) 47 | 48 | # Dataloader 49 | _, _, train_loader, val_loader = build_data_loader(CONFIG, writer, logger) 50 | logger.info("Dataloader Ready.") 51 | 52 | if not CONFIG.MODEL.ADV: 53 | train_base(CONFIG, writer, logger, train_loader, val_loader, device) 54 | else: 55 | pass 56 | 57 | print('-'*60) 58 | memory_usage_report(device=torch.device("cuda:0"), logger=logger) 59 | print('-'*40) 60 | memory_usage_report(device=torch.device("cuda:1"), logger=logger) 61 | print('-'*60) 62 | 63 | 64 | def train_base(CONFIG, writer, logger, train_loader, val_loader, device): 65 | 66 | # Setup Model 67 | logger.info("Building: {}".format(CONFIG.MODEL.NAME)) 68 | model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES, 69 | base_filters=CONFIG.MODEL.BASE_FILTERS, channel_in=CONFIG.MODEL.CHANNEL_IN) 70 | 71 | if CONFIG.MODEL.INIT_MODEL is not None: 72 | 73 | # original saved file with DataParallel 74 | state_dict = torch.load(CONFIG.MODEL.INIT_MODEL)["model_state"] 75 | 76 | # create new OrderedDict that does not contain `module.` 77 | new_state_dict = dict_conversion(state_dict) 78 | 79 | # load parameters 80 | for m in model.state_dict().keys(): 81 | if m not in new_state_dict.keys(): 82 | print(" Skip init:", m) 83 | model.load_state_dict(new_state_dict, strict=False) 84 | print("Pre-trained weights loaded.") 85 | 86 | if CONFIG.PARALLEL: 87 | model = nn.DataParallel(model) 88 | model.to(device) 89 | print("Model is ready.") 90 | 91 | # Setup optimizer, lr_scheduler and loss function 92 | optimizer_cls = get_optimizer(CONFIG) 93 | optimizer_params = {k.lower(): v for k, v in CONFIG.TRAINING.OPTIM.items() if k != "name"} 94 | optim = optimizer_cls(model.parameters(), **optimizer_params) 95 | logger.info("Using optimizer{}".format(optim)) 96 | 97 | scheduler = get_scheduler(optim, CONFIG.TRAINING.LR_SCHEDULER) 98 | logger.info("Using lr_scheculer {}".format(scheduler)) 99 | 100 | loss_func = get_loss_function(CONFIG) 101 | logger.info("Using loss {}".format(loss_func)) 102 | 103 | # setup metrics 104 | running_metrics = running_seg_score(CONFIG.DATASET.N_CLASSES) 105 | 106 | # meters 107 | val_loss_meter = averageMeter() 108 | time_meter = averageMeter() 109 | 110 | best_dsc = -100 111 | start_iter = -1 112 | flag = True 113 | 114 | i = start_iter 115 | while i <= CONFIG.TRAINING.ITER_MAX and flag: 116 | for _, volumes, labels, _ in train_loader: 117 | 118 | i += 1 119 | 120 | ###############################Training############################### 121 | 122 | start_time = time.time() 123 | model.train() 124 | 125 | volumes = volumes.cuda() 126 | labels = labels.cuda() 127 | 128 | optim.zero_grad() 129 | 130 | outputs = model(volumes) 131 | loss = loss_func(input=outputs, target=labels) 132 | 133 | loss.backward() 134 | optim.step() 135 | scheduler.step() 136 | 137 | time_meter.update(time.time()-start_time) 138 | 139 | if (i + 1) % CONFIG.TRAINING.PRINT_INTERVAL== 0: 140 | fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" 141 | print_str = fmt_str.format( 142 | i + 1, 143 | CONFIG.TRAINING.ITER_MAX, 144 | loss.item(), 145 | time_meter.avg / CONFIG.TRAINING.BATCH_SIZE, 146 | ) 147 | 148 | print(print_str) 149 | logger.info(print_str) 150 | writer.add_scalar("loss/train_loss", loss.item(), i + 1) 151 | time_meter.reset() 152 | ###################################################################### 153 | 154 | #############################Validation############################### 155 | if (i + 1) % CONFIG.TRAINING.VAL_INTERVAL== 0: 156 | model.eval() 157 | with torch.no_grad(): 158 | for _, imgs_val, labels_val, _ in tqdm(val_loader): 159 | 160 | imgs_val = imgs_val.cuda() 161 | labels_val = labels_val.cuda() 162 | 163 | outputs = model(imgs_val) 164 | 165 | val_loss = loss_func(input=outputs, target=labels_val) 166 | 167 | pred = outputs.data.max(1)[1].cpu().numpy() 168 | gt = labels_val.data.cpu().numpy() 169 | 170 | running_metrics.update(gt, pred) 171 | val_loss_meter.update(val_loss.item()) 172 | 173 | writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) 174 | logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) 175 | 176 | # scoring 177 | score, class_dsc = running_metrics.get_scores() 178 | for k, v in score.items(): 179 | print(k, v) 180 | logger.info("{}: {}".format(k, v)) 181 | writer.add_scalar("val_metrics/{}".format(k), v, i + 1) 182 | 183 | for k, v in class_dsc.items(): 184 | logger.info("{}: {}".format(k, v)) 185 | writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) 186 | 187 | print('-' * 40) 188 | 189 | # reset 190 | val_loss_meter.reset() 191 | running_metrics.reset() 192 | 193 | if score["Mean Dice Coefficient: \t"] >= best_dsc: 194 | best_dsc = score["Mean Dice Coefficient: \t"] 195 | state = { 196 | "epoch": i + 1, 197 | "model_state": model.state_dict(), 198 | "optimizer_state": optim.state_dict(), 199 | "scheduler_state": scheduler.state_dict(), 200 | "best_dsc": best_dsc 201 | } 202 | 203 | writer.add_scalar("best_model/dsc", best_dsc, i+1) 204 | 205 | save_path = os.path.join( 206 | writer.file_writer.get_logdir(), 207 | "{}_{}_best_model.pkl".format(CONFIG.MODEL.NAME, CONFIG.DATASET.NAME), 208 | ) 209 | torch.save(state, save_path) 210 | ###################################################################### 211 | 212 | ################################End################################### 213 | if (i + 1) == CONFIG.TRAINING.ITER_MAX: 214 | flag = False 215 | break 216 | ###################################################################### 217 | --------------------------------------------------------------------------------