├── src ├── test.py ├── settings.py ├── settings.ini ├── Normalize.py ├── run.py ├── data_utils.py └── convert_h5.py └── README.md /src/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DUnetCNN 2 | Keras 3D U-Net Convolution Neural Network (CNN) designed for medical image segmentation 3 | Background 4 | 5 | Originally designed after [this paper](https://lmb.informatik.uni-freiburg.de/Publications/2016/CABR16/cicek16miccai.pdf) on volumetric segmentation with a 3D U-Net. 6 | The code was written to be trained using the BRATS data set for brain tumors, 7 | but it can be easily modified to be used in other 3D applications. 8 | -------------------------------------------------------------------------------- /src/settings.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import configparser 3 | from collections.abc import Mapping 4 | 5 | 6 | def parse_values(config): 7 | config_parsed = {} 8 | for section in config.sections(): 9 | config_parsed[section] = {} 10 | for key, value in config[section].items(): 11 | config_parsed[section][key] = ast.literal_eval(value) 12 | return config_parsed 13 | 14 | 15 | class Settings(Mapping): 16 | def __init__(self, setting_file='setting.ini'): 17 | config = configparser.ConfigParser() 18 | config.read(setting_file) 19 | self.settings_dict = parse_values(config) 20 | 21 | def __getitem__(self, key): 22 | return self.settings_dict[key] 23 | 24 | def __len__(self): 25 | return len(self.settings_dict) 26 | 27 | def __iter__(self): 28 | return self.settings_dict.items() 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/settings.ini: -------------------------------------------------------------------------------- 1 | [COMMON] 2 | save_model_dir = "saved_models" 3 | model_name = "quicknat" 4 | log_dir = "logs" 5 | device = "0" 6 | exp_dir = "experiments" 7 | 8 | [DATA] 9 | data_dir = "/Users/sq566/PycharmProjects/data" 10 | train_data_file = "Data_train.h5" 11 | train_label_file = "Label_train.h5" 12 | test_data_file = "Data_test.h5" 13 | test_label_file = "Label_test.h5" 14 | train_weights_file = "Weights_train.h5" 15 | test_weights_file = "Weights_test.h5" 16 | labels = ["Background", "CP"] 17 | 18 | [NETWORK] 19 | num_class = 2 20 | num_channels = 1 21 | num_filters = 64 22 | pool = 2 23 | stride_pool = 2 24 | drop_out = 0 25 | kernel_c = 1 26 | stride_conv = 1 27 | 28 | [TRAINING] 29 | exp_name = "CP_Finetuning" 30 | final_model_file = "quicknat_finetuned_1.pth.tar" 31 | learning_rate = 1e-6 32 | train_batch_size = 1 33 | num_epochs = 10 34 | optim_betas = (0.9, 0.99) 35 | optim_eps = 1e-8 36 | optim_weight_decay = 0.00001 37 | lr_scheduler_step_size = 5 38 | lr_scheduler_gamma = 0.5 39 | -------------------------------------------------------------------------------- /src/Normalize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | 4 | def normalize_data(data, mean, std): 5 | data -= mean 6 | data /= std 7 | return data 8 | 9 | def normalize_data_storage(data_storage): 10 | modality='/t1ce.nii.gz' 11 | means = list() 12 | stds = list() 13 | for subject in case_arr: 14 | img = nib.load(subject+modality) 15 | imgU16 = img.get_data().astype(np.int16) 16 | means.append(imgU16.mean(axis=(0, 1, 2))) 17 | stds.append(imgU16.std(axis=(0, 1, 2))) 18 | 19 | mean = np.asarray(means).mean(axis=0) 20 | std = np.asarray(stds).mean(axis=0) 21 | count = 0 22 | for subject in case_arr: 23 | img = nib.load(subject+modality) 24 | data = img.get_data().astype(np.float64) 25 | data_n = normalize_data(data, mean, std) 26 | data_n[data_n < 0.0] = 0; 27 | print('Case ' + str(count) + ' done') 28 | image = nib.Nifti1Image(data_n, img.affine, img.header) 29 | nib.save(image , subject+ '/' + 't1ce_n.nii.gz') 30 | count = count + 1 31 | 32 | cases = '/rfanfs/pnl-zorro/home/sq566/pycharm/brats/3DUnetCNN/brats/data/original/case.txt' 33 | with open(cases) as f: 34 | case_arr = f.read().splitlines() 35 | 36 | normalize_data_storage(case_arr) 37 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from settings import Settings 4 | from quicknat import QuickNat 5 | from utils.data_utils import get_imdb_dataset 6 | from solver import Solver 7 | import tables 8 | 9 | 10 | def load_data(data_params): 11 | print("Loading dataset") 12 | train_data, test_data = get_imdb_dataset(data_params) 13 | print("Train size: %i" % len(train_data)) 14 | print("Test size: %i" % len(test_data)) 15 | return train_data, test_data 16 | 17 | 18 | def train(train_params, common_params, data_params, net_params): 19 | 20 | train_data, test_data = load_data(data_params) 21 | 22 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_params['train_batch_size'], 23 | shuffle=False, num_workers=4, pin_memory=False) 24 | 25 | quicknat_model = QuickNat(net_params) 26 | #print(quicknat_model) 27 | 28 | solver = Solver(quicknat_model, 29 | device=common_params['device'], 30 | num_class=net_params['num_class'], 31 | optim_args={"lr": train_params['learning_rate'], 32 | "betas": train_params['optim_betas'], 33 | "eps": train_params['optim_eps'], 34 | "weight_decay": train_params['optim_weight_decay'] 35 | }, 36 | model_name=common_params['model_name'], 37 | labels=data_params['labels'], 38 | num_epochs=train_params['num_epochs'], 39 | lr_scheduler_step_size=train_params['lr_scheduler_step_size'], 40 | lr_scheduler_gamma=train_params['lr_scheduler_gamma']) 41 | 42 | solver.train(train_loader) 43 | # final_model_path = os.path.join(common_params['save_mode_dir'], train_params['final_model_file']) 44 | # quicknat_model.save_model(final_model_path) 45 | # print("Final model saved @ " + str(final_model_path)) 46 | 47 | 48 | if __name__ == '__main__': 49 | settings = Settings() 50 | 51 | common_params = settings['COMMON'] 52 | data_params = settings['DATA'] 53 | net_params = settings['NETWORK'] 54 | train_params = settings['TRAINING'] 55 | 56 | train(train_params, common_params, data_params, net_params) 57 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import nibabel as nb 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | 9 | 10 | class ImdbData(data.Dataset): 11 | 12 | def __init__(self, x, y): 13 | self.x = x 14 | self.y = y 15 | 16 | def __getitem__(self, index): 17 | #print("Index = ", index) 18 | img = torch.from_numpy(self.x[index]) 19 | label = torch.from_numpy(self.y[index]) 20 | return img, label 21 | 22 | def __len__(self): 23 | return len(self.y) 24 | 25 | 26 | def get_imdb_dataset(data_params): 27 | 28 | data_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_data_file']), 'r') 29 | label_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_label_file']), 'r') 30 | 31 | data_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_data_file']), 'r') 32 | label_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_label_file']), 'r') 33 | 34 | return ImdbData(data_train['data'], label_train['label']), ImdbData(data_test['data'], label_test['label']) 35 | 36 | 37 | def load_file_paths(data_dir, label_dir): 38 | """ 39 | Returns list of lists containing paths to T1.mgz and CP-ROI.mgz 40 | [['/Users/sq566/case1/T1.mgz', '/Users/sq566/case1/CP-ROI.mgz'], 41 | ['/Users/sq566/case2/T1.mgz', '/Users/sq566/case2/CP-ROI.mgz']] 42 | """ 43 | volume_to_use = [name for name in os.listdir(data_dir)] 44 | # List of list 45 | file_path = [ 46 | [ 47 | os.path.join(data_dir, vol, 'T1.mgz'), 48 | os.path.join(label_dir, vol, 'CP-ROI.mgz') 49 | ] for vol in volume_to_use 50 | ] 51 | 52 | return file_path 53 | 54 | 55 | def load_data(file_path): 56 | volume_nifty, labemap_nifty = nb.load(file_path[0]), nb.load(file_path[1]) 57 | volume, labelmap = volume_nifty.get_fdata(), labemap_nifty.get_fdata() 58 | p = np.percentile(volume, 99) 59 | vol_data = volume / p 60 | vol_data[vol_data > 1] = 1 61 | vol_data[vol_data < 0] = sys.float_info.epsilon 62 | labelmap[labelmap > 0.0] = 1 63 | print(vol_data.shape) 64 | print(labelmap.shape) 65 | return vol_data, labelmap, volume_nifty.header 66 | 67 | 68 | def load_and_preprocess(file_path): 69 | volume, labelmap, header = load_data(file_path) 70 | return volume, labelmap, header 71 | 72 | 73 | def load_dataset(file_paths): 74 | 75 | volume_list, labelmap_list, headers = [], [], [] 76 | 77 | for file_path in file_paths: 78 | volume, labelmap, header = load_and_preprocess(file_path) 79 | 80 | # Appending 3D numpy array to list 81 | volume_list.append(volume) 82 | labelmap_list.append(labelmap) 83 | headers.append(header) 84 | 85 | return volume_list, labelmap_list, headers 86 | -------------------------------------------------------------------------------- /src/convert_h5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import data_utils as du 4 | import numpy as np 5 | import h5py 6 | 7 | 8 | def apply_split(data_split, data_dir, label_dir): 9 | """ 10 | Randomly splits the data in training and test set 11 | """ 12 | file_paths = du.load_file_paths(data_dir, label_dir) 13 | print("Total no of volumes to process: %d" % len(file_paths)) 14 | train_ratio, test_ratio = data_split.split(",") 15 | train_len = int((int(train_ratio) / 100.0) * len(file_paths)) 16 | 17 | train_idx = np.random.choice(len(file_paths), train_len, replace=False) 18 | test_idx = np.array([i for i in range(len(file_paths)) if i not in train_idx]) 19 | train_file_paths = [file_paths[i] for i in train_idx] 20 | test_file_paths = [file_paths[i] for i in test_idx] 21 | return train_file_paths, test_file_paths 22 | 23 | 24 | def _write_h5(data, label, f, mode): 25 | 26 | n_slices, h, w = data[0].shape 27 | with h5py.File(f[mode]['data'], "w") as data_handle: 28 | # -1 simply means that it is an unknown dimension and we want numpy to figure it out 29 | data_handle.create_dataset("data", data=np.concatenate(data).reshape(-1, h, w)) 30 | with h5py.File(f[mode]['label'], "w") as label_handle: 31 | label_handle.create_dataset("label", data=np.concatenate(label).reshape(-1, h, w)) 32 | 33 | 34 | def convert_h5(data_dir, label_dir, data_split, f): 35 | 36 | if data_split: 37 | train_file_paths, test_file_paths = apply_split(data_split, data_dir, label_dir) 38 | else: 39 | raise ValueError('Please provide the split ratio') 40 | 41 | print("Training dataset size: ", len(train_file_paths)) 42 | print("Testing dataset size: ", len(test_file_paths)) 43 | 44 | # data_train = list of 3D numpy array of training volumes 45 | # label_train = list of 3D numpy array of training labels 46 | # _ = list of header of training volumes 47 | print("Loading and pre-processing Training data...") 48 | data_train, label_train, _ = du.load_dataset(train_file_paths) 49 | _write_h5(data_train, label_train, f, mode="train") 50 | 51 | print("Loading and pre-processing Testing data...") 52 | data_test, label_test, _ = du.load_dataset(test_file_paths) 53 | _write_h5(data_test, label_test, f, mode="test") 54 | 55 | 56 | if __name__ == "__main__": 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--data_dir') 60 | parser.add_argument('--label_dir') 61 | parser.add_argument('--data_split') 62 | parser.add_argument('--destination_folder') 63 | 64 | args = parser.parse_args() 65 | 66 | f = { 67 | 'train': { 68 | "data": os.path.join(args.destination_folder, "Data_train.h5"), 69 | "label": os.path.join(args.destination_folder, "Label_train.h5") 70 | }, 71 | 'test': { 72 | "data": os.path.join(args.destination_folder, "Data_test.h5"), 73 | "label": os.path.join(args.destination_folder, "Label_test.h5") 74 | } 75 | } 76 | 77 | convert_h5(args.data_dir, args.label_dir, args.data_split, f) 78 | --------------------------------------------------------------------------------