├── LICENSE ├── N4ITKforBRATS2017.py ├── README.md ├── braintools.py ├── braintools.pyc ├── data_loader.py ├── data_loader.pyc ├── dataset.py ├── dataset.pyc ├── loss.py ├── main.py ├── model.py ├── model.pyc ├── paths.py ├── paths.pyc ├── run_preporcessing.py ├── test.py ├── utils.py └── utils.pyc /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Po-Yu Kao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /N4ITKforBRATS2017.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu May 1 18:54:21 2018 5 | 6 | @author: pkao 7 | 8 | This code applies N4ITK for BRATS2018 database 9 | 10 | The output for this code should be name_of_mri_corrected.nii.gz 11 | """ 12 | import os 13 | from nipype.interfaces.ants import N4BiasFieldCorrection 14 | from multiprocessing import Pool 15 | 16 | def N4ITK(filepath): 17 | print 'Working on: '+filepath 18 | n4 = N4BiasFieldCorrection() 19 | n4.inputs.dimension = 3 20 | n4.inputs.input_image = filepath 21 | 22 | outputPath = filepath[:-7]+'_N4ITK_corrected.nii.gz' 23 | n4.inputs.output_image = outputPath 24 | 25 | n4.run() 26 | 27 | brats2017_training_path = '/media/pkao/Dataset/BraTS2018/training' 28 | 29 | t1_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(brats2017_training_path) 30 | for name in files if 't1' in name and 'ce' not in name and name.endswith('.nii.gz')] 31 | t1_filepaths.sort() 32 | 33 | t1ce_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(brats2017_training_path) 34 | for name in files if 't1ce' in name and name.endswith('.nii.gz')] 35 | t1ce_filepaths.sort() 36 | 37 | t2_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(brats2017_training_path) 38 | for name in files if 't2' in name and name.endswith('.nii.gz')] 39 | t2_filepaths.sort() 40 | 41 | flair_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(brats2017_training_path) 42 | for name in files if 'flair' in name and name.endswith('.nii.gz')] 43 | flair_filepaths.sort() 44 | 45 | 46 | file_paths = t1_filepaths + t1ce_filepaths + t2_filepaths + flair_filepaths 47 | 48 | pool = Pool(6) 49 | 50 | pool.map(N4ITK, file_paths) 51 | 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modified-3D-UNet-Pytorch 2 | This repository implements the modified 3D UNet architecture in pytorch from [Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge](https://arxiv.org/abs/1802.10508) Fabian Isensee et al. participating in BraTS2017. 3 | 4 | You are able to find the PyTorch version of the modified 3D UNet in `model.py` 5 | 6 | The model works but you have to work on data loading part and training part. 7 | 8 | Please refer to our [BraTS2018-tumor-segmentation](https://github.com/pykao/BraTS2018-tumor-segmentation) repo if you would like to use this model on brain tumor segmentation. 9 | -------------------------------------------------------------------------------- /braintools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Mar 13 16:06:53 2018 5 | 6 | @author: pkao 7 | 8 | This code contains several tools for BRATS2017 database 9 | """ 10 | 11 | import os 12 | import numpy as np 13 | import SimpleITK as sitk 14 | 15 | #def Brats2017FASTSegPathes(bratsPath): 16 | # ''' This code returns the FAST segment maps''' 17 | # seg_filepathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 18 | # for name in files if 'seg' in name and 'FAST' in name and 'MNI152' not in name and name.endswith('.nii.gz')] 19 | # seg_filepathes.sort() 20 | 21 | # return seg_filepathes 22 | 23 | 24 | 25 | #def BrainParcellationPathes(bratsPath, brainParcellationName): 26 | # '''This function gives you the location of brain parcellation mask in patient space ''' 27 | # brain_parcellation_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 28 | # for name in files if brainParcellationName in name and 'lab' not in name and 'threshold' in root and name.endswith('.nii.gz')] 29 | # brain_parcellation_pathes.sort() 30 | # return brain_parcellation_pathes 31 | 32 | def ReadImage(path): 33 | ''' This code returns the numpy nd array for a MR image at path''' 34 | return sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32) 35 | 36 | 37 | def ModalityMaximum(filepathes): 38 | ''' This code returns the maximum value for MR images''' 39 | modality_maximum = 0 40 | for i in range(len(filepathes)): 41 | temp_img = sitk.ReadImage(filepathes[i]) 42 | temp_nda = sitk.GetArrayFromImage(temp_img) 43 | temp_max = np.amax(temp_nda) 44 | if temp_max > modality_maximum: 45 | modality_maximum = temp_max 46 | print modality_maximum 47 | return modality_maximum 48 | 49 | def ModalityMinimum(filepathes): 50 | ''' This code returns the minimum value for MR images''' 51 | modality_minimum = 4000 52 | for i in range(len(filepathes)): 53 | temp_img = sitk.ReadImage(filepathes[i]) 54 | temp_nda = sitk.GetArrayFromImage(temp_img) 55 | temp_max = np.amin(temp_nda) 56 | if temp_max < modality_minimum: 57 | modality_minimum = temp_max 58 | print modality_minimum 59 | return modality_minimum 60 | 61 | #def Brats2017LesionsPathes(bratsPath): 62 | # ''' This function gives you the pathes for evey lesion files''' 63 | # necrosis_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 64 | # for name in files if 'MNI' not in name and 'necrosis' in name and name.endswith('.nii.gz')] 65 | # necrosis_pathes.sort() 66 | # edema_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 67 | # for name in files if 'MNI' not in name and 'edema' in name and name.endswith('.nii.gz')] 68 | # edema_pathes.sort() 69 | # enhancing_tumor_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 70 | # for name in files if 'MNI' not in name and 'enhancingTumor' in name and name.endswith('.nii.gz')] 71 | # enhancing_tumor_pathes.sort() 72 | # return necrosis_pathes, edema_pathes, enhancing_tumor_pathes 73 | 74 | #def Brats2017LesionsMNI152Pathes(bratsPath): 75 | # ''' This function gives you the pathes for evey lesion in MNI 152 space files''' 76 | # necrosis_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 77 | # for name in files if 'MNI' in name and 'necrosis' in name and name.endswith('.nii.gz')] 78 | # necrosis_pathes.sort() 79 | # edema_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 80 | # for name in files if 'MNI' in name and 'edema' in name and name.endswith('.nii.gz')] 81 | # edema_pathes.sort() 82 | # enhancing_tumor_pathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 83 | # for name in files if 'MNI' in name and 'enhancingTumor' in name and name.endswith('.nii.gz')] 84 | # enhancing_tumor_pathes.sort() 85 | # return necrosis_pathes, edema_pathes, enhancing_tumor_pathes 86 | 87 | def Brats2018FilePaths(bratsPath): 88 | ''' This fucntion gives the filepathes of all MR images with N4ITK and ground truth''' 89 | t1_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 90 | for name in files if 't1' in name and 'ce' not in name and 'corrected' in name 91 | and 'normalized' not in name and 'MNI152' not in name and name.endswith('.nii.gz')] 92 | t1_filepaths.sort() 93 | 94 | t1c_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 95 | for name in files if 't1' in name and 'ce' in name and 'corrected' in name 96 | and 'normalized' not in name and 'MNI152' not in name and name.endswith('.nii.gz')] 97 | t1c_filepaths.sort() 98 | 99 | t2_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 100 | for name in files if 't2' in name and 'corrected' in name and 'normalized' not in name 101 | and 'MNI152' not in name and name.endswith('.nii.gz')] 102 | t2_filepaths.sort() 103 | 104 | flair_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 105 | for name in files if 'flair' in name and 'corrected' in name and 'normalized' not in name 106 | and 'MNI152' not in name and name.endswith('.nii.gz')] 107 | flair_filepaths.sort() 108 | 109 | seg_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 110 | for name in files if 'seg' in name and 'FAST' not in name and 'MNI152' not in name and name.endswith('.nii.gz')] 111 | seg_filepaths.sort() 112 | 113 | assert (len(t1_filepathes) == len(t1c_filepathes) == len(t2_filepathes) == len(flair_filepathes) == len(seg_filepathes)), "The len of different image modalities are differnt!!!" 114 | 115 | return t1_filepaths, t1c_filepaths, t2_filepaths, flair_filepaths, seg_filepaths 116 | 117 | def Brats2018OriginalFilePaths(bratsPath): 118 | ''' This fucntion gives the filepathes of all original MR images and ground truth''' 119 | t1_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 120 | for name in files if 't1' in name and 'ce' not in name and 'corrected' not in name 121 | and 'normalized' not in name and 'MNI152' not in name and name.endswith('.nii.gz')] 122 | t1_filepaths.sort() 123 | 124 | t1c_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 125 | for name in files if 't1' in name and 'ce' in name and 'corrected' not in name 126 | and 'normalized' not in name and 'MNI152' not in name and name.endswith('.nii.gz')] 127 | t1c_filepaths.sort() 128 | 129 | t2_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 130 | for name in files if 't2' in name and 'corrected' not in name and 'normalized' not in name 131 | and 'MNI152' not in name and name.endswith('.nii.gz')] 132 | t2_filepaths.sort() 133 | 134 | flair_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 135 | for name in files if 'flair' in name and 'corrected' not in name and 'normalized' not in name 136 | and 'MNI152' not in name and name.endswith('.nii.gz')] 137 | flair_filepaths.sort() 138 | 139 | seg_filepaths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 140 | for name in files if 'seg' in name and 'FAST' not in name and 'MNI152' not in name and name.endswith('.nii.gz')] 141 | seg_filepaths.sort() 142 | 143 | assert len(t1_filepaths)==len(t1c_filepaths)==len(t2_filepaths)==len(flair_filepaths)==len(seg_filepaths), "Lengths are different!!!" 144 | 145 | return t1_filepaths, t1c_filepaths, t2_filepaths, flair_filepaths, seg_filepaths 146 | 147 | 148 | def BrainMaskPaths(bratsPath): 149 | ''' This function gives you the location of brain mask''' 150 | brain_mask_paths = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) for name in files if 'brainmask' in name and name.endswith('.nii.gz')] 151 | brain_mask_paths.sort() 152 | return brain_mask_paths 153 | 154 | def FindOneElement(s, ch): 155 | ''' This function gives the indexs of one element ch on the string s''' 156 | return [i for i, ltr in enumerate(s) if ltr == ch] 157 | 158 | def SubjectID(bratsPath): 159 | ''' This function gives you the subject ID''' 160 | #filepathes = [os.path.join(root, name) for root, dirs, files in os.walk(bratsPath) 161 | # for name in files if 't1' in name and 'ce' not in name and 'corrected' in name 162 | # and name.endswith('.nii.gz')] 163 | #filepathes.sort() 164 | #subjectID = [name[FindOneElement(name,'/')[-2]+1:FindOneElement(name,'/')[-1]] for name in filepathes] 165 | 166 | #return subjectID 167 | return bratsPath[FindOneElement(bratsPath,'/')[6]+1:FindOneElement(bratsPath,'/')[7]] 168 | 169 | def AllSubjectID(bratsPath): 170 | ''' This function gives you all subject IDs''' 171 | _, _, _, _, seg_filepathes = Brats2017FilePathes(bratsPath) 172 | subject_dirs = [os.path.split(seg_path)[0] for seg_path in seg_filepathes] 173 | all_subject_ID = [SubjectID(seg_path) for seg_path in seg_filepathes] 174 | return all_subject_ID 175 | 176 | 177 | -------------------------------------------------------------------------------- /braintools.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pykao/Modified-3D-UNet-Pytorch/63f0489e8d1fdd7ec6a203bcff095f12ea030824/braintools.pyc -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from builtins import object 17 | 18 | import numpy as np 19 | from abc import ABCMeta, abstractmethod 20 | 21 | 22 | class DataLoaderBase(object): 23 | """ Derive from this class and override generate_train_batch. If you don't want to use this you can use any 24 | generator. 25 | You can modify this class however you want. How the data is presented as batch is you responsibility. You can sample 26 | randomly, cycle through the training examples or sample the dtaa according to a specific pattern. Just make sure to 27 | use our default data structure! 28 | {'data':your_batch_of_shape_(b, c, x, y(, z)), 29 | 'seg':your_batch_of_shape_(b, c, x, y(, z)), 30 | 'anything_else1':whatever, 31 | 'anything_else2':whatever2, 32 | ...} 33 | 34 | (seg is optional) 35 | 36 | Args: 37 | data (anything): Your dataset. Stored as member variable self._data 38 | 39 | BATCH_SIZE (int): batch size. Stored as member variable self.BATCH_SIZE 40 | 41 | num_batches (int): How many batches will be generated before raising StopIteration. None=unlimited. Careful 42 | when using MultiThreadedAugmenter: Each process will produce num_batches batches. 43 | 44 | seed (False, None, int): seed to seed the numpy rng with. False = no seeding 45 | 46 | """ 47 | def __init__(self, data, BATCH_SIZE, num_batches=None, seed=False): 48 | __metaclass__ = ABCMeta 49 | self._data = data 50 | self.BATCH_SIZE = BATCH_SIZE 51 | self._num_batches = num_batches 52 | self._seed = seed 53 | self._resetted_rng = False 54 | self._iter_initialized = False 55 | self._p = None 56 | if self._num_batches is None: 57 | self._num_batches = 1e100 58 | self._batches_generated = 0 59 | 60 | def _initialize_iter(self): 61 | if self._seed is not False: 62 | np.random.seed(self._seed) 63 | self._iter_initialized = True 64 | 65 | def __iter__(self): 66 | return self 67 | 68 | def __next__(self): 69 | if not self._iter_initialized: 70 | self._initialize_iter() 71 | if self._batches_generated >= self._num_batches: 72 | self._iter_initialized = False 73 | raise StopIteration 74 | minibatch = self.generate_train_batch() 75 | self._batches_generated += 1 76 | return minibatch 77 | 78 | @abstractmethod 79 | def generate_train_batch(self): 80 | '''override this''' 81 | ''' 82 | Generate your batch from either self._train_data, self._validation_data or self._test_data. Make sure you 83 | generate the correct batch size (self.BATCH_SIZE) 84 | ''' 85 | pass 86 | -------------------------------------------------------------------------------- /data_loader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pykao/Modified-3D-UNet-Pytorch/63f0489e8d1fdd7ec6a203bcff095f12ea030824/data_loader.pyc -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cPickle 3 | import torch 4 | from torch.utils.data import Dataset 5 | from multiprocessing import Pool 6 | from shutil import copyfile 7 | 8 | import SimpleITK as sitk 9 | import numpy as np 10 | 11 | import paths 12 | from braintools import ReadImage 13 | from utils import reshape_by_padding_upper_coords, random_crop_3D_image_batched 14 | from data_loader import DataLoaderBase 15 | 16 | class BraTS2018List(Dataset): 17 | def __init__(self, data_path, random_crop=None, to_tensor=True, convert_labels=True): 18 | """ 19 | Args: 20 | data_path (string): Directory with all the numpy files, pkl files and id_name_conversion.txt file 21 | transform (callable, optional): Optional transform to be applied on a sample 22 | """ 23 | self.data_path = data_path 24 | self.random_crop = random_crop 25 | assert len(self.random_crop) == 3, "The random crop size should be (x, y, z)" 26 | #self.transform = transform 27 | #self.sample_size = sample_size 28 | self.npy_names = sorted([name for name in os.listdir(self.data_path) if os.path.isfile(os.path.join(self.data_path, name)) and name.endswith('.npy')]) 29 | self.pkl_names = sorted([name for name in os.listdir(self.data_path) if os.path.isfile(os.path.join(self.data_path, name)) and name.endswith('.pkl')]) 30 | self.id_name_conversion = np.loadtxt(os.path.join(data_path, "id_name_conversion.txt"), dtype="str") 31 | self.to_tensor = to_tensor 32 | self.convert_labels = convert_labels 33 | def __getitem__(self, index): 34 | npy_data = np.load(os.path.join(self.data_path, self.npy_names[index])) 35 | #print os.path.join(self.data_path, self.npy_names[index]) 36 | idxs = self.id_name_conversion[:, 1].astype(int) 37 | #print np.where(idxs == int(index)) 38 | sample = {} 39 | with open(os.path.join(self.data_path, self.pkl_names[index]), 'r') as f: 40 | dp = cPickle.load(f) 41 | sample['name'] = self.id_name_conversion[np.where(idxs == int(index))[0][0], 0] 42 | sample['index'] = self.id_name_conversion[np.where(idxs == int(index))[0][0], 1] 43 | sample['type'] = self.id_name_conversion[np.where(idxs == int(index))[0][0], 2] 44 | sample['orig_shp'] = dp['orig_shp'] 45 | #info['bbox_z'] = dp['bbox_z'] 46 | #info['bbox_y'] = dp['bbox_y'] 47 | #info['bbox_x'] = dp['bbox_x'] 48 | sample['spacing'] = dp['spacing'] 49 | sample['direction'] = dp['direction'] 50 | sample['origin'] = dp['origin'] 51 | image = npy_data[0:4, :] 52 | ori_label = npy_data[4, :] 53 | 54 | if self.convert_labels: 55 | new_label = convert_brats_seg(ori_label) 56 | else: 57 | new_label = np.copy(ori_label) 58 | 59 | 60 | if self.random_crop: 61 | """ 62 | Now only support random crop on certain size 63 | """ 64 | z, y, x = image.shape[1:] 65 | # shape of output 66 | new_z, new_y, new_x = (self.random_crop[0], self.random_crop[1], self.random_crop[2]) 67 | if new_z == z: 68 | z += 1 69 | if new_y == y: 70 | y += 1 71 | if new_x == x: 72 | x += 1 73 | axial = np.random.randint(0, z - new_z) 74 | coronal = np.random.randint(0, y - new_y) 75 | sagittal = np.random.randint(0, x - new_x) 76 | image = image[:, axial: axial + new_z, coronal: coronal + new_y, sagittal: sagittal + new_x] 77 | new_label = new_label[axial: axial + new_z, coronal: coronal + new_y, sagittal: sagittal + new_x] 78 | 79 | if self.to_tensor: 80 | sample['data'] = torch.from_numpy(image) 81 | sample['seg'] = torch.from_numpy(new_label) 82 | else: 83 | sample['data'] = image 84 | sample['seg'] = new_label 85 | 86 | return sample 87 | 88 | def __len__(self): 89 | # of how many subjects it has 90 | return len(self.npy_names) 91 | 92 | 93 | def extract_brain_region(image, brain_mask, background=0): 94 | ''' find the boundary of the brain region, return the resized brain image and the index of the boundaries''' 95 | brain = np.where(brain_mask != background) 96 | #print brain 97 | min_z = int(np.min(brain[0])) 98 | max_z = int(np.max(brain[0]))+1 99 | min_y = int(np.min(brain[1])) 100 | max_y = int(np.max(brain[1]))+1 101 | min_x = int(np.min(brain[2])) 102 | max_x = int(np.max(brain[2]))+1 103 | # resize image 104 | resizer = (slice(min_z, max_z), slice(min_y, max_y), slice(min_x, max_x)) 105 | return image[resizer], [[min_z, max_z], [min_y, max_y], [min_x, max_x]] 106 | 107 | 108 | def cut_off_values_upper_lower_percentile(image, mask=None, percentile_lower=0.2, percentile_upper=99.8): 109 | if mask is None: 110 | mask = image != image[0, 0, 0] 111 | cut_off_lower = np.percentile(image[mask != 0].ravel(), percentile_lower) 112 | cut_off_upper = np.percentile(image[mask != 0].ravel(), percentile_upper) 113 | #print cut_off_lower, cut_off_upper 114 | res = np.copy(image) 115 | res[(res < cut_off_lower) & (mask !=0)] = cut_off_lower 116 | res[(res > cut_off_upper) & (mask !=0)] = cut_off_upper 117 | return res 118 | 119 | 120 | def run(folder, out_folder, pat_id, name, return_if_no_seg=True, N4ITK = True): 121 | print pat_id, name 122 | if N4ITK: 123 | t1_path = os.path.join(folder, "%s_t1_N4ITK_corrected.nii.gz" % name) 124 | t1ce_path = os.path.join(folder, "%s_t1ce_N4ITK_corrected.nii.gz" % name) 125 | t2_path = os.path.join(folder, "%s_t2_N4ITK_corrected.nii.gz" % name) 126 | flair_path = os.path.join(folder, "%s_flair_N4ITK_corrected.nii.gz" % name) 127 | if not N4ITK: 128 | t1_path = os.path.join(folder, "%s_t1.nii.gz" % name) 129 | t1ce_path = os.path.join(folder, "%s_t1ce.nii.gz" % name) 130 | t2_path = os.path.join(folder, "%s_t2.nii.gz" % name) 131 | flair_path = os.path.join(folder, "%s_flair.nii.gz" % name) 132 | seg_path = os.path.join(folder, "%s_seg.nii.gz" %name) 133 | if not os.path.isfile(t1_path): 134 | print "T1 file does not exist" 135 | return 136 | if not os.path.isfile(t1ce_path): 137 | print "T1ce file does not exist" 138 | return 139 | if not os.path.isfile(t2_path): 140 | print "T2 file does not exist" 141 | return 142 | if not os.path.isfile(flair_path): 143 | print "Flair file does not exist" 144 | return 145 | if not os.path.isfile(seg_path): 146 | if return_if_no_seg: 147 | print "Seg file does not exist" 148 | return 149 | t1_nda = ReadImage(t1_path) 150 | t1_img = sitk.ReadImage(t1_path) 151 | t1ce_nda = ReadImage(t1ce_path) 152 | t2_nda = ReadImage(t2_path) 153 | flair_nda = ReadImage(flair_path) 154 | #print t1_nda.shape, t1ce_nda.shape, t2_nda.shape, flair_nda.shape 155 | try: 156 | seg_nda = ReadImage(seg_path) 157 | except RuntimeError: 158 | seg_nda = np.zeros(t1_nda.shape, dtype = np.float32) 159 | except IOError: 160 | seg_nda = np.zeros(t1_nda.shape, dtype = np.float32) 161 | 162 | original_shape = t1_nda.shape 163 | 164 | brain_mask = (t1_nda != t1_nda[0, 0, 0]) & (t1ce_nda != t1ce_nda[0, 0, 0]) & (t2_nda != t2_nda[0, 0, 0]) & (flair_nda != flair_nda[0, 0, 0]) 165 | 166 | resized_t1_nda, bbox = extract_brain_region(t1_nda, brain_mask, 0) 167 | resized_t1ce_nda, bbox1 = extract_brain_region(t1ce_nda, brain_mask, 0) 168 | resized_t2_nda, bbox2 = extract_brain_region(t2_nda, brain_mask, 0) 169 | resized_flair_nda, bbox3 = extract_brain_region(flair_nda, brain_mask, 0) 170 | resized_seg_nda, bbox4 = extract_brain_region(seg_nda, brain_mask, 0) 171 | assert bbox == bbox1 == bbox2 == bbox3 == bbox4 172 | assert resized_t1_nda.shape == resized_t1ce_nda.shape == resized_t2_nda.shape == resized_flair_nda.shape 173 | 174 | with open(os.path.join(out_folder, "%03.0d.pkl" % pat_id), 'w') as f: 175 | dp = {} 176 | dp['orig_shp'] = original_shape 177 | dp['bbox_z'] = bbox[0] 178 | dp['bbox_y'] = bbox[1] 179 | dp['bbox_x'] = bbox[2] 180 | dp['spacing'] = t1_img.GetSpacing() 181 | dp['direction'] = t1_img.GetDirection() 182 | dp['origin'] = t1_img.GetOrigin() 183 | cPickle.dump(dp, f) 184 | 185 | # setting the cut off threshold 186 | cut_off_threshold = 2.0 187 | 188 | t1_msk = resized_t1_nda != 0 189 | t1_tmp = cut_off_values_upper_lower_percentile(resized_t1_nda, t1_msk, cut_off_threshold, 100.0 - cut_off_threshold) 190 | normalized_resized_t1_nda = np.copy(resized_t1_nda) 191 | normalized_resized_t1_nda[t1_msk] = (resized_t1_nda[t1_msk] - t1_tmp[t1_msk].mean()) / t1_tmp[t1_msk].std() 192 | 193 | t1ce_msk = resized_t1ce_nda != 0 194 | t1ce_tmp = cut_off_values_upper_lower_percentile(resized_t1ce_nda, t1ce_msk, cut_off_threshold, 100.0 - cut_off_threshold) 195 | normalized_resized_t1ce_nda = np.copy(resized_t1ce_nda) 196 | normalized_resized_t1ce_nda[t1ce_msk] = (resized_t1ce_nda[t1ce_msk] - t1ce_tmp[t1ce_msk].mean()) / t1ce_tmp[t1ce_msk].std() 197 | 198 | t2_msk = resized_t2_nda != 0 199 | t2_tmp = cut_off_values_upper_lower_percentile(resized_t2_nda, t2_msk, cut_off_threshold, 100.0 - cut_off_threshold) 200 | normalized_resized_t2_nda = np.copy(resized_t2_nda) 201 | normalized_resized_t2_nda[t2_msk] = (resized_t2_nda[t2_msk] - t2_tmp[t2_msk].mean()) / t2_tmp[t2_msk].std() 202 | 203 | flair_msk = resized_flair_nda != 0 204 | flair_tmp = cut_off_values_upper_lower_percentile(resized_flair_nda, flair_msk, cut_off_threshold, 100.0 - cut_off_threshold) 205 | normalized_resized_flair_nda = np.copy(resized_flair_nda) 206 | normalized_resized_flair_nda[flair_msk] = (resized_flair_nda[flair_msk] - flair_tmp[flair_msk].mean()) / flair_tmp[flair_msk].std() 207 | 208 | shp = resized_t1_nda.shape 209 | #print shp 210 | 211 | new_shape = np.array([128, 128, 128]) 212 | pad_size = np.max(np.vstack((new_shape, np.array(shp))), 0) 213 | #print pad_size 214 | new_t1_nda = reshape_by_padding_upper_coords(normalized_resized_t1_nda, pad_size, 0) 215 | new_t1ce_nda = reshape_by_padding_upper_coords(normalized_resized_t1ce_nda, pad_size, 0) 216 | new_t2_nda = reshape_by_padding_upper_coords(normalized_resized_t2_nda, pad_size, 0) 217 | new_flair_nda = reshape_by_padding_upper_coords(normalized_resized_flair_nda, pad_size, 0) 218 | new_seg_nda = reshape_by_padding_upper_coords(resized_seg_nda, pad_size, 0) 219 | #print new_t1_nda.shape, new_t1ce_nda.shape, new_t2_nda.shape, new_flair_nda.shape, new_seg_nda.shape 220 | number_of_data = 5 221 | #print [number_of_data]+list(new_t1_nda.shape) 222 | 223 | all_data = np.zeros([number_of_data]+list(new_t1_nda.shape), dtype=np.float32) 224 | #print all_data.shape 225 | all_data[0] = new_t1_nda 226 | all_data[1] = new_t1ce_nda 227 | all_data[2] = new_t2_nda 228 | all_data[3] = new_flair_nda 229 | all_data[4] = new_seg_nda 230 | np.save(os.path.join(out_folder, "%03.0d" % pat_id), all_data) 231 | 232 | 233 | def run_star(args): 234 | return run(*args) 235 | 236 | 237 | def run_preprocessing_BraTS2018_training(training_data_location=paths.raw_training_data_folder, folder_out=paths.preprocessed_training_data_folder, N4ITK=True): 238 | if not os.path.isdir(folder_out): os.mkdir(folder_out) 239 | ctr = 0 240 | id_name_conversion = [] 241 | for f in ("HGG", "LGG"): 242 | fld = os.path.join(training_data_location, f) 243 | patients = os.listdir(fld) 244 | patients.sort() 245 | #print len(patients) 246 | fldrs = [os.path.join(fld, pt) for pt in patients] 247 | #print fldrs 248 | p = Pool(7) 249 | p.map(run_star, zip(fldrs, [folder_out]*len(patients), range(ctr, ctr + len(patients)), patients, len(patients) * [True], len(patients) * [N4ITK])) 250 | p.close() 251 | p.join() 252 | 253 | for i, j in zip(patients, range(ctr, ctr+len(patients))): 254 | id_name_conversion.append([i, j, f]) 255 | ctr += (ctr+len(patients)) 256 | id_name_conversion = np.vstack(id_name_conversion) 257 | np.savetxt(os.path.join(folder_out, "id_name_conversion.txt"), id_name_conversion, fmt="%s") 258 | copyfile(os.path.join(training_data_location, "survival_data.csv"), os.path.join(folder_out, "survival_data.csv")) 259 | 260 | 261 | def run_preprocessing_BraTS2018_validationOrTesting(original_data_location=paths.raw_validation_data_folder, folder_out=paths.preprocessed_validation_data_folder, N4ITK=True): 262 | if not os.path.isdir(folder_out): os.mkdir(folder_out) 263 | ctr = 0 264 | id_name_conversion = [] 265 | patients = os.listdir(original_data_location) 266 | patients.sort() 267 | #print len(patients) 268 | fldrs = [os.path.join(fld, pt) for pt in patients] 269 | #print fldrs 270 | p = Pool(7) 271 | p.map(run_star, zip(fldrs, [folder_out]*len(patients), range(ctr, ctr + len(patients)), patients, len(patients) * [False], len(patients) * [N4ITK])) 272 | p.close() 273 | p.join() 274 | 275 | for i, j in zip(patients, range(ctr, ctr+len(patients))): 276 | id_name_conversion.append([i, j, 'unknown']) 277 | ctr += (ctr+len(patients)) 278 | id_name_conversion = np.vstack(id_name_conversion) 279 | np.savetxt(os.path.join(folder_out, "id_name_conversion.txt"), id_name_conversion, fmt="%s") 280 | copyfile(os.path.join(training_data_location, "survival_data.csv"), os.path.join(folder_out, "survival_data.csv")) 281 | 282 | 283 | def load_dataset(pat_ids = range(285), folder=paths.preprocessed_training_data_folder): 284 | id_name_conversion = np.loadtxt(os.path.join(folder, "id_name_conversion.txt"), dtype="str") 285 | #print id_name_conversion[0] 286 | idxs = id_name_conversion[:, 1].astype(int) 287 | #print idxs 288 | dataset = {} 289 | for pat in pat_ids: 290 | if os.path.isfile(os.path.join(folder, "%03.0d.npy" % pat)): 291 | dataset[pat] = {} 292 | dataset[pat]['data'] = np.load(os.path.join(folder, "%03.0d.npy" %pat), mmap_mode='r') 293 | dataset[pat]['idx'] = pat 294 | dataset[pat]['name'] = id_name_conversion[np.where(idxs == pat)[0][0], 0] 295 | dataset[pat]['type'] = id_name_conversion[np.where(idxs == pat)[0][0], 2] 296 | with open(os.path.join(folder, "%03.0d.pkl" % pat), 'r') as f: 297 | dp = cPickle.load(f) 298 | dataset[pat]['orig_shp'] = dp['orig_shp'] 299 | dataset[pat]['bbox_z'] = dp['bbox_z'] 300 | dataset[pat]['bbox_x'] = dp['bbox_x'] 301 | dataset[pat]['bbox_y'] = dp['bbox_y'] 302 | dataset[pat]['spacing'] = dp['spacing'] 303 | dataset[pat]['direction'] = dp['direction'] 304 | dataset[pat]['origin'] = dp['origin'] 305 | return dataset 306 | 307 | 308 | def convert_brats_seg(seg): 309 | new_seg = np.zeros(seg.shape, seg.dtype) 310 | new_seg[seg == 1] = 1 311 | new_seg[seg == 2] = 2 312 | # convert label 4 enhancing tumor to label 3 313 | new_seg[seg == 4] = 3 314 | return new_seg 315 | 316 | 317 | def convert_to_brats_seg(seg): 318 | new_seg = np.zeros(seg.shape, seg.dtype) 319 | new_seg[seg == 1] = 2 320 | new_seg[seg == 2] = 4 321 | # convert label 3 back to label 4 enhancing tumor 322 | new_seg[seg == 3] = 4 323 | return new_seg 324 | 325 | # Their code to generate 3D random batch for training 326 | class BatchGenerator3D_random_sampling(DataLoaderBase): 327 | def __init__(self, data, BATCH_SIZE, num_batches, seed, patch_size=(128, 128, 128), convert_labels=False): 328 | self.convert_labels = convert_labels 329 | self._patch_size = patch_size 330 | DataLoaderBase.__init__(self, data, BATCH_SIZE, num_batches, seed) 331 | 332 | def generate_train_batch(self): 333 | ids = np.random.choice(self._data.keys(), self.BATCH_SIZE) 334 | data = np.zeros((self.BATCH_SIZE, 4, self._patch_size[0], self._patch_size[1], self._patch_size[2]), 335 | dtype=np.float32) 336 | seg = np.zeros((self.BATCH_SIZE, 1, self._patch_size[0], self._patch_size[1], self._patch_size[2]), 337 | dtype=np.float32) 338 | types = [] 339 | patient_names = [] 340 | identifiers = [] 341 | ages = [] 342 | survivals = [] 343 | for j, i in enumerate(ids): 344 | types.append(self._data[i]['type']) 345 | patient_names.append(self._data[i]['name']) 346 | identifiers.append(self._data[i]['idx']) 347 | # construct a batch, not very efficient 348 | data_all = self._data[i]['data'][None] 349 | if np.any(np.array(data_all.shape[2:]) - np.array(self._patch_size) < 0): 350 | new_shp = np.max(np.vstack((np.array(data_all.shape[2:])[None], np.array(self._patch_size)[None])), 0) 351 | data_all = resize_image_by_padding_batched(data_all, new_shp, 0) 352 | data_all = random_crop_3D_image_batched(data_all, self._patch_size) 353 | data[j, :] = data_all[0, :4] 354 | if self.convert_labels: 355 | seg[j, 0] = convert_brats_seg(data_all[0, 4]) 356 | else: 357 | seg[j, 0] = data_all[0, 4] 358 | if 'survival' in self._data[i].keys(): 359 | survivals.append(self._data[i]['survival']) 360 | else: 361 | survivals.append(np.nan) 362 | if 'age' in self._data[i].keys(): 363 | ages.append(self._data[i]['age']) 364 | else: 365 | ages.append(np.nan) 366 | return {'data': data, 'seg': seg, "idx": ids, "grades": types, "identifiers": identifiers, "patient_names": patient_names, 'survival':survivals, 'age':ages} -------------------------------------------------------------------------------- /dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pykao/Modified-3D-UNet-Pytorch/63f0489e8d1fdd7ec6a203bcff095f12ea030824/dataset.pyc -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from itertools import repeat 4 | import numpy as np 5 | 6 | # Intersection = dot(A, B) 7 | # Union = dot(A, A) + dot(B, B) 8 | # The Dice loss function is defined as 9 | # 1/2 * intersection / union 10 | # 11 | # The derivative is 2[(union * target - 2 * intersect * input) / union^2] 12 | 13 | class DiceLoss(Function): 14 | def __init__(self, *args, **kwargs): 15 | pass 16 | 17 | def forward(self, input, target, save=True): 18 | if save: 19 | self.save_for_backward(input, target) 20 | eps = 0.000001 21 | _, result_ = input.max(1) 22 | result_ = torch.squeeze(result_) 23 | if input.is_cuda: 24 | result = torch.cuda.FloatTensor(result_.size()) 25 | self.target_ = torch.cuda.FloatTensor(target.size()) 26 | else: 27 | result = torch.FloatTensor(result_.size()) 28 | self.target_ = torch.FloatTensor(target.size()) 29 | result.copy_(result_) 30 | self.target_.copy_(target) 31 | target = self.target_ 32 | # print(input) 33 | intersect = torch.dot(result, target) 34 | # binary values so sum the same as sum of squares 35 | result_sum = torch.sum(result) 36 | target_sum = torch.sum(target) 37 | union = result_sum + target_sum + (2*eps) 38 | 39 | # the target volume can be empty - so we still want to 40 | # end up with a score of 1 if the result is 0/0 41 | IoU = intersect / union 42 | print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format( 43 | union, intersect, target_sum, result_sum, 2*IoU)) 44 | out = torch.FloatTensor(1).fill_(2*IoU) 45 | self.intersect, self.union = intersect, union 46 | return out 47 | 48 | def backward(self, grad_output): 49 | input, _ = self.saved_tensors 50 | intersect, union = self.intersect, self.union 51 | target = self.target_ 52 | gt = torch.div(target, union) 53 | IoU2 = intersect/(union*union) 54 | pred = torch.mul(input[:, 1], IoU2) 55 | dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4)) 56 | grad_input = torch.cat((torch.mul(dDice, -grad_output[0]), 57 | torch.mul(dDice, grad_output[0])), 0) 58 | return grad_input , None 59 | 60 | def dice_loss(input, target): 61 | return DiceLoss()(input, target) 62 | 63 | def dice_error(input, target): 64 | eps = 0.000001 65 | _, result_ = input.max(1) 66 | result_ = torch.squeeze(result_) 67 | if input.is_cuda: 68 | result = torch.cuda.FloatTensor(result_.size()) 69 | target_ = torch.cuda.FloatTensor(target.size()) 70 | else: 71 | result = torch.FloatTensor(result_.size()) 72 | target_ = torch.FloatTensor(target.size()) 73 | result.copy_(result_.data) 74 | target_.copy_(target.data) 75 | target = target_ 76 | intersect = torch.dot(result, target) 77 | 78 | result_sum = torch.sum(result) 79 | target_sum = torch.sum(target) 80 | union = result_sum + target_sum + 2*eps 81 | intersect = np.max([eps, intersect]) 82 | # the target volume can be empty - so we still want to 83 | # end up with a score of 1 if the result is 0/0 84 | IoU = intersect / union 85 | # print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format( 86 | # union, intersect, target_sum, result_sum, 2*IoU)) 87 | return 2*IoU -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.optim 11 | from torch.utils.data import DataLoader 12 | from sklearn.model_selection import train_test_split 13 | 14 | from dataset import load_dataset, BraTS2018List 15 | from model import Modified3DUNet 16 | import paths 17 | 18 | def datestr(): 19 | now = time.localtime() 20 | return '{:04}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min) 21 | 22 | #print datestr() 23 | 24 | 25 | 26 | # Training setting 27 | parser = argparse.ArgumentParser(description='PyTorch Modified 3D U-Net Training') 28 | #parser.add_argument('-m', '--modality', default='T1', choices = ['T1', 'T1c', 'T2', 'FLAIR'], 29 | # type = str, help='modality of input 3d images (default:T1)') 30 | #parser.add_argument('-w', '--workers', default=8, type=int, 31 | # help='number of data loading workers (default: 8)') 32 | parser.add_argument('--epochs', default=300, type=int, 33 | help='number of total epochs to run (default: 300)') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=2, type=int, 37 | help='batch size (default: 2)') 38 | parser.add_argument('-g', '--gpu', default='0', type=str) 39 | parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, 40 | help='initial learning rate (default:5e-4)') 41 | parser.add_argument('--momentum', default=0.9, type=float, 42 | help='momentum (default: 0.9)') 43 | parser.add_argument('--weight-decay', '--wd', default=985e-3, type=float, 44 | help='weight decay (default: 985e-3)') 45 | parser.add_argument('--print-freq', '-p', default=100, type=int, 46 | help='print frequency (default: 100)') 47 | parser.add_argument('-d', '--data', default=paths.preprocessed_training_data_folder, 48 | type=str, help='The location of BRATS2015') 49 | 50 | log_file = os.path.join("train_log.txt") 51 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', filename=log_file) 52 | 53 | console = logging.StreamHandler() 54 | console.setLevel(logging.INFO) 55 | console.setFormatter(logging.Formatter('%(asctime)s %(message)s')) 56 | logging.getLogger('').addHandler(console) 57 | 58 | 59 | global args, best_loss 60 | best_loss = float('inf') 61 | args = parser.parse_args() 62 | #print os.environ['CUDA_VISIBLE_DEVICES'] 63 | dtype = torch.float 64 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 65 | # input = data.to(device) 66 | 67 | # Loading the model 68 | in_channels = 4 69 | n_classes = 4 70 | base_n_filter = 16 71 | model = Modified3DUNet(in_channels, n_classes, base_n_filter).to(device) 72 | #print args.data 73 | 74 | 75 | # Split the training and testing dataset 76 | 77 | test_size = 0.1 78 | train_idx, test_idx = train_test_split(range(285), test_size = test_size) 79 | train_data = load_dataset(train_idx) 80 | test_data = load_dataset(test_idx) 81 | 82 | 83 | #print all_data.keys() 84 | # create your optimizer 85 | #optimizer = optim.adam(net.parameteres(), lr=) 86 | ''' 87 | # in training loop: 88 | optimizer.zero_grad() 89 | output = net(input) 90 | loss = criterion(output, target) 91 | loss.backward() 92 | optimizer.step 93 | ''' 94 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class Modified3DUNet(nn.Module): 5 | def __init__(self, in_channels, n_classes, base_n_filter = 8): 6 | super(Modified3DUNet, self).__init__() 7 | self.in_channels = in_channels 8 | self.n_classes = n_classes 9 | self.base_n_filter = base_n_filter 10 | 11 | self.lrelu = nn.LeakyReLU() 12 | self.dropout3d = nn.Dropout3d(p=0.6) 13 | self.upsacle = nn.Upsample(scale_factor=2, mode='nearest') 14 | self.softmax = nn.Softmax(dim=1) 15 | 16 | # Level 1 context pathway 17 | self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter) 20 | self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter) 21 | 22 | # Level 2 context pathway 23 | self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter*2, kernel_size=3, stride=2, padding=1, bias=False) 24 | self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter*2, self.base_n_filter*2) 25 | self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter*2) 26 | 27 | # Level 3 context pathway 28 | self.conv3d_c3 = nn.Conv3d(self.base_n_filter*2, self.base_n_filter*4, kernel_size=3, stride=2, padding=1, bias=False) 29 | self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter*4, self.base_n_filter*4) 30 | self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter*4) 31 | 32 | # Level 4 context pathway 33 | self.conv3d_c4 = nn.Conv3d(self.base_n_filter*4, self.base_n_filter*8, kernel_size=3, stride=2, padding=1, bias=False) 34 | self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter*8, self.base_n_filter*8) 35 | self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter*8) 36 | 37 | # Level 5 context pathway, level 0 localization pathway 38 | self.conv3d_c5 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*16, kernel_size=3, stride=2, padding=1, bias=False) 39 | self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter*16, self.base_n_filter*16) 40 | self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*8) 41 | 42 | self.conv3d_l0 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*8, kernel_size = 1, stride=1, padding=0, bias=False) 43 | self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter*8) 44 | 45 | # Level 1 localization pathway 46 | self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*16) 47 | self.conv3d_l1 = nn.Conv3d(self.base_n_filter*16, self.base_n_filter*8, kernel_size=1, stride=1, padding=0, bias=False) 48 | self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*4) 49 | 50 | # Level 2 localization pathway 51 | self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*8) 52 | self.conv3d_l2 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*4, kernel_size=1, stride=1, padding=0, bias=False) 53 | self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*2) 54 | 55 | # Level 3 localization pathway 56 | self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*4) 57 | self.conv3d_l3 = nn.Conv3d(self.base_n_filter*4, self.base_n_filter*2, kernel_size=1, stride=1, padding=0, bias=False) 58 | self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter) 59 | 60 | # Level 4 localization pathway 61 | self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter*2) 62 | self.conv3d_l4 = nn.Conv3d(self.base_n_filter*2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False) 63 | 64 | self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter*8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False) 65 | self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter*4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False) 66 | 67 | 68 | 69 | 70 | def conv_norm_lrelu(self, feat_in, feat_out): 71 | return nn.Sequential( 72 | nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), 73 | nn.InstanceNorm3d(feat_out), 74 | nn.LeakyReLU()) 75 | 76 | def norm_lrelu_conv(self, feat_in, feat_out): 77 | return nn.Sequential( 78 | nn.InstanceNorm3d(feat_in), 79 | nn.LeakyReLU(), 80 | nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) 81 | 82 | def lrelu_conv(self, feat_in, feat_out): 83 | return nn.Sequential( 84 | nn.LeakyReLU(), 85 | nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) 86 | 87 | def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out): 88 | return nn.Sequential( 89 | nn.InstanceNorm3d(feat_in), 90 | nn.LeakyReLU(), 91 | nn.Upsample(scale_factor=2, mode='nearest'), 92 | # should be feat_in*2 or feat_in 93 | nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), 94 | nn.InstanceNorm3d(feat_out), 95 | nn.LeakyReLU()) 96 | 97 | def forward(self, x): 98 | # Level 1 context pathway 99 | out = self.conv3d_c1_1(x) 100 | residual_1 = out 101 | out = self.lrelu(out) 102 | out = self.conv3d_c1_2(out) 103 | out = self.dropout3d(out) 104 | out = self.lrelu_conv_c1(out) 105 | # Element Wise Summation 106 | out += residual_1 107 | context_1 = self.lrelu(out) 108 | out = self.inorm3d_c1(out) 109 | out = self.lrelu(out) 110 | 111 | # Level 2 context pathway 112 | out = self.conv3d_c2(out) 113 | residual_2 = out 114 | out = self.norm_lrelu_conv_c2(out) 115 | out = self.dropout3d(out) 116 | out = self.norm_lrelu_conv_c2(out) 117 | out += residual_2 118 | out = self.inorm3d_c2(out) 119 | out = self.lrelu(out) 120 | context_2 = out 121 | 122 | # Level 3 context pathway 123 | out = self.conv3d_c3(out) 124 | residual_3 = out 125 | out = self.norm_lrelu_conv_c3(out) 126 | out = self.dropout3d(out) 127 | out = self.norm_lrelu_conv_c3(out) 128 | out += residual_3 129 | out = self.inorm3d_c3(out) 130 | out = self.lrelu(out) 131 | context_3 = out 132 | 133 | # Level 4 context pathway 134 | out = self.conv3d_c4(out) 135 | residual_4 = out 136 | out = self.norm_lrelu_conv_c4(out) 137 | out = self.dropout3d(out) 138 | out = self.norm_lrelu_conv_c4(out) 139 | out += residual_4 140 | out = self.inorm3d_c4(out) 141 | out = self.lrelu(out) 142 | context_4 = out 143 | 144 | # Level 5 145 | out = self.conv3d_c5(out) 146 | residual_5 = out 147 | out = self.norm_lrelu_conv_c5(out) 148 | out = self.dropout3d(out) 149 | out = self.norm_lrelu_conv_c5(out) 150 | out += residual_5 151 | out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out) 152 | 153 | out = self.conv3d_l0(out) 154 | out = self.inorm3d_l0(out) 155 | out = self.lrelu(out) 156 | 157 | # Level 1 localization pathway 158 | out = torch.cat([out, context_4], dim=1) 159 | out = self.conv_norm_lrelu_l1(out) 160 | out = self.conv3d_l1(out) 161 | out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out) 162 | 163 | # Level 2 localization pathway 164 | out = torch.cat([out, context_3], dim=1) 165 | out = self.conv_norm_lrelu_l2(out) 166 | ds2 = out 167 | out = self.conv3d_l2(out) 168 | out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out) 169 | 170 | # Level 3 localization pathway 171 | out = torch.cat([out, context_2], dim=1) 172 | out = self.conv_norm_lrelu_l3(out) 173 | ds3 = out 174 | out = self.conv3d_l3(out) 175 | out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out) 176 | 177 | # Level 4 localization pathway 178 | out = torch.cat([out, context_1], dim=1) 179 | out = self.conv_norm_lrelu_l4(out) 180 | out_pred = self.conv3d_l4(out) 181 | 182 | ds2_1x1_conv = self.ds2_1x1_conv3d(ds2) 183 | ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv) 184 | ds3_1x1_conv = self.ds3_1x1_conv3d(ds3) 185 | ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv 186 | ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum) 187 | 188 | out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale 189 | seg_layer = out 190 | out = out.permute(0, 2, 3, 4, 1).contiguous().view(-1, self.n_classes) 191 | #out = out.view(-1, self.n_classes) 192 | out = self.softmax(out) 193 | return out, seg_layer -------------------------------------------------------------------------------- /model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pykao/Modified-3D-UNet-Pytorch/63f0489e8d1fdd7ec6a203bcff095f12ea030824/model.pyc -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | raw_training_data_folder = "/media/pkao/Dataset/BraTS2018/training" 2 | raw_validation_data_folder = "/media/pkao/Dataset/BraTS/2017/Brats17ValidationData" 3 | raw_testing_data_folder = "/media/pkao/Datase/BraTS/2017/Brats17TestingData" 4 | 5 | preprocessed_training_data_folder = "/media/pkao/Dataset/DeepLearningData/BraTS_2018_train" 6 | preprocessed_validation_data_folder = "/media/pkao/Dataset/DeepLearningData/BraTS_2017_val" 7 | preprocessed_testing_data_folder = "/media/pkao/Dataset/DeepLearningData/datasets/BraTS_2017_test" 8 | 9 | #results_folder = "/home/pkao/PhD/results/BraTS_2017_lasagne/" # where to save the network training and validation files 10 | -------------------------------------------------------------------------------- /paths.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pykao/Modified-3D-UNet-Pytorch/63f0489e8d1fdd7ec6a203bcff095f12ea030824/paths.pyc -------------------------------------------------------------------------------- /run_preporcessing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataset import run_preprocessing_BraTS2018_training #, run_preprocessing_BraTS2018_validationOrTesting 3 | import paths 4 | print 'start' 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-m", "--mode", help="train for training set, val for validation set, and test for testing set", type=str) 7 | args = parser.parse_args() 8 | print args.mode 9 | 10 | if args.mode == "train": 11 | run_preprocessing_BraTS2018_training(paths.raw_training_data_folder, paths.preprocessed_training_data_folder) 12 | #elif args.mode == "val": 13 | # run_preprocessing_BraTS2017_trainSet(paths.raw_validation_data_folder, paths.preprocessed_validation_data_folder) 14 | #elif args.mode == "test": 15 | # run_preprocessing_BraTS2017_trainSet(paths.raw_testing_data_folder, paths.preprocessed_testing_data_folder) 16 | else: 17 | raise ValueError("Unknown value for --mode. Use \"train\", \"test\" or \"val\"") -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | import cPickle as pickle 4 | import os 5 | import pprint 6 | from dataset import load_dataset, BraTS2018List, BatchGenerator3D_random_sampling 7 | import paths 8 | from torch.utils.data import DataLoader 9 | from sklearn.model_selection import train_test_split 10 | 11 | 12 | def testDataPreprocessing(pat_id = 0): 13 | 14 | train_dataset = load_BraTS2018_dataset() 15 | 16 | example_nda = train_dataset[pat_id]['data'] 17 | 18 | print train_dataset[pat_id]['name'], train_dataset[pat_id]['type'] 19 | 20 | t1_nda = example_nda[0, :] 21 | 22 | t1ce_nda = example_nda[1, :] 23 | 24 | t2_nda = example_nda[2, :] 25 | 26 | flair_nda = example_nda[3, :] 27 | 28 | seg_nda = example_nda[4, :] 29 | 30 | t1_img = sitk.GetImageFromArray(t1_nda) 31 | sitk.WriteImage(t1_img, './t1.nii.gz') 32 | t1ce_img = sitk.GetImageFromArray(t1ce_nda) 33 | sitk.WriteImage(t1ce_img, './t1ce.nii.gz') 34 | t2_img = sitk.GetImageFromArray(t2_nda) 35 | sitk.WriteImage(t2_img, './t2.nii.gz') 36 | flair_img = sitk.GetImageFromArray(flair_nda) 37 | sitk.WriteImage(flair_img, './flair.nii.gz') 38 | seg_img = sitk.GetImageFromArray(seg_nda) 39 | sitk.WriteImage(seg_img, './seg.nii.gz') 40 | 41 | #data_path = paths.preprocessed_training_data_folder 42 | #dataset = BraTS2018List(data_path=data_path, random_crop=(128, 128, 128)) 43 | #sample = dataset[42] 44 | #print sample['data'].shape 45 | #print sample['data'].type(), sample['seg'].type() 46 | #print sample 47 | #dataloader = DataLoader(dataset, batch_size=2, shuffle=True) 48 | #for i_batch, sample_batch in enumerate(dataloader): 49 | # print(i_batch, sample_batch['name'], sample_batch['data'].size()) 50 | 51 | all_data = load_dataset() 52 | keys_sorted = np.sort(all_data.keys()) 53 | #print all_data.keys() 54 | 55 | train_idx, valid_idx = train_test_split(all_data.keys(), train_size = 0.9) 56 | 57 | print train_idx 58 | 59 | #train_keys = [keys_sorted[i] for i in train_idx] 60 | #valid_keys = [keys_sorted[i] for i in valid_idx] 61 | 62 | #print train_keys 63 | 64 | #train_data = {i:all_data[i] for i in train_keys} 65 | #valid_data = {i:all_data[i] for i in valid_keys} 66 | 67 | #print len(train_data.keys()), len(valid_data.keys()) 68 | 69 | #data_gen_validation = BatchGenerator3D_random_sampling(valid_data, 2, num_batches=None, seed=False, patch_size=(128, 128, 128), convert_labels=True) 70 | #for i_batch, sample_batch in enumerate(data_gen_validation): 71 | # print(i_batch, sample_batch['name']) 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | 4 | 5 | def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): 6 | shape = tuple(list(image.shape)) 7 | new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) 8 | if pad_value is None: 9 | if len(shape)==2: 10 | pad_value = image[0,0] 11 | elif len(shape)==3: 12 | pad_value = image[0, 0, 0] 13 | else: 14 | raise ValueError("Image must be either 2 or 3 dimensional") 15 | res = np.ones(list(new_shape), dtype=image.dtype) * pad_value 16 | if len(shape) == 2: 17 | res[0:0+int(shape[0]), 0:0+int(shape[1])] = image 18 | elif len(shape) == 3: 19 | res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image 20 | return res 21 | 22 | 23 | def random_crop_3D_image_batched(img, crop_size): 24 | if type(crop_size) not in (tuple, list): 25 | crop_size = [crop_size] * (len(img.shape) - 2) 26 | else: 27 | assert len(crop_size) == (len(img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" 28 | 29 | if crop_size[0] < img.shape[2]: 30 | lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) 31 | elif crop_size[0] == img.shape[2]: 32 | lb_x = 0 33 | else: 34 | raise ValueError("crop_size[0] must be smaller or equal to the images x dimension") 35 | 36 | if crop_size[1] < img.shape[3]: 37 | lb_y = np.random.randint(0, img.shape[3] - crop_size[1]) 38 | elif crop_size[1] == img.shape[3]: 39 | lb_y = 0 40 | else: 41 | raise ValueError("crop_size[1] must be smaller or equal to the images y dimension") 42 | 43 | if crop_size[2] < img.shape[4]: 44 | lb_z = np.random.randint(0, img.shape[4] - crop_size[2]) 45 | elif crop_size[2] == img.shape[4]: 46 | lb_z = 0 47 | else: 48 | raise ValueError("crop_size[2] must be smaller or equal to the images z dimension") 49 | 50 | return img[:, :, lb_x:lb_x + crop_size[0], lb_y:lb_y + crop_size[1], lb_z:lb_z + crop_size[2]] -------------------------------------------------------------------------------- /utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pykao/Modified-3D-UNet-Pytorch/63f0489e8d1fdd7ec6a203bcff095f12ea030824/utils.pyc --------------------------------------------------------------------------------