├── README.md ├── archs ├── net_usage.py └── old │ ├── net │ ├── net20 │ ├── net32 │ ├── netfcn │ ├── netfcnmem │ ├── netmem │ ├── netmem20 │ ├── netmem32 │ ├── netmemvw │ ├── netpad │ ├── netpadmem │ ├── netvwmem │ ├── test │ └── testmem ├── ash.py ├── brain_data_scripts ├── __init__.py ├── find_mha_files.py ├── read_images.py └── show_images.py ├── conv3d ├── README.md ├── __init__.py ├── __init__.pyc ├── basic.py ├── basic.pyc ├── cnn3d.py ├── cnn3d.pyc ├── model.py └── model.pyc ├── data ├── build_dataset.py ├── create_dummy.py ├── info │ ├── brats2013_challenge_info.json │ ├── brats2013_leaderboard_info.json │ ├── brats_test_info.json │ ├── brats_test_names.txt │ ├── check_names.py │ └── dummy_test_info.json └── pkl_to_hdf5.py ├── demo.py ├── folds.pkl ├── gnumpy.py ├── helper ├── get_dice.py ├── store_mha_results.py └── view_test_images.py ├── model_defs.py ├── requirements.txt ├── segment.py ├── train.py ├── train_model.py ├── train_vox.py └── transformations.py /README.md: -------------------------------------------------------------------------------- 1 | **Code for reproducing the results from [our paper on cnn-based medical image segmentation](https://arxiv.org/abs/1701.03056)** 2 | 3 | For any questions or issues, contact Baris via: bkayalibay@gmail.com. 4 | 5 | Whenever you use this code, please refer to our publication 6 | ``` 7 | @article{kayalibay2017cnn, 8 | title={CNN-based Segmentation of Medical Imaging Data}, 9 | author={Kayalibay, Baris and Jensen, Grady and van der Smagt, Patrick}, 10 | journal={arXiv preprint arXiv:1701.03056}, 11 | year={2017} 12 | } 13 | ``` 14 | 15 | # Requirements: 16 | 17 | + Python 2.7 18 | + Theano (0.9.0) (along with [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn)) 19 | + [cudamat](https://github.com/cudamat/cudamat) 20 | + [this fork](https://github.com/bkayalibay/breze) of breze 21 | + [climin](https://github.com/BRML/climin) 22 | + h5py 23 | + SimpleITK 24 | 25 | Some of these requirements can be installed with ``pip install package`` (theano, h5py, SimpleITK), 26 | others (breze, cudamat, climin) should be cloned from the github links provided 27 | and installed via ``pip install -e .`` 28 | 29 | # Usage: 30 | 31 | To segment new images and get test results, you will need to train the network first. 32 | The following steps need to be taken to create a data set, train and segment new images: 33 | 34 | ## Acquire the BRATS 2015 data set: 35 | 36 | Go to the [official brats website](http://braintumorsegmentation.org/) and download the 37 | BRATS 2015 data. Store the **training data** in this directory under a directory called ``BRATS2015_Training``. 38 | 39 | ## Create a data set: 40 | 41 | Run the following line on the terminal: 42 | 43 | ``python brain_data_scipts/read_images.py`` 44 | 45 | This will create a .hdf5 file called ``brats_fold0.hdf5`` under ``data/datasets``. 46 | This .hdf5 file contains three randomly created partitions train, valid and test 47 | for training, validation and testing. You can now use it to train a neural network. 48 | 49 | ## Train the network: 50 | 51 | Run: 52 | 53 | ``python train.py fcn_rffc4 brats_fold0 brats_fold0 600 -ch False`` 54 | 55 | This will train the network used in our paper on the data set brats_fold0 for 56 | 600 iterations over the data set and store the results at the path 57 | ``models/brats_fold0``. 58 | 59 | ## Test or reuse the trained network: 60 | 61 | Once you've trained a network, its parameters and hyperparameters are stored in 62 | a subdirectory of the directory ``models`` (read the docstring of the module 63 | ``train.py`` on how to select this). You can then reuse those parameters using 64 | the API provided in the module ``segment.py``. In ``segment.py`` you will find 65 | a function segment that can be used in the following way to segment new images: 66 | 67 | ``segment('BRATS2015_Training/HGG/brats_2013_pat0001_1', 'results2', 'fcn_rffc4', 5)`` 68 | 69 | In this example snippet, we are using the network with the id ``fcn_rffc4`` along 70 | with the parameters stored at ``models/results2`` to segment a medical image 71 | contained at the path ``BRATS2015_Training/HGG/brats_2013_pat0001_1``. 72 | The general usage is: 73 | 74 | ``segment([image_path], [params_path], [model_id])`` 75 | -------------------------------------------------------------------------------- /archs/net_usage.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | lines = [] 4 | mems = [] 5 | fcs = [] 6 | with open(sys.argv[1], 'r') as f: 7 | line = f.readline() 8 | while line: 9 | lines.append(line) 10 | left, right = line.split(':') 11 | mem = 1. 12 | for n in right.split('*'): 13 | mem *= int(n) 14 | if left.startswith('FC'): 15 | if len(fcs) == 0: 16 | fcs.append(mem) 17 | mem *= (mems[-1]/4) 18 | else: 19 | fcs.append(mem) 20 | mem *= fcs[-2] 21 | mems.append(mem*4) 22 | line = f.readline() 23 | 24 | with open(sys.argv[2], 'w') as f: 25 | total = 0. 26 | for l, m in zip(lines, mems): 27 | mem_in_mb = m / (1024.*1024.) 28 | total += mem_in_mb 29 | if mem_in_mb >= 1.: 30 | line = (l[:-1] + ' mem: %.4fMB\n') % (mem_in_mb) 31 | else: 32 | mem_in_kb = m / 1024. 33 | line = (l[:-1] + ' mem: %.4fKB\n') % (mem_in_kb) 34 | f.write(line) 35 | tot_usage = 'TOTAL USAGE~= %.4fMB\n' % total 36 | f.write(tot_usage) 37 | 38 | -------------------------------------------------------------------------------- /archs/old/net: -------------------------------------------------------------------------------- 1 | Input: 50*50*50*1 2 | C3-8: 48*48*48*8 3 | POOL2: 24*24*24*8 4 | C3-32: 22*22*22*32 5 | POOL2: 11*11*11*32 6 | C3-64: 9*9*9*64 7 | POOL2: 4*4*4*64 8 | FC: 512 9 | -------------------------------------------------------------------------------- /archs/old/net20: -------------------------------------------------------------------------------- 1 | INPT: 20*20*20*1 2 | C3-20: 18*18*18*20 3 | POOL2: 9*9*9*20 4 | C3-50: 7*7*7*50 5 | POOL2: 3*3*3*50 6 | FC: 256 7 | -------------------------------------------------------------------------------- /archs/old/net32: -------------------------------------------------------------------------------- 1 | INPUT: 32*32*32*1 2 | C3-8: 30*30*30*8 3 | POOL2: 15*15*15*8 4 | C3-16: 13*13*13*16 5 | POOL2: 6*6*6*16 6 | C3-32: 4*4*4*32 7 | POOL2: 2*2*2*32 8 | FC: 100 9 | -------------------------------------------------------------------------------- /archs/old/netfcn: -------------------------------------------------------------------------------- 1 | Input: 20*20*20*1 2 | C3-16: 20*20*20*16 3 | Pool2: 10*10*10*16 4 | C3-32: 10*10*10*32 5 | Pool2: 5*5*5*32 6 | C3-32: 5*5*5*32 7 | Uppool2: 10*10*10*32 8 | C3-32: 10*10*10*32 9 | Uppool2: 20*20*20*32 10 | C3-16: 20*20*20*16 11 | C3-2: 20*20*20*2 12 | -------------------------------------------------------------------------------- /archs/old/netfcnmem: -------------------------------------------------------------------------------- 1 | Input: 20*20*20*1 mem: 31.2500KB 2 | C3-16: 20*20*20*16 mem: 500.0000KB 3 | Pool2: 10*10*10*16 mem: 62.5000KB 4 | C3-32: 10*10*10*32 mem: 125.0000KB 5 | Pool2: 5*5*5*32 mem: 15.6250KB 6 | C3-32: 5*5*5*32 mem: 15.6250KB 7 | Uppool2: 10*10*10*32 mem: 125.0000KB 8 | C3-32: 10*10*10*32 mem: 125.0000KB 9 | Uppool2: 20*20*20*32 mem: 1000.0000KB 10 | C3-16: 20*20*20*16 mem: 500.0000KB 11 | C3-2: 20*20*20*2 mem: 62.5000KB 12 | TOTAL USAGE~= 2.5024MB 13 | -------------------------------------------------------------------------------- /archs/old/netmem: -------------------------------------------------------------------------------- 1 | Input: 50*50*50*1 mem: 488.2812KB 2 | C3-8: 48*48*48*8 mem: 3.3750MB 3 | POOL2: 24*24*24*8 mem: 432.0000KB 4 | C3-32: 22*22*22*32 mem: 1.2998MB 5 | POOL2: 11*11*11*32 mem: 166.3750KB 6 | C3-64: 9*9*9*64 mem: 182.2500KB 7 | POOL2: 4*4*4*64 mem: 16.0000KB 8 | FC: 512 mem: 8.0000MB 9 | TOTAL USAGE~= 13.9296MB 10 | -------------------------------------------------------------------------------- /archs/old/netmem20: -------------------------------------------------------------------------------- 1 | INPT: 20*20*20*1 mem: 31.2500KB 2 | C3-20: 18*18*18*20 mem: 455.6250KB 3 | POOL2: 9*9*9*20 mem: 56.9531KB 4 | C3-50: 7*7*7*50 mem: 66.9922KB 5 | POOL2: 3*3*3*50 mem: 5.2734KB 6 | FC: 256 mem: 1.3184MB 7 | TOTAL USAGE~= 1.9200MB 8 | -------------------------------------------------------------------------------- /archs/old/netmem32: -------------------------------------------------------------------------------- 1 | INPUT: 32*32*32*1 mem: 128.0000KB 2 | C3-8: 30*30*30*8 mem: 843.7500KB 3 | POOL2: 15*15*15*8 mem: 105.4688KB 4 | C3-16: 13*13*13*16 mem: 137.3125KB 5 | POOL2: 6*6*6*16 mem: 13.5000KB 6 | C3-32: 4*4*4*32 mem: 8.0000KB 7 | POOL2: 2*2*2*32 mem: 1.0000KB 8 | FC: 100 mem: 100.0000KB 9 | TOTAL USAGE~= 1.3057MB 10 | -------------------------------------------------------------------------------- /archs/old/netmemvw: -------------------------------------------------------------------------------- 1 | INPUT: 20*20*20*1 mem: 31.2500KB 2 | C3-64: 18*18*18*64 mem: 1.4238MB 3 | POOL2: 9*9*9*64 mem: 182.2500KB 4 | C3-128: 7*7*7*128 mem: 171.5000KB 5 | POOL2: 3*3*3*128 mem: 13.5000KB 6 | FC: 8000 mem: 105.4688MB 7 | TOTAL USAGE~= 107.2817MB 8 | -------------------------------------------------------------------------------- /archs/old/netpad: -------------------------------------------------------------------------------- 1 | INPUT: 50*50*50*1 2 | C3-8: 50*50*50*8 3 | POOL2: 25*25*25*8 4 | C3-16: 25*25*25*16 5 | POOL2: 12*12*12*16 6 | C3-32: 12*12*12*32 7 | POOL2: 6*6*6*32 8 | C3-64: 6*6*6*64 9 | POOL2: 3*3*3*64 10 | FC: 512 11 | -------------------------------------------------------------------------------- /archs/old/netpadmem: -------------------------------------------------------------------------------- 1 | INPUT: 50*50*50*1 mem: 488.2812KB 2 | C3-8: 50*50*50*8 mem: 3.8147MB 3 | POOL2: 25*25*25*8 mem: 488.2812KB 4 | C3-16: 25*25*25*16 mem: 976.5625KB 5 | POOL2: 12*12*12*16 mem: 108.0000KB 6 | C3-32: 12*12*12*32 mem: 216.0000KB 7 | POOL2: 6*6*6*32 mem: 27.0000KB 8 | C3-64: 6*6*6*64 mem: 54.0000KB 9 | POOL2: 3*3*3*64 mem: 6.7500KB 10 | FC: 512 mem: 3.3750MB 11 | TOTAL USAGE~= 9.4991MB 12 | -------------------------------------------------------------------------------- /archs/old/netvwmem: -------------------------------------------------------------------------------- 1 | INPUT: 20*20*20*1 mem: 31.2500KB 2 | C3-16: 20*20*20*16 mem: 500.0000KB 3 | POOL2: 10*10*10*16 mem: 62.5000KB 4 | C3-32: 10*10*10*32 mem: 125.0000KB 5 | POOL2: 5*5*5*32 mem: 15.6250KB 6 | C3-32: 5*5*5*32 mem: 15.6250KB 7 | UPPOOL: 10*10*10*32 mem: 125.0000KB 8 | C3-16: 10*10*10*32 mem: 125.0000KB 9 | UPPOOL: 20*20*20*32 mem: 1000.0000KB 10 | C3-2: 20*20*20*2 mem: 62.5000KB 11 | TOTAL USAGE~= 2.0142MB 12 | -------------------------------------------------------------------------------- /archs/old/test: -------------------------------------------------------------------------------- 1 | NPUT: 20*20*20*1 2 | C3-64: 20*20*20*32 3 | POOL2: 10*10*10*32 4 | C3-128: 10*10*10*64 5 | POOL2: 5*5*5*64 6 | FC: 1000 7 | FC: 8000 8 | -------------------------------------------------------------------------------- /archs/old/testmem: -------------------------------------------------------------------------------- 1 | NPUT: 20*20*20*1 mem: 31.2500KB 2 | C3-64: 20*20*20*32 mem: 1000.0000KB 3 | POOL2: 10*10*10*32 mem: 125.0000KB 4 | C3-128: 10*10*10*64 mem: 250.0000KB 5 | POOL2: 5*5*5*64 mem: 31.2500KB 6 | FC: 1000 mem: 30.5176MB 7 | FC: 8000 mem: 30.5176MB 8 | TOTAL USAGE~= 62.4390MB 9 | -------------------------------------------------------------------------------- /ash.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import exceptions 4 | import json 5 | import os 6 | import cPickle as pickle 7 | 8 | import theano 9 | import theano.tensor as T 10 | from theano.tensor.shared_randomstreams import RandomStreams 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from skimage import color 14 | 15 | import h5py 16 | import gnumpy 17 | 18 | from climin import mathadapt as ma 19 | from climin.util import iter_minibatches 20 | 21 | from breze.arch.util import lookup 22 | from breze.arch.component import transfer as _transfer 23 | 24 | def dice_demo_(seg, gt): 25 | dice = np.sum(2 * seg * gt) 26 | dice /= (np.sum(np.square(seg)) + np.sum(np.square(gt))) 27 | return dice 28 | 29 | 30 | def dice_demo(seg, gt): 31 | seg_transposed = np.transpose(seg, (3, 0, 1, 2)) 32 | gt = np.transpose(gt, (3, 0, 1, 2)) 33 | 34 | dice_list = [dice_demo_(s, g) for s, g in zip(seg_transposed, gt)] 35 | dice_list = [dice_demo_(seg_transposed, gt)] + dice_list 36 | return dice_list 37 | 38 | def vis_result(image, seg, gt, file_name='test.png'): 39 | indices = np.where(seg == 1) 40 | indices_gt = np.where(gt == 1) 41 | 42 | im_norm = image / image.max() 43 | rgb_image = color.gray2rgb(im_norm) 44 | multiplier = [0., 1., 1.] 45 | multiplier_gt = [1., 1., 0.] 46 | 47 | im_seg = rgb_image.copy() 48 | im_gt = rgb_image.copy() 49 | im_seg[indices[0], indices[1], :] *= multiplier 50 | im_gt[indices_gt[0], indices_gt[1], :] *= multiplier_gt 51 | 52 | fig = plt.figure() 53 | a = fig.add_subplot(1, 2, 1) 54 | plt.imshow(im_seg) 55 | a.set_title('Segmentation') 56 | a = fig.add_subplot(1, 2, 2) 57 | plt.imshow(im_gt) 58 | a.set_title('Ground truth') 59 | plt.savefig(file_name) 60 | 61 | def vis_col_im(im, seg, gt, file_name='test.png'): 62 | indices_0 = np.where(gt == 0) 63 | indices_1 = np.where(gt == 1) # green - metacarpal - necrosis 64 | indices_2 = np.where(gt == 2) # yellow - proximal - edema 65 | indices_3 = np.where(gt == 3) # orange - middle - enhancing tumor 66 | indices_4 = np.where(gt == 4) # red - distal - nonenhancing tumor 67 | 68 | indices_s0 = np.where(seg == 0) 69 | indices_s1 = np.where(seg == 1) 70 | indices_s2 = np.where(seg == 2) 71 | indices_s3 = np.where(seg == 3) 72 | indices_s4 = np.where(seg == 4) 73 | 74 | im = im * 1. / im.max() 75 | rgb_image = color.gray2rgb(im) 76 | m0 = [0.6, 0.6, 1.] 77 | m1 = [0.2, 1., 0.2] 78 | m2 = [1., 1., 0.2] 79 | m3 = [1., 0.6, 0.2] 80 | m4 = [1., 0., 0.] 81 | 82 | im_gt = rgb_image.copy() 83 | im_seg = rgb_image.copy() 84 | im_gt[indices_0[0], indices_0[1], :] *= m0 85 | im_gt[indices_1[0], indices_1[1], :] *= m1 86 | im_gt[indices_2[0], indices_2[1], :] *= m2 87 | im_gt[indices_3[0], indices_3[1], :] *= m3 88 | im_gt[indices_4[0], indices_4[1], :] *= m4 89 | 90 | im_seg[indices_s0[0], indices_s0[1], :] *= m0 91 | im_seg[indices_s1[0], indices_s1[1], :] *= m1 92 | im_seg[indices_s2[0], indices_s2[1], :] *= m2 93 | im_seg[indices_s3[0], indices_s3[1], :] *= m3 94 | im_seg[indices_s4[0], indices_s4[1], :] *= m4 95 | 96 | fig = plt.figure() 97 | a = fig.add_subplot(1, 2, 1) 98 | plt.imshow(im_seg) 99 | a.set_title('Segmentation') 100 | a = fig.add_subplot(1, 2, 2) 101 | plt.imshow(im_gt) 102 | a.set_title('Ground truth') 103 | plt.savefig(file_name) 104 | 105 | plt.close() 106 | 107 | class TransFun(object): 108 | def __init__(self, fun, *params): 109 | self.params = params 110 | self.fun = fun 111 | 112 | def __call__(self, inpt): 113 | return self.fun(inpt, *self.params) 114 | 115 | def tensor_softmax(inpt, n_classes=2): 116 | output = inpt.dimshuffle(0, 3, 4, 1, 2) 117 | output = T.reshape(output, (-1, n_classes)) 118 | 119 | f = lookup('softmax', _transfer) 120 | output = T.reshape(f(output), (1, -1, n_classes)) 121 | return output 122 | 123 | def tensor_ident(inpt, n_classes=2): 124 | output = inpt.dimshuffle(0, 3, 4, 1, 2) 125 | output = T.reshape(output, (1, -1, n_classes)) 126 | return output 127 | 128 | def fcn_cat_ce(target, prediction, eps=1e-8): 129 | ''' 130 | This loss function assumes the data set is processed one 131 | image (patch) at a time. As a consequence, the targets and 132 | the predictions should both be of shape (1, n_voxels, n_classes). 133 | ''' 134 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 135 | target = T.reshape(target, (target.shape[1], target.shape[2])) 136 | prediction = T.clip(prediction, eps, 1-eps) 137 | loss = -(target * T.log(prediction)) 138 | return loss 139 | 140 | def weighted_cat_ce(target, prediction, eps=1e-8): 141 | ''' 142 | This loss weights each class by some factor. 143 | ''' 144 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 145 | target = T.reshape(target, (target.shape[1], target.shape[2])) 146 | prediction = T.clip(prediction, eps, 1 - eps) 147 | loss = -((np.array([0.7, 0.3], dtype='float32') / (T.mean(target, axis=0))) * target * T.log(prediction)) 148 | return loss 149 | 150 | def cat_ce_parts(target, prediction, eps=1e-8): 151 | ''' 152 | This loss weights each class by some factor. 153 | ''' 154 | aleph = 0.4 155 | 156 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 157 | target = T.reshape(target, (target.shape[1], target.shape[2])) 158 | prediction = T.clip(prediction, eps, 1 - eps) 159 | 160 | b_inds = (target[:,1] > 0).nonzero() 161 | t_bones = target[b_inds] 162 | p_bones = prediction[b_inds] 163 | bones_loss = T.mean(-(t_bones * T.log(p_bones)), axis=0, keepdims=True) 164 | loss = T.mean(-(target * T.log(prediction)), axis=0, keepdims=True) 165 | 166 | return (1 - aleph) * loss + aleph * bones_loss 167 | 168 | def dice(target, prediction, eps=1e-8): 169 | ''' 170 | The dice loss as described in: 171 | https://arxiv.org/pdf/1606.04797v1.pdf 172 | (V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation) 173 | The above paper aims to maximize the dice measure. Since climin only 174 | minimizes, this function returns 1 - dice instead of dice, with the assumption 175 | that minimizing the former is equivalent to maximizing the latter. 176 | ''' 177 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 178 | target = T.reshape(target, (target.shape[1], target.shape[2])) 179 | prediction = T.clip(prediction, eps, 1 - eps) 180 | loss = 2*T.sum(target*prediction,axis=0,keepdims=True) 181 | loss /= (T.sum(T.sqr(target),axis=0,keepdims=True) + T.sum(T.sqr(prediction),axis=0,keepdims=True)) 182 | return 1 - loss 183 | 184 | def jaccard(target, prediction, eps=1e-8): 185 | ''' 186 | Jaccard distance, see: https://en.wikipedia.org/wiki/Jaccard_index 187 | ''' 188 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 189 | target = T.reshape(target, (target.shape[1], target.shape[2])) 190 | prediction = T.clip(prediction, eps, 1 - eps) 191 | intersection = T.sum(target * prediction, axis=0, keepdims=True) 192 | loss = intersection / (T.sum(target + prediction, axis=0, keepdims=True) - intersection) 193 | return 1 - loss 194 | 195 | def tanimoto(target, prediction, eps=1e-8): 196 | ''' 197 | Tanimoto distance, see: https://en.wikipedia.org/wiki/Jaccard_index#Other_definitions_of_Tanimoto_distance 198 | ''' 199 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 200 | target = T.reshape(target, (target.shape[1], target.shape[2])) 201 | prediction = T.clip(prediction, eps, 1 - eps) 202 | 203 | intersection = T.sum(target * prediction, axis=0, keepdims=True) 204 | prediction_sq = T.sum(T.sqr(prediction), axis=0, keepdims=True) 205 | target_sq = T.sum(T.sqr(target), axis=0, keepdims=True) 206 | 207 | loss = intersection / (target_sq + prediction_sq - intersection) 208 | return 1 - loss 209 | 210 | def tanimoto_wmap(target_in, prediction, eps=1e-8): 211 | ''' 212 | Tanimoto distance, see: https://en.wikipedia.org/wiki/Jaccard_index#Other_definitions_of_Tanimoto_distance 213 | ''' 214 | target_in = T.reshape(target_in, (target_in.shape[1], target_in.shape[2])) 215 | target = target_in[:, :2] 216 | wmap = T.repeat(target_in[:, 2].dimshuffle(('x', 0)), 2, axis=0).dimshuffle((1, 0)) 217 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 218 | prediction = T.clip(prediction, eps, 1 - eps) 219 | 220 | target_w = T.sum(T.sqr(target * wmap), axis=0, keepdims=True) 221 | pred_w = T.sum(T.sqr(prediction * wmap), axis=0, keepdims=True) 222 | intersection_w = T.sum(target_w * pred_w, axis=0, keepdims=True) 223 | 224 | intersection = T.sum(target * prediction, axis=0, keepdims=True) 225 | prediction_sq = T.sum(T.sqr(prediction), axis=0, keepdims=True) 226 | target_sq = T.sum(T.sqr(target), axis=0, keepdims=True) 227 | 228 | loss = (target_w + pred_w - 2 * intersection_w) / (target_sq + prediction_sq - intersection) 229 | return loss 230 | 231 | def kl_divergence(target, prediction, eps=1e-6): 232 | '''Kullback-Leibler divergence''' 233 | prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2])) 234 | target = T.reshape(target, (target.shape[1], target.shape[2])) 235 | prediction = T.clip(prediction, eps, 1 - eps) 236 | target = T.clip(target, eps, 1 - eps) 237 | 238 | kl = T.sum(target * T.log(target / prediction), axis=0, keepdims=True) 239 | return kl 240 | 241 | class BatchNormFuns(object): 242 | ''' 243 | Convenience class to compute network forward passes during validation and inference. 244 | ''' 245 | def __init__(self, model, phase, fun): 246 | ''' 247 | :param model: network model 248 | :param phase: 'valid' or 'infer', model will be reset 249 | to 'train' after computing forward pass 250 | in given phase. 251 | :param fun: the concrete function of the model to call 252 | i.e. model.score or model.predict 253 | ''' 254 | self.model = model 255 | self.phase = phase 256 | self.fun = fun 257 | 258 | def __call__(self, *data): 259 | return batchnorm_apply_fun(self.model, self.phase, self.fun, data) 260 | 261 | def batchnorm_apply_fun(model, phase, fun, data): 262 | model.phase_select(phase_id=phase) 263 | res = fun(*data) 264 | model.reset_phase() 265 | return res 266 | 267 | class PocketTrainer(object): 268 | def __init__(self, model, data, stop, 269 | pause, score_fun, report_fun, 270 | evaluate=True, test=True, batchnorm=False, 271 | model_code=None, n_report=None): 272 | self.model = model 273 | self.data = data 274 | self.stop = stop 275 | self.pause = pause 276 | self.score_fun = score_fun 277 | self.report_fun = report_fun 278 | self.best_pars = None 279 | self.best_loss = float('inf') 280 | self.runtime = 0 281 | self.evaluate = evaluate 282 | self.test = test 283 | self.losses = [] 284 | self.test_performance = [] 285 | self.model_code = model_code 286 | self.n_epochs_done = 0 287 | self.n_iters_done = 0 288 | self.n_report = n_report 289 | 290 | # if batchnorm: 291 | # self.m_score_train = BatchNormFuns( 292 | # model=self.model, phase='valid', 293 | # fun=self.model.score 294 | # ) 295 | # if bn_mode == 'native': 296 | # print 'using batch norm with running metrics for validation' 297 | # self.m_score_valid = BatchNormFuns( 298 | # model=self.model, phase='valid', 299 | # fun=self.model.score 300 | # ) 301 | # elif bn_mode == 'batch': 302 | # print 'using batch norm without running metrics' 303 | # self.m_score_valid = self.m_score_train 304 | # else: 305 | # raise ValueError('BN modes are: native, batch') 306 | # else: 307 | # self.m_score_train = self.m_score_valid = self.model.score 308 | 309 | self.using_bn = batchnorm 310 | 311 | def demo(self, predict, image, gt, size_reduction, im_name='test.png'): 312 | output_h = self.model.image_height - size_reduction 313 | output_w = self.model.image_width - size_reduction 314 | output_d = self.model.image_depth - size_reduction 315 | n_chans = self.model.n_channels 316 | n_classes = self.model.n_output 317 | 318 | segmentation = predict(image) 319 | segmentation = segmentation.as_numpy_array() if isinstance(segmentation, gnumpy.garray) else segmentation 320 | segmentation = np.reshape( 321 | segmentation, 322 | (output_h, output_w, output_d, n_classes) 323 | ) 324 | 325 | gt = np.reshape( 326 | gt, (output_h, output_w, output_d, n_classes) 327 | ) 328 | 329 | dice_list = dice_demo(segmentation, gt) 330 | segmentation = segmentation.argmax(axis=3) 331 | gt = gt.argmax(axis=3) 332 | 333 | image = np.reshape(np.transpose(image, (0,2,3,4,1)), (n_chans, output_h, output_w, output_d)) 334 | im_slice = image[0,:,:,image.shape[-1]/2] 335 | 336 | seg_slice = segmentation[:,:,segmentation.shape[-1]/2] 337 | gt_slice = gt[:,:,gt.shape[-1]/2] 338 | 339 | if n_classes == 2: 340 | vis_result(im_slice, seg_slice, gt_slice, file_name=im_name) 341 | elif n_classes == 5: 342 | vis_col_im(im=im_slice, seg=seg_slice, gt=gt_slice, file_name=im_name) 343 | else: 344 | raise NotImplementedError('Can only handle 2 or 5 classes') 345 | 346 | return dice_list 347 | 348 | def fit(self): 349 | try: 350 | for i in self.iter_fit(*self.data['train']): 351 | self.report_fun(i) 352 | except exceptions.IOError, e: 353 | pass 354 | except KeyboardInterrupt: 355 | self.quit_training() 356 | import sys 357 | sys.exit(0) 358 | 359 | 360 | def iter_fit(self, *fit_data): 361 | start = time.time() 362 | 363 | for info in self.model.iter_fit(*fit_data): 364 | if self.pause(info): 365 | # Take care of batch norm 366 | # Things done here shouldn't affect running metrics since no learning is supposed to happen. 367 | if self.using_bn: 368 | self.model.phase_select(phase_id='valid') 369 | 370 | if 'loss' not in info: 371 | info['loss'] = ma.scalar( 372 | self.score_fun(self.model.score, *self.data['train']) 373 | ) 374 | 375 | if self.evaluate: 376 | info['val_loss'] = ma.scalar( 377 | self.score_fun(self.model.score, *self.data['val']) 378 | ) 379 | 380 | if info['val_loss'] < self.best_loss: 381 | self.best_loss = info['val_loss'] 382 | self.best_pars = self.model.parameters.data.copy() 383 | 384 | self.losses.append((info['loss'], info['val_loss'])) 385 | else: 386 | self.losses.append(info['loss']) 387 | 388 | if self.test: 389 | info['test_avg'] = ma.scalar( 390 | self.score_fun(self.model.score, *self.data['test']) 391 | ) 392 | self.test_performance.append(info['test_avg']) 393 | 394 | self.runtime = time.time() - start 395 | info.update({ 396 | 'best_loss': self.best_loss, 397 | 'best_pars': self.best_pars, 398 | 'runtime': self.runtime 399 | }) 400 | self.n_epochs_done = info['n_iter'] / self.n_report 401 | self.n_iters_done = info['n_iter'] 402 | 403 | # Return to training mode, keep learning running metrics. 404 | if self.using_bn: 405 | self.model.phase_select(phase_id='train') 406 | 407 | yield info 408 | 409 | if self.stop(info): 410 | break 411 | 412 | def quit_training(self): 413 | if self.best_pars is None: 414 | print 'canceled before the end of the first epoch, nothing to do.' 415 | return 416 | 417 | model_code = self.model_code 418 | param_loc = os.path.join('models', 'checkpoints', model_code + '.hdf5') 419 | GLOB_CKPT_DIR = os.path.join('models', 'checkpoints') 420 | if not os.path.exists(GLOB_CKPT_DIR): 421 | os.makedirs(GLOB_CKPT_DIR) 422 | print 'setting checkpoint at: ', param_loc 423 | param_file = h5py.File(param_loc, 'w') 424 | best_params = param_file.create_dataset( 425 | 'best_pars', self.model.parameters.data.shape, dtype='float32' 426 | ) 427 | last_params = param_file.create_dataset( 428 | 'last_pars', best_params.shape, dtype='float32' 429 | ) 430 | 431 | if isinstance(self.best_pars, gnumpy.garray): 432 | best_params[...] = self.best_pars.as_numpy_array() 433 | last_params[...] = self.model.parameters.data.as_numpy_array() 434 | else: 435 | best_params[...] = self.best_pars[...] 436 | last_params[...] = self.model.parameters.data[...] 437 | param_file.close() 438 | 439 | if self.using_bn: 440 | bn_pars = self.model.get_batchnorm_params() 441 | bn_pars_path = os.path.join('models', 'checkpoints', model_code + '_bn_pars.pkl') 442 | with open(bn_pars_path, 'w') as f: 443 | pickle.dump(bn_pars, f) 444 | 445 | mini_log_code = os.path.join('models', 'checkpoints', model_code + '_log.json') 446 | 447 | if os.path.exists(mini_log_code): 448 | print 'previous log found' 449 | with open(mini_log_code, 'r') as f: 450 | prev_log = json.load(f) 451 | print 'updating current log...' 452 | 453 | self.losses = prev_log['losses'] + self.losses 454 | self.test_performance = prev_log['test_performance'] + self.test_performance 455 | self.n_epochs_done += prev_log['n_epochs'] 456 | self.n_iters_done += prev_log['n_iters'] 457 | mini_log = { 458 | 'losses': self.losses, 459 | 'test_performance': self.test_performance, 460 | 'best_loss': self.best_loss, 461 | 'n_epochs': self.n_epochs_done, 462 | 'n_iters': self.n_iters_done 463 | } 464 | print 'writing new log at: ', mini_log_code 465 | with open(mini_log_code, 'w') as f: 466 | json.dump(mini_log, f) 467 | 468 | print 'all done.' 469 | return 470 | 471 | class MinibatchTest(object): 472 | def __init__(self, max_samples, sample_dims): 473 | self.max_samples = max_samples 474 | self.sample_dims = sample_dims 475 | 476 | def __call__(self, predict_f, *data): 477 | batches = iter_minibatches(data, self.max_samples, self.sample_dims, 1) 478 | seen_samples = 0. 479 | score = 0. 480 | for batch in batches: 481 | x, z = batch 482 | y = predict_f(x) 483 | this_samples = int(y.shape[self.sample_dims[0]]) 484 | errs = (y.argmax(axis=1) != z.argmax(axis=1)).sum() 485 | score += errs 486 | seen_samples += this_samples 487 | 488 | return ma.scalar(score / seen_samples) 489 | 490 | class MinibatchTestFCN(object): 491 | def __init__(self, max_samples, sample_dims): 492 | self.max_samples = max_samples 493 | self.sample_dims = sample_dims 494 | 495 | def __call__(self, predict_f, *data): 496 | batches = iter_minibatches(data, self.max_samples, self.sample_dims, 1) 497 | seen_samples = 0. 498 | score = 0. 499 | for batch in batches: 500 | x, z = batch 501 | y = predict_f(x) 502 | this_samples = int(y.shape[1]) 503 | errs = (y.argmax(axis=2) != z.argmax(axis=2)).sum() 504 | score += errs 505 | seen_samples += this_samples 506 | 507 | return ma.scalar(score / seen_samples) 508 | 509 | class MinibatchScoreFCN(object): 510 | def __init__(self, max_samples, sample_dims): 511 | self.max_samples = max_samples 512 | self.sample_dims = sample_dims 513 | 514 | def __call__(self, f_score, *data): 515 | batches = iter_minibatches(data, self.max_samples, self.sample_dims, 1) 516 | score = 0. 517 | seen_samples = 0. 518 | for batch in batches: 519 | x = batch[0] 520 | z = batch[1] 521 | this_samples = int(x.shape[0]) 522 | score += f_score(x, z) * this_samples 523 | seen_samples += this_samples 524 | return ma.scalar(score / seen_samples) -------------------------------------------------------------------------------- /brain_data_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from read_images import ( 2 | get_im, get_im_as_ndarray, 3 | get_gt, convert_gt_to_onehot, 4 | get_image_slice 5 | ) 6 | from show_images import ( 7 | vis_col_im 8 | ) -------------------------------------------------------------------------------- /brain_data_scripts/find_mha_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | def find_patient_dirs(dirs): 5 | patients = [] 6 | for d in dirs: 7 | pat_code = d['path'].split('VSD')[0] 8 | if pat_code not in patients: 9 | patients.append(pat_code) 10 | return patients 11 | 12 | def get_patient_dirs(path, dirs): 13 | for it in os.listdir(path): 14 | it_path = os.path.join(path, it) 15 | if os.path.isdir(it_path): 16 | if 'pat' in it: 17 | dirs.append(it_path) 18 | else: 19 | dirs = get_patient_dirs(it_path, dirs) 20 | return dirs 21 | 22 | def _crawl(path, dirs): 23 | files = [] 24 | mhas = [] 25 | for it in os.listdir(path): 26 | it_path = os.path.join(path, it) 27 | if os.path.isfile(it_path): 28 | if it.endswith('.mha'): 29 | mhas.append(it_path) 30 | else: 31 | files.append(it_path) 32 | elif os.path.isdir(it_path): 33 | dirs = _crawl(it_path, dirs) 34 | if len(mhas) > 0 or len(files) > 0: 35 | new_dir = {'path': path, 'files': files, 'mhas': mhas} 36 | dirs.append(new_dir) 37 | return dirs 38 | 39 | def crawl(path, item, mhas): 40 | if item.endswith('.mha'): 41 | mha_path = os.path.join(path, item) 42 | mhas.append(mha_path) 43 | return mhas 44 | else: 45 | new_path = os.path.join(path, item) 46 | if not os.path.isdir(new_path): 47 | return mhas 48 | for it in os.listdir(new_path): 49 | mhas = crawl(new_path, it, mhas) 50 | return mhas 51 | 52 | def find_patients(dir_name): 53 | pats = [] 54 | pats = get_patient_dirs(dir_name, pats) 55 | 56 | for p in pats: 57 | print p 58 | print 'Found %i patient directories.' % len(pats) 59 | 60 | if __name__ == '__main__': 61 | if len(sys.argv) != 2: 62 | raise ValueError('You have to input the directory to be crawled.') 63 | dir_name = sys.argv[1] 64 | 65 | find_patients(dir_name) 66 | -------------------------------------------------------------------------------- /brain_data_scripts/read_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import six 3 | import cPickle as pickle 4 | import SimpleITK as sitk 5 | import numpy as np 6 | import warnings 7 | from breze.learn.data import one_hot 8 | import h5py 9 | from scipy.ndimage import zoom 10 | from find_mha_files import get_patient_dirs 11 | from random import shuffle 12 | from skimage import color 13 | import copy 14 | import matplotlib.pyplot as plt 15 | 16 | def vis_col_im(im, gt): 17 | indices_0 = np.where(gt == 0) # nothing 18 | indices_1 = np.where(gt == 1) # necrosis 19 | indices_2 = np.where(gt == 2) # edema 20 | indices_3 = np.where(gt == 3) # non-enhancing tumor 21 | indices_4 = np.where(gt == 4) # enhancing tumor 22 | 23 | im = np.asarray(im, dtype='float32') 24 | im = im*1./im.max() 25 | rgb_image = color.gray2rgb(im) 26 | m0 = [1., 1., 1.] 27 | m1 = [1., 0., 0.] 28 | m2 = [0.2, 1., 0.2] 29 | m3 = [1., 1., 0.2] 30 | m4 = [1., 0.6, 0.2] 31 | 32 | im = rgb_image.copy() 33 | im[indices_0[0], indices_0[1], :] *= m0 34 | im[indices_1[0], indices_1[1], :] *= m1 35 | im[indices_2[0], indices_2[1], :] *= m2 36 | im[indices_3[0], indices_3[1], :] *= m3 37 | im[indices_4[0], indices_4[1], :] *= m4 38 | 39 | plt.imshow(im) 40 | plt.show() 41 | plt.close() 42 | 43 | def col_im(im, gt): 44 | im = np.asarray(im, dtype='float32') 45 | im = im*1./im.max() 46 | rgb_image = color.gray2rgb(im) 47 | im = rgb_image.copy() 48 | 49 | if gt is None: 50 | return im 51 | 52 | indices_0 = np.where(gt == 0) # nothing 53 | indices_1 = np.where(gt == 1) # necrosis 54 | indices_2 = np.where(gt == 2) # edema 55 | indices_3 = np.where(gt == 3) # non-enhancing tumor 56 | indices_4 = np.where(gt == 4) # enhancing tumor 57 | 58 | m0 = [1., 1., 1.] 59 | m1 = [1., 0., 0.] # red: necrosis 60 | m2 = [0.2, 1., 0.2] # green: edema 61 | m3 = [1., 1., 0.2] # yellow: non-enhancing tumor 62 | m4 = [1., 0.6, 0.2] # orange: enhancing tumor 63 | 64 | im[indices_0[0], indices_0[1], :] *= m0 65 | im[indices_1[0], indices_1[1], :] *= m1 66 | im[indices_2[0], indices_2[1], :] *= m2 67 | im[indices_3[0], indices_3[1], :] *= m3 68 | im[indices_4[0], indices_4[1], :] *= m4 69 | 70 | return im 71 | 72 | def vis_ims(im0, gt0, im1, gt1, title0='Original', title1='Transformed'): 73 | im0 = col_im(im0, gt0) 74 | im1 = col_im(im1, gt1) 75 | 76 | fig = plt.figure() 77 | a = fig.add_subplot(1,2,1) 78 | plt.imshow(im0) 79 | a.set_title(title0) 80 | a = fig.add_subplot(1,2,2) 81 | plt.imshow(im1) 82 | a.set_title(title1) 83 | 84 | plt.show() 85 | plt.close() 86 | 87 | def get_im_as_ndarray(image, downsize=False): 88 | ims = [image['Flair'], image['T1'], image['T1c'], image['T2']] 89 | if downsize: 90 | ims = [zoom(x, 0.5, order=1) for x in ims] 91 | im = np.array(ims, dtype='int16') 92 | 93 | return im 94 | 95 | def get_gt(gt, n_classes, downsize=False): 96 | if not downsize: 97 | return gt 98 | original_shape = gt.shape 99 | gt_onehot = np.reshape(gt, (-1,)) 100 | gt_onehot = np.reshape(one_hot(gt_onehot, n_classes), original_shape + (n_classes,)) 101 | gt_onehot = np.transpose(gt_onehot, (3, 0, 1, 2)) 102 | 103 | zoom_gt = np.array([zoom(class_map, 0.5, order=1) for class_map in gt_onehot]) 104 | zoom_gt = zoom_gt.argmax(axis=0) 105 | zoom_gt = np.asarray(zoom_gt, dtype='int8') 106 | 107 | return zoom_gt 108 | 109 | def convert_gt_to_onehot(gt, n_classes): 110 | gt_onehot = np.transpose(gt, (1, 2, 0)) 111 | gt_onehot = np.reshape(gt_onehot, (-1,)) 112 | gt_onehot = np.reshape(one_hot(gt_onehot, n_classes), (-1, n_classes)) 113 | 114 | return gt_onehot 115 | 116 | def process_gt(gt, n_classes, downsize=False): 117 | if downsize: 118 | gt = zoom(gt, 0.5, order=0) 119 | gt = np.asarray(gt, dtype='int8') 120 | gt = np.transpose(gt, (1, 2, 0)) 121 | l = np.reshape(gt, (-1,)) 122 | l = np.reshape(one_hot(l, n_classes), (-1, n_classes)) 123 | return l 124 | 125 | def center(im): 126 | indices = np.where(im > 0) 127 | indices = np.array(indices) 128 | indices = indices.T 129 | 130 | return [int(i) for i in np.round(np.mean(indices, axis=0))] 131 | 132 | def get_pats(dir_name): 133 | pats = [] 134 | pats = get_patient_dirs(dir_name, pats) 135 | return pats 136 | 137 | def find_mha_paths(pat_dir, paths): 138 | for item in os.listdir(pat_dir): 139 | item_path = os.path.join(pat_dir, item) 140 | if os.path.isdir(item_path): 141 | paths = find_mha_paths(item_path, paths) 142 | elif os.path.isfile(item_path): 143 | if item.endswith('.mha'): 144 | paths.append(item_path) 145 | return paths 146 | 147 | def get_im(pat_dir): 148 | paths = [] 149 | paths = find_mha_paths(pat_dir, paths) 150 | gt = None 151 | im = {'Flair': None, 'T1': None, 'T1c': None, 'T2': None, 'gt': None} 152 | for p in paths: 153 | itk_image = sitk.ReadImage(p) 154 | nd_image = sitk.GetArrayFromImage(itk_image) 155 | if 'more' in p or 'OT' in p: 156 | if gt is None: 157 | gt = nd_image 158 | else: 159 | raise ValueError('Found multiple ground truths.') 160 | elif 'Flair' in p: 161 | im['Flair'] = nd_image 162 | elif 'T1c' in p: 163 | im['T1c'] = nd_image 164 | elif 'T1' in p: 165 | im['T1'] = nd_image 166 | elif 'T2' in p: 167 | im['T2'] = nd_image 168 | else: 169 | print 'Unexpected path: ', p 170 | if gt is None: 171 | warnings.warn('Could not find ground truth. Is this a test image?') 172 | im['gt'] = gt 173 | return im 174 | 175 | def check(coords, shape): 176 | z, y, x = coords 177 | sl_z = (z-64, z+64) 178 | sl_y = (y-84, y+76) # -70, +90 179 | sl_x = (x-72, x+72) 180 | if sl_z[0] < 0: 181 | sl_z = (0, 128) 182 | elif sl_z[1] > shape[0]: 183 | sl_z = (shape[0]-128, shape[0]) 184 | 185 | if sl_y[0] < 0: 186 | sl_y = (0, 160) 187 | elif sl_y[1] > shape[1]: 188 | sl_y = (shape[1]-160, shape[1]) 189 | 190 | if sl_x[0] < 0: 191 | sl_x = (0, 144) 192 | elif sl_x[1] > shape[2]: 193 | sl_x = (shape[2]-144, shape[2]) 194 | 195 | z_s = slice(sl_z[0], sl_z[1]) 196 | y_s = slice(sl_y[0], sl_y[1]) 197 | x_s = slice(sl_x[0], sl_x[1]) 198 | 199 | return (z_s, y_s, x_s) 200 | 201 | def create_folds(dir_name='..//BRATS2015_Training'): 202 | pats = get_pats(dir_name) 203 | shuffle(pats) 204 | 205 | folds = [] 206 | for i in range(3): 207 | validation_slice = slice(i*74, i*74+74) 208 | valid_and_test = pats[validation_slice] 209 | valid = valid_and_test[:37] 210 | test = valid_and_test[37:] 211 | train = pats[:i*74] + pats[i*74+74:] 212 | folds.append({ 213 | 'train': copy.deepcopy(train), 214 | 'valid': copy.deepcopy(valid), 215 | 'test': copy.deepcopy(test) 216 | }) 217 | 218 | with open('folds.pkl', 'w') as f: 219 | pickle.dump(folds, f) 220 | 221 | return folds 222 | 223 | def test_folds(): 224 | folds = create_folds() 225 | 226 | for i, fold in enumerate(folds): 227 | print 'Fold %i: ' % (i+1) 228 | for key in ['train', 'valid', 'test']: 229 | print '\t%s: ' % key 230 | print '\t%i patients' % len(fold[key]) 231 | for patient in fold[key]: 232 | print '\t', patient 233 | 234 | def build_hdf5_from_fold(): 235 | """ 236 | Function for creating our training set. 237 | This function will first search a file called 238 | folds.pkl in the current directory. folds.pkl 239 | is a file detailing three cross-validation folds, 240 | where each cross-validation fold has a different 241 | set of training, validation and testing partitions. 242 | If no folds.pkl is present, the function create_folds 243 | will be called to create it. The variable fold_number 244 | determines which of the three cross-validation folds should 245 | be used to create the data set. 246 | The data set itself will be a .hdf5 file that will be saved under 247 | ../data/datasets/brats_foldX.hdf5 where X=fold_number 248 | """ 249 | if os.path.exists('folds.pkl'): 250 | with open('folds.pkl', 'r') as f: 251 | folds = pickle.load(f) 252 | else: 253 | folds = create_folds('BRATS2015_Training') 254 | 255 | fold_number = 0 # ADAPT TO FOLD 256 | fold = folds[fold_number] 257 | 258 | if not os.path.exists('data//datasets'): 259 | os.makedirs('data//datasets') 260 | 261 | data = h5py.File('data//datasets//brats_fold'+str(fold_number)+'.hdf5', 'w') 262 | depth, height, width = (128, 160, 144) 263 | n_chans = 4 264 | dimprod = height*width*depth 265 | n_classes = 5 266 | 267 | train_size = 200 268 | valid_size = 37 269 | test_size = 37 270 | 271 | x = data.create_dataset('train_x', (train_size, depth, n_chans, height, width), dtype='int16') 272 | vx = data.create_dataset('valid_x', (valid_size, depth, n_chans, height, width), dtype='int16') 273 | tx = data.create_dataset('test_x', (test_size, depth, n_chans, height, width), dtype='int16') 274 | y = data.create_dataset('train_y', (train_size, dimprod, n_classes), dtype='int8') 275 | vy = data.create_dataset('valid_y', (valid_size, dimprod, n_classes), dtype='int8') 276 | ty = data.create_dataset('test_y', (test_size, dimprod, n_classes), dtype='int8') 277 | 278 | dat_access = { 279 | 'train': (x, y), 280 | 'valid': (vx, vy), 281 | 'test': (tx, ty) 282 | } 283 | 284 | for key in ['train', 'valid', 'test']: 285 | print 'building %s set' % key 286 | index = 0 287 | size = len(fold[key]) 288 | for image in gen_images(custom_pats=fold[key], crop=True, n=-1): 289 | print '\treading image %i of %i...' % (index+1, size) 290 | gt = get_gt(image['gt'], n_classes, downsize=False) 291 | im = get_im_as_ndarray(image, downsize=False) 292 | 293 | # sanity check 294 | #t_im = im[0] 295 | #for _slice in np.arange(0, t_im.shape[0], t_im.shape[0]/15): 296 | # im_slice = t_im[_slice] 297 | # gt_slice = gt[_slice] 298 | # vis_ims(im0=im_slice, gt0=gt_slice, im1=im_slice, gt1=np.zeros(im_slice.shape)) 299 | # 300 | 301 | dat_access[key][0][index, :, :, :, :] = np.transpose(im, (1, 0, 2, 3)) 302 | dat_access[key][1][index, :, :] = convert_gt_to_onehot(gt, n_classes) 303 | index += 1 304 | data.close() 305 | 306 | def get_image_slice(image): 307 | z_s, x_s, y_s = check(center(image['Flair']), image['Flair'].shape) 308 | im = {} 309 | for key, value in six.iteritems(image): 310 | if value is not None: 311 | im.update({key: value[z_s, x_s, y_s]}) 312 | return im, (z_s, x_s, y_s) 313 | 314 | 315 | def gen_images(dir_name='..//BRATS2015_Training', n=1, specific=False, interval=None, crop=False, randomize=False, custom_pats=None): 316 | pats = get_pats(dir_name) if custom_pats is None else custom_pats 317 | print '%i images in total.' % len(pats) 318 | if randomize: 319 | print 'shuffling patients.' 320 | shuffle(pats) 321 | 322 | im_gts = [] 323 | if interval is None: 324 | a = 0 325 | b = n 326 | else: 327 | a, b = interval 328 | if b == -1: 329 | b = len(pats) 330 | if a == -1: 331 | pats = pats[::-1] 332 | a = 0 333 | print 'yielding images in reverse order.' 334 | elif b > len(pats): 335 | raise ValueError('There are %i images but user requested %i.' % (len(pats), b)) 336 | if not specific: 337 | print 'yielding images in range: (%i, %i).' % (a, b) 338 | for p in pats[a:b]: 339 | try: 340 | print('{}\t'.format(p)) 341 | im = get_im(p) 342 | except ValueError: 343 | print 'Problem with: ', p 344 | raise 345 | if im is not None: 346 | if not crop: 347 | yield im 348 | else: 349 | z_s, x_s, y_s = check(center(im['Flair']), im['Flair'].shape) 350 | for key in im: 351 | if im[key] is not None: 352 | im[key] = im[key][z_s, x_s, y_s] 353 | yield im 354 | else: 355 | if b != len(pats): 356 | print 'yielding image %i.' % b 357 | p = pats[b] 358 | try: 359 | im = get_im(p) 360 | except ValueError: 361 | print 'Problem with: ', p 362 | raise 363 | if not crop: 364 | yield im 365 | else: 366 | for key in im: 367 | z_s, x_s, y_s = check(center(im['Flair']), im['Flair'].shape) 368 | if im[key] is not None: 369 | im[key] = im[key][z_s, x_s, y_s] 370 | yield im 371 | else: 372 | raise ValueError('There are %i images but user requested image %i(images are zero-indexed).' % (len(pats), b)) 373 | 374 | def make_data_set(): 375 | data = h5py.File('data.hdf5', 'w') 376 | depth, height, width = (64, 80, 72) 377 | n_chans = 4 378 | dimprod = height*width*depth 379 | n_classes = 5 380 | 381 | train_size = 200 382 | valid_size = 37 383 | test_size = 37 384 | 385 | x = data.create_dataset('train_x', (train_size, depth, n_chans, height, width), dtype='int16') 386 | vx = data.create_dataset('valid_x', (valid_size, depth, n_chans, height, width), dtype='int16') 387 | tx = data.create_dataset('test_x', (test_size, depth, n_chans, height, width), dtype='int16') 388 | y = data.create_dataset('train_y', (train_size, dimprod, n_classes), dtype='int8') 389 | vy = data.create_dataset('valid_y', (valid_size, dimprod, n_classes), dtype='int8') 390 | ty = data.create_dataset('test_y', (test_size, dimprod, n_classes), dtype='int8') 391 | 392 | dat_access = { 393 | 'train': (x, y), 394 | 'valid': (vx, vy), 395 | 'test': (tx, ty) 396 | } 397 | 398 | count = 0 399 | index = 0 400 | access_code = 'train' 401 | print 'starting with train set' 402 | for image in gen_images(n=-1, crop=True, randomize=True): 403 | if count == 274: 404 | print 'read 274 images, terminating...' 405 | break 406 | print '\tReading image %i...' % (count+1) 407 | gt = process_gt(image['gt'], n_classes, downsize=True) 408 | im = get_im_as_ndarray(image, downsize=True) 409 | 410 | dat_access[access_code][0][index, :, :, :, :] = np.transpose(im, (1, 0, 2, 3)) 411 | dat_access[access_code][1][index, :, :] = gt 412 | 413 | index += 1 414 | count += 1 415 | if count == 200: 416 | print 'train set complete, proceeding to valid set.' 417 | access_code = 'valid' 418 | index = 0 419 | elif count == 237: 420 | print 'valid set complete, proceeding to test set.' 421 | access_code = 'test' 422 | index = 0 423 | data.close() 424 | 425 | def get_shapes(im): 426 | shapes = [] 427 | for key in im: 428 | shapes.append(im[key].shape) 429 | return shapes 430 | 431 | def check_shapes(im): 432 | for sh in get_shapes(im): 433 | if sh != (128, 160, 144): 434 | return False 435 | return True 436 | 437 | def test_shapes(): 438 | count = 1 439 | errors = 0 440 | for image in gen_images(n=-1, crop=True): 441 | if not check_shapes(image): 442 | print 'Problem with image %i.' % count 443 | errors += 1 444 | else: 445 | print 'image %i is ok.' % count 446 | count += 1 447 | print 'Finished with %i errors.' % errors 448 | 449 | if __name__ == '__main__': 450 | #make_data_set() 451 | #test_shapes() 452 | #test_folds() 453 | #create_folds() 454 | build_hdf5_from_fold() 455 | -------------------------------------------------------------------------------- /brain_data_scripts/show_images.py: -------------------------------------------------------------------------------- 1 | from read_images import gen_images, center 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from skimage.transform import swirl, rescale, rotate, downscale_local_mean, PiecewiseAffineTransform, warp 5 | from skimage import color 6 | from skimage import filters 7 | import scipy.ndimage as ndi 8 | import scipy.signal as sig 9 | from scipy.ndimage import zoom 10 | from scipy.ndimage.interpolation import rotate as rotate_scipy 11 | from breze.learn.data import one_hot 12 | 13 | def groundtruth_(gt): 14 | """Takes a discrete label volume with zero-indexed labels and applies one_hot encoding.""" 15 | n_classes = gt.max() + 1 16 | shape = gt.shape 17 | l = np.reshape(gt, (-1,)) 18 | l = np.reshape(one_hot(l, n_classes), (-1, n_classes)) 19 | gt_onehot = np.reshape(l, shape + (n_classes,)) 20 | return gt_onehot 21 | 22 | def rotate_transform(im, gt): 23 | ang = np.random.uniform(-90, 90) 24 | axes = np.random.permutation(3)[:2] 25 | rot_im = rotate_scipy(im, ang, axes=axes, order=3, reshape=False) 26 | rot_gt = groundtruth_(gt) 27 | rot_gt = np.array([ 28 | rotate_scipy(class_map, ang, axes=axes, order=3, reshape=False) 29 | for class_map in np.transpose(rot_gt, (3, 0, 1, 2))]) 30 | rot_gt = rot_gt.argmax(axis=0) 31 | rot_gt = np.array(rot_gt, dtype='int8') 32 | 33 | return (rot_im, rot_gt) 34 | 35 | def sinus(image, strength): 36 | rows, cols = image.shape[0], image.shape[1] 37 | 38 | src_cols = np.linspace(0, cols, 5) 39 | src_rows = np.linspace(0, rows, 2) 40 | src_rows, src_cols = np.meshgrid(src_rows, src_cols) 41 | src = np.dstack([src_cols.flat, src_rows.flat])[0] 42 | 43 | # add sinusoidal oscillation to row coordinates 44 | dst_rows = src[:, 1] - np.sin(np.linspace(0, 2*np.pi, src.shape[0])) * strength 45 | dst_cols = src[:, 0] 46 | dst_rows *= 1. 47 | dst_rows -= 1.5 * strength 48 | dst = np.vstack([dst_cols, dst_rows]).T 49 | 50 | 51 | tform = PiecewiseAffineTransform() 52 | tform.estimate(src, dst) 53 | 54 | out_rows = image.shape[0] #- 1.5 * 5 55 | out_cols = cols 56 | out = warp(image, tform, output_shape=(out_rows, out_cols)) 57 | return np.array(out, dtype='float32') 58 | 59 | def sinus_(im, strength): 60 | return np.array([sinus(im_slice, strength) for im_slice in im], dtype='float32') 61 | 62 | def sinus_transform(im, gt): 63 | strength = np.random.uniform(3, 6) 64 | sinus_im = sinus_(im, strength) 65 | sinus_gt = groundtruth_(gt) 66 | sinus_gt = np.array([sinus_(class_map, strength) for class_map in np.transpose(sinus_gt, (3, 0, 1, 2))]) 67 | sinus_gt = sinus_gt.argmax(axis=0) 68 | sinus_gt = np.array(sinus_gt, dtype='int8') 69 | 70 | return (sinus_im, sinus_gt) 71 | 72 | def swirl_(im, strength, radius): 73 | return np.array([swirl(im_slice, rotation=0, strength=strength, radius=radius) for im_slice in im], dtype='float32') 74 | 75 | def swirl_transform(im, gt): 76 | strength = np.random.uniform(1, 2) 77 | radius = np.random.randint(90, 140) 78 | 79 | swirled_im = swirl_(im, strength, radius) 80 | swirled_gt = groundtruth_(gt) 81 | swirled_gt = np.array([swirl_(class_map, strength, radius) for class_map in np.transpose(swirled_gt, (3, 0, 1, 2))]) 82 | swirled_gt = swirled_gt.argmax(axis=0) 83 | swirled_gt = np.array(swirled_gt, dtype='int8') 84 | 85 | return (swirled_im, swirled_gt) 86 | 87 | def rotate_3d_ski(im, gt): 88 | im = np.transpose(im, (1, 2, 0)) 89 | gt = np.transpose(gt, (1, 2, 0)) 90 | 91 | ang = np.random.uniform(0, 360) 92 | r_im = rotate(im , ang, order=3) 93 | r_gt = rotate(gt, ang, order=3) 94 | 95 | return np.transpose(r_im, (2, 0, 1)), np.transpose(r_gt, (2, 0, 1)) 96 | 97 | def re_rescale(im): 98 | d_im = zoom(im, (1, 0.5, 0.8), order=3) 99 | d_im = zoom(d_im, (1, 2, (1/0.8)), order=3) 100 | 101 | return d_im 102 | 103 | def bounding_box(p1, p2): 104 | x1, y1, z1 = p1 105 | x2, y2, z2 = p2 106 | 107 | x_s = slice(np.minimum(x1, x2), np.maximum(x1, x2)) 108 | y_s = slice(np.minimum(y1, y2), np.maximum(y1, y2)) 109 | z_s = slice(np.minimum(z1, z2), np.maximum(z1, z2)) 110 | 111 | return (x_s, y_s, z_s) 112 | 113 | def compute_random_region(shape): 114 | x = np.random.randint(low=0, high=shape[0], size=(2,)) 115 | y = np.random.randint(low=0, high=shape[1], size=(2,)) 116 | z = np.random.randint(low=0, high=shape[2], size=(2,)) 117 | p1 = [c[0] for c in [x, y, z]] 118 | p2 = [c[1] for c in [x, y, z]] 119 | 120 | bb = bounding_box(p1, p2) 121 | 122 | return bb 123 | 124 | def noise(im, intensity=1, n=1): 125 | if n > 1: 126 | new_im = im.copy() 127 | for i in np.arange(0, n): 128 | bb = compute_random_region(im.shape) 129 | new_im[bb] = noise(new_im[bb], intensity, 1) 130 | return new_im 131 | 132 | try: 133 | noise_vol = np.random.randint(0, int(im.mean()*intensity), size=im.shape) 134 | noise_vol = np.asarray(noise_vol, dtype='int16') 135 | except ValueError: 136 | return im 137 | return im + noise_vol 138 | 139 | def flip(im, full=True): 140 | if full: 141 | return im[::-1, ::-1, ::-1] 142 | else: 143 | return im[:, ::-1, ::-1] 144 | 145 | def sharpen(blurred_im, alpha): 146 | filter_blurred_im = ndi.gaussian_filter(blurred_im, 1) 147 | sharp = blurred_im + alpha * (blurred_im - filter_blurred_im) 148 | 149 | return sharp 150 | 151 | def compute_random_shadow(shape, intensity): 152 | bb = compute_random_region(shape) 153 | shadow = np.ones(shape, dtype='int16') 154 | shadow[bb] = intensity 155 | 156 | return shadow 157 | 158 | def shadow(im, intensity=None, n=1): 159 | if intensity is None: 160 | intensity = im.mean() 161 | if im.dtype != 'float': 162 | intensity = int(intensity) 163 | 164 | sh_im = im.copy() 165 | for i in np.arange(0, n): 166 | shade = compute_random_shadow(im.shape, intensity) 167 | sh_im -= shade 168 | sh_im = np.maximum(0, sh_im) 169 | 170 | return sh_im 171 | 172 | def rotate_3d_scipy(image, gt): 173 | #if image.dtype != 'float32': 174 | # image = np.asarray(image, dtype='float32') 175 | #if gt.dtype != 'float32': 176 | # gt = np.asarray(gt, dtype='float32') 177 | 178 | ang = np.random.uniform(0, 360) 179 | axes = (1,2)#np.random.permutation(3)[:2] 180 | rot_im = rotate_scipy(image, ang, axes=axes, order=1, reshape=False) 181 | rot_gt = rotate_scipy(gt, ang, axes=axes, order=0, reshape=False) 182 | 183 | return rot_im, rot_gt 184 | 185 | def prep2(gt): 186 | indices0 = np.where(gt < 0.5) 187 | indices1 = np.where((gt >= 0.5) & (gt < 1.5)) 188 | indices2 = np.where((gt >= 1.5) & (gt < 2.5)) 189 | indices3 = np.where((gt >= 2.5) & (gt < 3.5)) 190 | indices4 = np.where(gt >= 3.5) 191 | 192 | res = np.zeros(gt.shape, dtype='int8') 193 | res[indices0] = 0 194 | res[indices1] = 1 195 | res[indices2] = 2 196 | res[indices3] = 3 197 | res[indices4] = 4 198 | 199 | return res 200 | 201 | def vis_col_im(im, gt): 202 | indices_0 = np.where(gt == 0) # nothing 203 | indices_1 = np.where(gt == 1) # necrosis 204 | indices_2 = np.where(gt == 2) # edema 205 | indices_3 = np.where(gt == 3) # non-enhancing tumor 206 | indices_4 = np.where(gt == 4) # enhancing tumor 207 | 208 | im = np.asarray(im, dtype='float32') 209 | im = im*1./im.max() 210 | rgb_image = color.gray2rgb(im) 211 | m0 = [1., 1., 1.] 212 | m1 = [1., 0., 0.] 213 | m2 = [0.2, 1., 0.2] 214 | m3 = [1., 1., 0.2] 215 | m4 = [1., 0.6, 0.2] 216 | 217 | im = rgb_image.copy() 218 | im[indices_0[0], indices_0[1], :] *= m0 219 | im[indices_1[0], indices_1[1], :] *= m1 220 | im[indices_2[0], indices_2[1], :] *= m2 221 | im[indices_3[0], indices_3[1], :] *= m3 222 | im[indices_4[0], indices_4[1], :] *= m4 223 | 224 | plt.imshow(im) 225 | plt.show() 226 | plt.close() 227 | 228 | def col_im(im, gt): 229 | im = np.asarray(im, dtype='float32') 230 | im = im*1./im.max() 231 | rgb_image = color.gray2rgb(im) 232 | im = rgb_image.copy() 233 | 234 | if gt is None: 235 | return im 236 | 237 | indices_0 = np.where(gt == 0) # nothing 238 | indices_1 = np.where(gt == 1) # necrosis 239 | indices_2 = np.where(gt == 2) # edema 240 | indices_3 = np.where(gt == 3) # non-enhancing tumor 241 | indices_4 = np.where(gt == 4) # enhancing tumor 242 | 243 | m0 = [1., 1., 1.] 244 | m1 = [1., 0., 0.] # red: necrosis 245 | m2 = [0.2, 1., 0.2] # green: edema 246 | m3 = [1., 1., 0.2] # yellow: non-enhancing tumor 247 | m4 = [1., 0.6, 0.2] # orange: enhancing tumor 248 | 249 | im[indices_0[0], indices_0[1], :] *= m0 250 | im[indices_1[0], indices_1[1], :] *= m1 251 | im[indices_2[0], indices_2[1], :] *= m2 252 | im[indices_3[0], indices_3[1], :] *= m3 253 | im[indices_4[0], indices_4[1], :] *= m4 254 | 255 | return im 256 | 257 | def vis_ims(im0, gt0, im1, gt1, title0='Original', title1='Transformed'): 258 | im0 = col_im(im0, gt0) 259 | im1 = col_im(im1, gt1) 260 | 261 | fig = plt.figure() 262 | a = fig.add_subplot(1,2,1) 263 | plt.imshow(im0) 264 | a.set_title(title0) 265 | a = fig.add_subplot(1,2,2) 266 | plt.imshow(im1) 267 | a.set_title(title1) 268 | 269 | plt.show() 270 | plt.close() 271 | 272 | def vis_hems(left, gt_left, right, gt_right): 273 | left = col_im(left, gt_left) 274 | right = col_im(right, gt_right) 275 | 276 | fig = plt.figure() 277 | a = fig.add_subplot(1,2,1) 278 | plt.imshow(left) 279 | a.set_title('Left Hemisphere') 280 | a = fig.add_subplot(1,2,2) 281 | plt.imshow(right) 282 | a.set_title('Right Hemisphere') 283 | 284 | plt.show() 285 | plt.close() 286 | 287 | def vis_diff_modalities(*ims): 288 | flair, t1, t1_c, t2, gt = ims 289 | flair, t1, t1_c, t2 = [col_im(x, gt) for x in ims[:-1]] 290 | fig = plt.figure() 291 | a = fig.add_subplot(2,2,1) 292 | plt.imshow(flair) 293 | a.set_title('Flair') 294 | plt.axis('off') 295 | a = fig.add_subplot(2,2,2) 296 | plt.imshow(t1) 297 | a.set_title('T1') 298 | plt.axis('off') 299 | a = fig.add_subplot(2,2,3) 300 | plt.imshow(t1_c) 301 | a.set_title('T1c') 302 | plt.axis('off') 303 | a = fig.add_subplot(2,2,4) 304 | plt.imshow(t2) 305 | a.set_title('T2') 306 | plt.axis('off') 307 | 308 | plt.show() 309 | 310 | def show_brains(): 311 | for im in gen_images(n=-1): 312 | t_im = im['T1c'] 313 | gt = im['gt'] 314 | 315 | for _slice in np.arange(0, t_im.shape[0], t_im.shape[0]/15): 316 | im_slice = t_im[_slice] 317 | gt_slice = gt[_slice] 318 | 319 | vis_col_im(im=im_slice, gt=gt_slice) 320 | 321 | def show_modalities(): 322 | for im in gen_images(n=-1, crop=True): 323 | ims = [im['Flair'], im['T1'], im['T1c'], im['T2'], None] 324 | for _slice in np.arange(0, ims[0].shape[0], ims[0].shape[0]/20): 325 | im_slices = [x[_slice] if x is not None else x for x in ims] 326 | 327 | vis_diff_modalities(*im_slices) 328 | 329 | def show_downsize(): 330 | for im in gen_images(n=-1, crop=True): 331 | t_im = im['T1c'] 332 | gt = im['gt'] 333 | 334 | t_im = np.asarray(t_im, dtype='float32') 335 | gt = np.asarray(gt, dtype='float32') 336 | 337 | d_im = zoom(t_im, 0.5, order=3) 338 | d_gt = zoom(gt, 0.5, order=0) 339 | print 'New shape: ', d_im.shape 340 | 341 | slices1 = np.arange(0, d_im.shape[0], d_im.shape[0]/20) 342 | slices2 = np.arange(0, t_im.shape[0], t_im.shape[0]/20) 343 | 344 | for s1, s2 in zip(slices1, slices2): 345 | d_im_slice = d_im[s1] 346 | d_gt_slice = d_gt[s1] 347 | 348 | im_slice = t_im[s2] 349 | gt_slice = gt[s2] 350 | 351 | title0= 'Original' 352 | title1= 'Downsized' 353 | vis_ims(im0=im_slice, gt0=gt_slice, im1=d_im_slice, 354 | gt1=d_gt_slice, title0=title0, title1=title1) 355 | 356 | def show_crops(): 357 | x_c = 119 358 | y_c = 119 359 | z_c = 77 360 | 361 | count = 1 362 | for im in gen_images(n=-1, crop=True): 363 | print 'image %i: ' % count 364 | t_im = im['T1c'] 365 | gt = im['gt'] 366 | print t_im.shape 367 | for _slice in np.arange(0, t_im.shape[0], t_im.shape[0]/20): 368 | im_slice = t_im[_slice] 369 | gt_slice = gt[_slice] 370 | 371 | vis_col_im(im=im_slice, gt=gt_slice) 372 | count += 1 373 | 374 | def show_hemisphere(): 375 | x_c = 119 376 | y_c = 119 377 | z_c = 77 378 | 379 | for im in gen_images(n=-1, crop=True): 380 | t_im = im['T1c'] 381 | gt = im['gt'] 382 | 383 | left = t_im[:,:,:t_im.shape[-1]/2] 384 | gt_left = gt[:,:,:gt.shape[-1]/2] 385 | 386 | right = t_im[:,:,t_im.shape[-1]/2:] 387 | gt_right = gt[:,:,gt.shape[-1]/2:] 388 | 389 | for _slice in np.arange(0, t_im.shape[0], t_im.shape[0]/20): 390 | l_slice = left[_slice] 391 | gt_l_slice = gt_left[_slice] 392 | r_slice = right[_slice] 393 | gt_r_slice = gt_right[_slice] 394 | 395 | vis_hems(left=l_slice, gt_left=gt_l_slice, right=r_slice, gt_right=gt_r_slice) 396 | 397 | def show_rotation(): 398 | for im in gen_images(n=-1, crop=True): 399 | t_im = im['T1c'] 400 | gt = im['gt'] 401 | 402 | rot_im, rot_gt = rotate_3d_scipy(t_im, gt) 403 | rot_gt = np.asarray(rot_gt, dtype='int8') 404 | #rot_gt = prep2(rot_gt) 405 | for _slice in np.arange(0, rot_im.shape[0], rot_im.shape[0]/20): 406 | im_slice = rot_im[_slice] 407 | gt_slice = rot_gt[_slice] 408 | 409 | vis_col_im(im=im_slice, gt=gt_slice) 410 | 411 | def show_transform(): 412 | for im in gen_images(n=-1, crop=True): 413 | t_im = im['T1c'] 414 | gt = im['gt'] 415 | #t_im_trans, trans_gt = rotate_transform(t_im, gt) 416 | #t_im_trans = t_im 417 | #t_im_trans = re_rescale(t_im) 418 | #t_im_trans = flip(t_im) 419 | #t_im_trans = noise(t_im, intensity=1, n=10) 420 | t_im_trans, trans_gt = ndi.percentile_filter(t_im, np.random.randint(0, 10), (2, 2, 2)), gt 421 | #t_im_trans = ndi.morphological_gradient(t_im, size=(2, 2, 2)) 422 | #t_im_trans = ndi.grey_dilation(t_im, size=(3, 3, 3)) 423 | #t_im_trans = ndi.grey_erosion(t_im_trans, size=(3, 3, 3)) 424 | 425 | print t_im_trans.dtype 426 | 427 | for _slice in np.arange(0, t_im.shape[0], t_im.shape[0]/20): 428 | im_slice = t_im[_slice] 429 | im_slice_trans = t_im_trans[_slice] 430 | gt_slice = gt[_slice] 431 | trans_gt_slice = trans_gt[_slice] 432 | 433 | vis_ims(im0=im_slice, gt0=gt_slice, im1=im_slice_trans, gt1=trans_gt_slice) 434 | 435 | if __name__ == '__main__': 436 | #show_brains() 437 | #show_modalities() 438 | #show_downsize() 439 | show_crops() 440 | #show_hemisphere() 441 | #show_rotation() 442 | #show_transform() -------------------------------------------------------------------------------- /conv3d/README.md: -------------------------------------------------------------------------------- 1 | The modules in this package implement the convolutional neural networks used in the experiments. 2 | -------------------------------------------------------------------------------- /conv3d/__init__.py: -------------------------------------------------------------------------------- 1 | import model 2 | -------------------------------------------------------------------------------- /conv3d/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BRML/CNNbasedMedicalSegmentation/49cacabf51cd1187d8b677532ef635d60f12bb1a/conv3d/__init__.pyc -------------------------------------------------------------------------------- /conv3d/basic.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BRML/CNNbasedMedicalSegmentation/49cacabf51cd1187d8b677532ef635d60f12bb1a/conv3d/basic.pyc -------------------------------------------------------------------------------- /conv3d/cnn3d.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BRML/CNNbasedMedicalSegmentation/49cacabf51cd1187d8b677532ef635d60f12bb1a/conv3d/cnn3d.pyc -------------------------------------------------------------------------------- /conv3d/model.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | from theano.tensor.type import TensorType 4 | 5 | import numpy as np 6 | 7 | from breze.arch.component.varprop import loss as vp_loss 8 | from breze.arch.construct.simple import SupervisedLoss 9 | from breze.arch.util import ParameterSet, lookup 10 | from breze.learn.base import SupervisedModel 11 | 12 | import cnn3d 13 | from basic import SupervisedMultiLoss 14 | 15 | def tensor5(name=None, dtype=None): 16 | """ 17 | Returns a symbolic 5D tensor variable. 18 | """ 19 | if dtype is None: 20 | dtype = theano.config.floatX 21 | 22 | type = TensorType(dtype, ((False,)*5)) 23 | return type(name) 24 | 25 | class SequentialModel(SupervisedModel): 26 | ''' 27 | Sequential model consisting of multiple layers of different types. 28 | ''' 29 | def __init__(self, image_height, image_width, 30 | image_depth, n_channels, n_output, 31 | layer_vars, out_transfer, loss_id=None, 32 | loss_layer_def=None, optimizer='adam', 33 | batch_size=1, max_iter=1000, 34 | using_bn=False, regularize=False, 35 | l1=None, l2=None, perform_transform=None, 36 | verbose=False): 37 | ''' 38 | :param n_output: Number of classes 39 | :param layer_vars: list of dictionaries. Each dictionary specifies 40 | the type of a layer and the values of parameters 41 | associated with that kind of layer. 42 | Possible layers: conv, pool, deconv 43 | conv params: fs(:=filter shape), nkerns(:=feature maps), 44 | transfer(:=non-linearity used), bm(:=border-mode), 45 | stride(:=convolution stride), imp(:=theano implementation), 46 | bias(:=True to use biases, False to omit them) 47 | pool params: ps(:=pool shape), transfer(:=non-linearity used) 48 | deconv params: fs(:=filter shape), nkerns(:=feature maps), 49 | transfer(:=non-linearity used), imp(:=theano implementation), 50 | bias(:=True to use biases, False to omit them), 51 | up(:=upsampling factor) 52 | shortcut params: shortcut computes src.output + dst.output 53 | src(:=index of src_layer), dst(:= index of dst_layer), 54 | proj(:=projection to be used if src and dst have different 55 | numbers of feature maps, 'zero_pad' or 'project'), 56 | transfer(:=non_linearity used), imp(:=theano implementation, 57 | only relevant if src and dst have different numbers of 58 | feature maps and proj=='project') 59 | non_linearity params: transfer(:=non-linearity used) 60 | see: basic.Conv3d, basic.MaxPool3d, basic.Deconv, basic.Shortcut, 61 | basic.NonLinearity 62 | 63 | :param out_transfer: output non-linearity, has to be a callable 64 | :param loss_id: loss function, has to be a callable or the name 65 | of a loss function included in breze.arch.component.loss 66 | :param optimizer: name of an optimizer supported by climin 67 | :param batch_size: size of an input batch 68 | :param max_iter: maximum number of training iterations 69 | ''' 70 | self.image_height = image_height 71 | self.image_width = image_width 72 | self.image_depth = image_depth 73 | self.n_channels = n_channels 74 | self.n_output = n_output 75 | self.layer_vars = layer_vars 76 | self.out_transfer = out_transfer 77 | self.optimizer = optimizer 78 | self.batch_size = batch_size 79 | self.max_iter = max_iter 80 | self.using_bn = using_bn 81 | self.verbose = verbose 82 | self.regularize = regularize 83 | if self.regularize: 84 | if l1 is None and l2 is None: 85 | raise ValueError('Asked to use regularization but no input for l1 or l2.') 86 | else: 87 | self.l1 = l1 88 | self.l2 = l2 89 | 90 | if loss_layer_def is None: 91 | if loss_id is not None: 92 | self.loss_id = loss_id 93 | self.loss_layer_def = None 94 | else: 95 | raise ValueError('Either loss id or loss layer definition has to be specified.') 96 | else: 97 | if loss_id is None: 98 | self.loss_layer_def = loss_layer_def 99 | self.loss_id = None 100 | else: 101 | raise ValueError('loss_id and loss_layer_def can not be used at the same time.') 102 | self.perform_transform = perform_transform 103 | 104 | self._init_exprs() 105 | 106 | if self.using_bn: 107 | self.ext_phase = 0 108 | self.phase_select = self._phase_select 109 | self.reset_phase = self._reset_phase 110 | else: 111 | self.ext_phase = None 112 | self.phase_select = None 113 | self.reset_phase = None 114 | 115 | def _make_loss_layer(self, lv, target, declare, imp_weight): 116 | mode = lv['mode'] 117 | if mode == 'weighted': 118 | p_weights = lv['weights'] 119 | else: 120 | p_weights = None 121 | loss_fun = lv['loss_fun'] 122 | p_indices = lv['predictions'] 123 | transfer = lv['transfer'] if 'transfer' in lv else self.out_transfer 124 | 125 | layers = self.conv_net.layers 126 | 127 | predictions = [] 128 | for p_i in p_indices: 129 | predictions.append(layers[p_i].get_output()) 130 | 131 | self.loss_layer = SupervisedMultiLoss( 132 | target=target, predictions=predictions, 133 | loss=loss_fun, mode=mode, 134 | p_weights=p_weights, imp_weight=imp_weight, 135 | transfer=transfer, declare=declare 136 | ) 137 | 138 | def _init_exprs(self): 139 | inpt = tensor5('inpt') 140 | target = T.tensor3('target') 141 | 142 | parameters = ParameterSet() 143 | 144 | self.conv_net = cnn3d.SequentialModel( 145 | inpt=inpt, image_height=self.image_height, 146 | image_width=self.image_width, image_depth=self.image_depth, 147 | n_channels=self.n_channels, out_transfer=self.out_transfer, 148 | layer_vars=self.layer_vars, using_bn=self.using_bn, 149 | declare=parameters.declare 150 | ) 151 | 152 | output = self.conv_net.output 153 | 154 | if self.imp_weight: 155 | imp_weight = T.matrix('imp_weight') 156 | else: 157 | imp_weight = None 158 | 159 | if self.loss_id is not None: 160 | self.loss_layer = SupervisedLoss( 161 | target, output, loss=self.loss_id, 162 | imp_weight=imp_weight, declare=parameters.declare 163 | ) 164 | else: 165 | self._make_loss_layer( 166 | lv=self.loss_layer_def, target=target, 167 | imp_weight=imp_weight, declare=parameters.declare 168 | ) 169 | 170 | SupervisedModel.__init__(self, inpt=inpt, target=target, 171 | output=output, 172 | loss=self.loss_layer.total, 173 | parameters=parameters) 174 | 175 | self.exprs['imp_weight'] = imp_weight 176 | if self.regularize: 177 | self.exprs['true_loss'] = self.exprs['loss'].copy() 178 | if self.l2 is not None: 179 | l2_reg = T.sum(T.sqr(self.parameters.flat)) * self.l2 / 2 180 | self.exprs['loss'] += l2_reg 181 | if self.l1 is not None: 182 | l1_reg = T.sum(T.abs_(self.parameters.flat)) * self.l1 183 | self.exprs['loss'] += l1_reg 184 | 185 | def _phase_select(self, phase_id): 186 | if phase_id == 'train': 187 | phase = 0 188 | elif phase_id == 'valid' or phase_id == 'infer': 189 | phase = 1 190 | else: 191 | raise ValueError('Phases are: train, valid, infer') 192 | 193 | self.ext_phase = phase 194 | for l_index in self.conv_net.bn_layers: 195 | self.conv_net.layers[l_index].set_phase(phase) 196 | 197 | def _reset_phase(self): 198 | self.phase_select(phase_id='train') 199 | 200 | def get_batchnorm_params(self): 201 | batchnorm_params = [] 202 | for l_index in self.conv_net.bn_layers: 203 | mean, std = self.conv_net.layers[l_index].submit() 204 | if not isinstance(mean, np.ndarray): 205 | mean = mean.as_numpy_array() 206 | std = std.as_numpy_array() 207 | 208 | mean = np.asarray(mean, dtype='float32') 209 | std = np.asarray(std, dtype='float32') 210 | 211 | mean_and_std = (mean, std) 212 | batchnorm_params.append(mean_and_std) 213 | 214 | return batchnorm_params 215 | 216 | def set_batchnorm_params(self, batchnorm_params): 217 | index = 0 218 | for l_index in self.conv_net.bn_layers: 219 | mean, std = batchnorm_params[index] 220 | 221 | self.conv_net.layers[l_index].running_mean.set_value(mean) 222 | self.conv_net.layers[l_index].running_std.set_value(std) 223 | index += 1 224 | 225 | def get_params(self): 226 | layers = self.conv_net.layers 227 | params = self.parameters 228 | 229 | for i, l in enumerate(layers): 230 | if hasattr(l, 'weights'): 231 | w = params[l.weights] 232 | else: 233 | w = None 234 | if hasattr(l, 'bias'): 235 | b = params[l.bias] 236 | else: 237 | b = None 238 | yield(w, b, i) 239 | 240 | def initialize_xavier_weights(self): 241 | layers = self.conv_net.layers 242 | params = self.parameters 243 | 244 | for layer in layers: 245 | if hasattr(layer, 'weights'): 246 | w = layer.get_weights() 247 | fan_in = layer.get_fan_in() 248 | params[w] = np.random.normal(0., 1., params[w].shape) * np.sqrt(2./fan_in) 249 | elif hasattr(layer, 'a'): 250 | a = layer.a 251 | params[a] = 0.25 252 | 253 | 254 | class FCN(SupervisedModel): 255 | def __init__(self, image_height, image_width, image_depth, 256 | n_channel, n_output, n_hiddens_conv, down_filter_shapes, 257 | hidden_transfers_conv, down_pools, n_hiddens_upconv, 258 | up_filter_shapes, hidden_transfers_upconv, up_pools, 259 | out_transfer, loss, optimizer='adam', 260 | bm_up='same', bm_down='same', 261 | batch_size=1, max_iter=1000, 262 | strides_d=(1, 1, 1), up_factors=(2, 2, 2), 263 | verbose=False, implementation=False): 264 | assert len(hidden_transfers_conv) == len(n_hiddens_conv) 265 | assert len(down_filter_shapes) == len(n_hiddens_conv) 266 | assert len(down_pools) == len(n_hiddens_conv) 267 | 268 | assert len(hidden_transfers_upconv) == len(n_hiddens_upconv) 269 | assert len(up_filter_shapes) == len(n_hiddens_upconv) 270 | assert len(up_pools) == len(n_hiddens_upconv) 271 | 272 | self.image_height = image_height 273 | self.image_width = image_width 274 | self.image_depth = image_depth 275 | self.n_channel = n_channel 276 | self.n_output = n_output 277 | self.n_hiddens_conv = n_hiddens_conv 278 | self.down_filter_shapes = down_filter_shapes 279 | self.hidden_transfers_conv = hidden_transfers_conv 280 | self.down_pools = down_pools 281 | self.n_hiddens_upconv = n_hiddens_upconv 282 | self.up_filter_shapes = up_filter_shapes 283 | self.hidden_transfers_upconv = hidden_transfers_upconv 284 | self.up_pools = up_pools 285 | self.out_transfer = out_transfer 286 | self.loss_ident = loss 287 | self.optimizer = optimizer 288 | self.bm_down = bm_down 289 | self.bm_up = bm_up 290 | self.batch_size = batch_size 291 | self.max_iter = max_iter 292 | self.verbose = verbose 293 | self.implementation = implementation 294 | self.strides_d = strides_d 295 | self.up_factors = up_factors 296 | 297 | self._init_exprs() 298 | 299 | def _init_exprs(self): 300 | inpt = tensor5('inpt') 301 | #inpt.tag.test_value = np.zeros(( 302 | # 2, self.image_depth, self.n_channel, 303 | # self.image_height, self.image_width 304 | #)) 305 | 306 | target = T.tensor3('target') 307 | #target.tag.test_value = np.zeros(( 308 | # 2,self.image_depth*self.image_width*self.image_height, self.n_output 309 | #)) 310 | 311 | parameters = ParameterSet() 312 | 313 | self.conv_net = cnn3d.FCN( 314 | inpt=inpt, image_height=self.image_height, 315 | image_width=self.image_width, image_depth=self.image_depth, 316 | n_channel=self.n_channel, n_hiddens_conv=self.n_hiddens_conv, 317 | hidden_transfers_conv=self.hidden_transfers_conv, 318 | n_hiddens_upconv=self.n_hiddens_upconv, 319 | hidden_transfers_upconv=self.hidden_transfers_upconv, 320 | d_filter_shapes=self.down_filter_shapes, 321 | u_filter_shapes=self.up_filter_shapes, 322 | down_pools=self.down_pools, 323 | up_pools=self.up_pools, 324 | out_transfer=self.out_transfer, 325 | b_modes_down=self.bm_down, 326 | b_modes_up=self.bm_up, 327 | implementation=self.implementation, 328 | strides_down=self.strides_d, 329 | up_factors=self.up_factors, 330 | declare=parameters.declare 331 | ) 332 | 333 | output = self.conv_net.output 334 | 335 | if self.imp_weight: 336 | imp_weight = T.matrix('imp_weight') 337 | else: 338 | imp_weight = None 339 | 340 | self.loss_layer = SupervisedLoss( 341 | target, output, loss=self.loss_ident, 342 | imp_weight=imp_weight, declare=parameters.declare 343 | ) 344 | 345 | SupervisedModel.__init__(self, inpt=inpt, target=target, 346 | output=output, 347 | loss=self.loss_layer.sample_wise.mean(), 348 | parameters=parameters) 349 | 350 | self.exprs['imp_weight'] = imp_weight 351 | 352 | 353 | class ConvNet3d(SupervisedModel): 354 | def __init__(self, image_height, image_width, image_depth, 355 | n_channel, n_hiddens_conv, filter_shapes, pool_shapes, 356 | n_hiddens_full, n_output, hidden_transfers_conv, 357 | hidden_transfers_full, out_transfer, loss, optimizer='adam', 358 | batch_size=1, max_iter=1000, verbose=False, border_modes='valid', 359 | implementation='dnn_conv3d', 360 | dropout=False): 361 | """Flexible Convolutional neural network model 362 | 363 | Some key things to know: 364 | :param pool_shapes: list of 3-tuples or string-flags. e.g: 365 | [(2,2,2), 'no_pool', (2,2,2)] 366 | 'no_pool' is to skip pooling whenever necessary. 367 | 368 | Future work: 369 | This model shouldn't actually have fully-connected layers. Rather, it should 370 | turn fully-connected layers into convolutional layers as follows: 371 | FC layer that takes (4*4*4)*10 inpt and outputs 1000 neurons will be turned 372 | into a convolutional layer with 4*4*4 receptive fields outputting on 1000 feature 373 | maps, thus producing a (1*1*1)*1000 output. If the output of the classification layer 374 | has 3 neurons (3 classes), then after the conversion you'll get a (1*1*1)*3 output, which 375 | will be reshaped to 3 neurons afterwards. 376 | """ 377 | assert len(hidden_transfers_conv) == len(n_hiddens_conv) 378 | assert len(n_hiddens_conv) == len(filter_shapes) 379 | 380 | assert len(pool_shapes) == len(filter_shapes) 381 | assert len(hidden_transfers_full) == len(n_hiddens_full) 382 | 383 | self.image_height = image_height 384 | self.image_width = image_width 385 | self.image_depth = image_depth 386 | self.n_channel = n_channel 387 | self.n_hiddens_conv = n_hiddens_conv 388 | self.filter_shapes = filter_shapes 389 | self.pool_shapes = pool_shapes 390 | self.n_hiddens_full = n_hiddens_full 391 | self.n_output = n_output 392 | self.hidden_transfers_conv = hidden_transfers_conv 393 | self.hidden_transfers_full = hidden_transfers_full 394 | self.out_transfer = out_transfer 395 | self.loss_ident = loss 396 | self.optimizer = optimizer 397 | self.batch_size = batch_size 398 | self.max_iter = max_iter 399 | self.verbose = verbose 400 | self.implementation = implementation 401 | 402 | self.dropout = dropout 403 | 404 | self.border_modes = border_modes 405 | 406 | self._init_exprs() 407 | 408 | def _init_exprs(self): 409 | inpt = tensor5('inpt') 410 | inpt.tag.test_value = np.zeros(( 411 | 2, self.image_depth, self.n_channel, 412 | self.image_height, self.image_width 413 | )) 414 | 415 | target = T.matrix('target') 416 | target.tag.test_value = np.zeros(( 417 | 2, self.n_output 418 | )) 419 | 420 | parameters = ParameterSet() 421 | 422 | if self.dropout: 423 | self.p_dropout_inpt = .2 424 | self.p_dropout_hiddens = [.5] * len(self.n_hiddens_full) 425 | else: 426 | self.p_dropout_inpt = None 427 | self.p_dropout_hiddens = None 428 | 429 | self.conv_net = cnn3d.ConvNet3d( 430 | inpt=inpt, image_height=self.image_height, 431 | image_width=self.image_width, image_depth=self.image_depth, 432 | n_channel=self.n_channel, n_hiddens_conv=self.n_hiddens_conv, 433 | filter_shapes=self.filter_shapes, pool_shapes=self.pool_shapes, 434 | n_hiddens_full=self.n_hiddens_full, 435 | hidden_transfers_conv=self.hidden_transfers_conv, 436 | hidden_transfers_full=self.hidden_transfers_full, n_output=self.n_output, 437 | out_transfer=self.out_transfer, 438 | border_modes=self.border_modes, 439 | declare=parameters.declare, 440 | implementation=self.implementation, 441 | dropout=self.dropout, p_dropout_inpt=self.p_dropout_inpt, 442 | p_dropout_hiddens=self.p_dropout_hiddens 443 | ) 444 | 445 | output = self.conv_net.output 446 | 447 | if self.imp_weight: 448 | imp_weight = T.matrix('imp_weight') 449 | else: 450 | imp_weight = None 451 | 452 | if not self.dropout: 453 | loss_id = self.loss_ident 454 | else: 455 | loss_id = lookup(self.loss_ident, vp_loss) 456 | 457 | self.loss_layer = SupervisedLoss( 458 | target, output, loss=loss_id, 459 | imp_weight=imp_weight, declare=parameters.declare 460 | ) 461 | 462 | SupervisedModel.__init__(self, inpt=inpt, target=target, 463 | output=output, 464 | loss=self.loss_layer.total, 465 | parameters=parameters) 466 | 467 | self.exprs['imp_weight'] = imp_weight 468 | 469 | 470 | class Lenet3d(SupervisedModel): 471 | 472 | def __init__(self, image_height, image_width, image_depth, 473 | n_channel, n_hiddens_conv, filter_shapes, pool_shapes, 474 | n_hiddens_full, n_output, hidden_transfers_conv, 475 | hidden_transfers_full, out_transfer, loss, optimizer='adam', 476 | batch_size=1, max_iter=1000, verbose=False, implementation='dnn_conv3d', 477 | pool=True): 478 | assert len(hidden_transfers_conv) == len(n_hiddens_conv) 479 | assert len(n_hiddens_conv) == len(filter_shapes) 480 | 481 | assert len(pool_shapes) == len(filter_shapes) 482 | assert len(hidden_transfers_full) == len(n_hiddens_full) 483 | 484 | self.image_height = image_height 485 | self.image_width = image_width 486 | self.image_depth = image_depth 487 | self.n_channel = n_channel 488 | self.n_hiddens_conv = n_hiddens_conv 489 | self.n_hiddens_full = n_hiddens_full 490 | self.filter_shapes = filter_shapes 491 | self.pool_shapes = pool_shapes 492 | self.n_output = n_output 493 | self.hidden_transfers_conv = hidden_transfers_conv 494 | self.hidden_transfers_full = hidden_transfers_full 495 | self.out_transfer = out_transfer 496 | self.loss_ident = loss 497 | self.optimizer = optimizer 498 | self.batch_size = batch_size 499 | self.max_iter = max_iter 500 | self.verbose = verbose 501 | self.implementation=implementation 502 | self.pool = pool 503 | 504 | self._init_exprs() 505 | 506 | def _init_exprs(self): 507 | inpt = tensor5('inpt') 508 | inpt.tag.test_value = np.zeros(( 509 | 2, self.image_depth, self.n_channel, 510 | self.image_height, self.image_width 511 | )) 512 | 513 | target = T.matrix('target') 514 | target.tag.test_value = np.zeros(( 515 | 2, self.n_output 516 | )) 517 | 518 | parameters = ParameterSet() 519 | 520 | self.lenet = cnn3d.Lenet3d( 521 | inpt, self.image_height, 522 | self.image_width, self.image_depth, 523 | self.n_channel, self.n_hiddens_conv, 524 | self.filter_shapes, self.pool_shapes, 525 | self.n_hiddens_full, self.hidden_transfers_conv, 526 | self.hidden_transfers_full, self.n_output, 527 | self.out_transfer, 528 | declare=parameters.declare, 529 | implementation=self.implementation, 530 | pool=self.pool 531 | ) 532 | 533 | if self.imp_weight: 534 | imp_weight = T.matrix('imp_weight') 535 | else: 536 | imp_weight = None 537 | 538 | self.loss_layer = SupervisedLoss( 539 | target, self.lenet.output, loss=self.loss_ident, 540 | imp_weight=imp_weight, declare=parameters.declare 541 | ) 542 | 543 | SupervisedModel.__init__(self, inpt=inpt, target=target, 544 | output=self.lenet.output, 545 | loss=self.loss_layer.total, 546 | parameters=parameters) 547 | 548 | self.exprs['imp_weight'] = imp_weight 549 | -------------------------------------------------------------------------------- /conv3d/model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BRML/CNNbasedMedicalSegmentation/49cacabf51cd1187d8b677532ef635d60f12bb1a/conv3d/model.pyc -------------------------------------------------------------------------------- /data/build_dataset.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import gzip 3 | import numpy as np 4 | import h5py 5 | import json 6 | import os 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | from breze.learn.data import one_hot 11 | 12 | d_code = 'handsize2_v2' 13 | data_dir = 'datasets' 14 | hdf_code = d_code + 'fcn.hdf5' 15 | data = os.path.join(data_dir, hdf_code) 16 | h_file = h5py.File('datasets//handsize2_v2.hdf5', 'w') 17 | 18 | b_size = 6 19 | train_n = 20 20 | valid_n = 5 21 | train_size = train_n*b_size 22 | valid_size = test_size = valid_n*b_size 23 | 24 | #d = int(d_code) 25 | dims = (64, 64, 192) 26 | dimprod = dims[0] * dims[1] * dims[2] 27 | 28 | print 'Train size: ', train_size 29 | print 'Valid size: ', valid_size 30 | print 'Test size: ', test_size 31 | 32 | full_x = np.zeros((30, b_size, dims[2], 1, dims[0], dims[1]), dtype='float32') 33 | full_y = np.zeros((30, b_size, dimprod, 2), dtype='float32') 34 | 35 | train_x = h_file.create_dataset( 36 | 'train_x', (train_size, dims[2], 1, dims[0], dims[1]), dtype='float32' 37 | ) 38 | train_y = h_file.create_dataset( 39 | 'train_y', (train_size,dimprod,2), dtype='float32' 40 | ) 41 | 42 | valid_x = h_file.create_dataset( 43 | 'valid_x', (valid_size, dims[2], 1, dims[0], dims[1]), dtype='float32' 44 | ) 45 | valid_y = h_file.create_dataset( 46 | 'valid_y', (valid_size,dimprod,2), dtype='float32' 47 | ) 48 | 49 | test_x = h_file.create_dataset( 50 | 'test_x', (test_size, dims[2], 1, dims[0], dims[1]), dtype='float32' 51 | ) 52 | test_y = h_file.create_dataset( 53 | 'test_y', (test_size,dimprod,2), dtype='float32' 54 | ) 55 | 56 | count = 1 57 | index = 0 58 | p_code = 'fcn' + d_code 59 | 60 | means = np.zeros((30,b_size)) 61 | 62 | def shuffle_ims(ims, gts): 63 | shuffle_indices = np.random.permutation(ims.shape[0]) 64 | 65 | ims = ims[shuffle_indices] 66 | gts = gts[shuffle_indices] 67 | return ims, gts 68 | 69 | def z_mirror_ims(ims, gts): 70 | rand_indices = np.random.randint(low=0, high=ims.shape[0], size=(ims.shape[0]/2,)) 71 | ims[rand_indices] = ims[rand_indices,:,:,::-1] 72 | 73 | gt_rands = np.reshape( 74 | gts[rand_indices], (len(rand_indices), ims.shape[1], ims.shape[2], ims.shape[3]) 75 | ) 76 | gt_rands = gt_rands[:,:,:,::-1] 77 | gts[rand_indices] = np.reshape(gt_rands, (len(rand_indices),-1)) 78 | 79 | return ims, gts 80 | 81 | while count <= 30: 82 | print 'Reading image ', count 83 | #with gzip.open('../patches/patches'+p_code+'/im'+str(count)+'.pkl.gz', 'rb') as f: 84 | with gzip.open('../patches/noisy_bi_images/im'+str(count)+'.pkl.gz', 'rb') as f: 85 | patches, labels = pickle.load(f) 86 | 87 | patches, labels = shuffle_ims(patches, labels) 88 | patches, labels = z_mirror_ims(patches, labels) 89 | print labels.shape 90 | for i, lbl in enumerate(labels): 91 | mean = lbl.mean() 92 | means[count-1,i] = mean 93 | print 'Mean value: ', mean 94 | labels = np.reshape(labels, (-1,)) 95 | labels = np.asarray(labels, dtype='int16') 96 | full_x[index,:,:,0,:,:] = np.transpose(patches, (0, 3, 1, 2)) 97 | l = np.reshape(one_hot(labels, 2), (b_size, dimprod, -1)) 98 | full_y[index,:,:,:] = np.asarray(l, dtype='float32') 99 | #plt.imshow(full_x[index,dims[2]/2,0,:,:], cmap='Greys_r') 100 | #plt.show() 101 | gt = np.reshape(labels[:dimprod], dims) 102 | #plt.imshow(gt[:,:,dims[2]/2], cmap='Greys_r') 103 | #plt.show() 104 | index += 1 105 | count += 1 106 | 107 | rand_indices = np.random.permutation(30) 108 | #rand_indices = np.arange(30) 109 | 110 | train_x[:,:,:,:,:] = np.reshape(full_x[rand_indices[:train_n],:,:,:,:,:], train_x.shape) 111 | train_y[:,:,:] = np.reshape(full_y[rand_indices[:train_n],:,:,:], train_y.shape) 112 | train_mean = means[rand_indices[:train_n],:] 113 | train_mean = train_mean.mean() 114 | 115 | valid_x[:,:,:,:,:] = np.reshape(full_x[rand_indices[train_n:train_n+valid_n],:,:,:,:,:], valid_x.shape) 116 | valid_y[:,:,:] = np.reshape(full_y[rand_indices[train_n:train_n+valid_n],:,:,:], valid_y.shape) 117 | valid_mean = means[rand_indices[train_n:train_n+valid_n],:] 118 | valid_mean = valid_mean.mean() 119 | 120 | take = full_x[rand_indices[train_n+valid_n:],:,:,:,:,:] 121 | test_x[:,:,:,:,:] = np.reshape(take, test_x.shape) 122 | test_y[:,:,:] = np.reshape(full_y[rand_indices[train_n+valid_n:],:,:,:], test_y.shape) 123 | test_mean = means[rand_indices[train_n+valid_n:],:] 124 | test_mean = test_mean.mean() 125 | 126 | doc = { 127 | 'code': data, 128 | 'dims': dims, 129 | 'batch_size': b_size, 130 | 'train_size': train_size, 131 | 'valid_size': valid_size, 132 | 'test_size': test_size, 133 | 'means': (str(train_mean), str(valid_mean), str(test_mean)) 134 | } 135 | 136 | doc_code = 'doc' + d_code + '.json' 137 | with open(os.path.join(data_dir, doc_code), 'w') as f: 138 | json.dump(doc, f) 139 | 140 | print 'Training: ', rand_indices[:train_n] 141 | print 'Validation: ', rand_indices[train_n:train_n+valid_n] 142 | print 'Testing: ', rand_indices[train_n+valid_n:] 143 | -------------------------------------------------------------------------------- /data/create_dummy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | 5 | from breze.learn.data import one_hot 6 | 7 | DATA_HOME = 'datasets' 8 | 9 | # Make sure directory 'datasets' exists: 10 | if not os.path.exists(DATA_HOME): 11 | os.makedirs(DATA_HOME) 12 | 13 | ds = (64, 80, 72) 14 | dp = np.prod(np.array(ds)) 15 | n_chans = 4 16 | x_size = 2 17 | v_size = t_size = 1 18 | 19 | n_classes = 5 20 | 21 | x = np.random.randn(x_size, ds[2], 4, ds[0], ds[1]) 22 | y = np.random.randint(low=0,high=n_classes,size=(x_size, dp)) 23 | y = one_hot(np.reshape(y, (-1,)), n_classes) 24 | y = np.reshape(y, (x_size, dp, n_classes)) 25 | 26 | vx = np.random.randn(v_size, ds[2], 4, ds[0], ds[1]) 27 | vy = np.random.randint(low=0,high=n_classes,size=(1, dp)) 28 | vy = one_hot(np.reshape(vy, (-1,)), n_classes) 29 | vy = np.reshape(vy, (v_size, dp, n_classes)) 30 | 31 | tx = np.random.randn(t_size, ds[2], 4, ds[0], ds[1]) 32 | ty = np.random.randint(low=0,high=n_classes,size=(1, dp)) 33 | ty = one_hot(np.reshape(ty, (-1,)), n_classes) 34 | ty = np.reshape(ty, (t_size, dp, n_classes)) 35 | 36 | f = h5py.File(os.path.join(DATA_HOME, 'dummy45.hdf5'), 'w') 37 | train_x = f.create_dataset('train_x', x.shape, dtype='float32') 38 | train_y = f.create_dataset('train_y', y.shape, dtype='float32') 39 | valid_x = f.create_dataset('valid_x', vx.shape, dtype='float32') 40 | valid_y = f.create_dataset('valid_y', vy.shape, dtype='float32') 41 | test_x = f.create_dataset('test_x', tx.shape, dtype='float32') 42 | test_y = f.create_dataset('test_y', ty.shape, dtype='float32') 43 | 44 | print 'writing to dummy...' 45 | train_x[...] = x 46 | train_y[...] = y 47 | valid_x[...] = vx 48 | valid_y[...] = vy 49 | test_x[...] = tx 50 | test_y[...] = ty 51 | print 'done.' -------------------------------------------------------------------------------- /data/info/brats2013_challenge_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "sizes": [ 3 | [ 140, 180, 158 ], 4 | [ 146, 187, 160 ], 5 | [ 140, 184, 159 ], 6 | [ 140, 167, 156 ], 7 | [ 148, 184, 170 ], 8 | [ 148, 173, 152 ], 9 | [ 133, 181, 151 ], 10 | [ 138, 175, 159 ], 11 | [ 145, 187, 164 ], 12 | [ 153, 200, 166 ] 13 | ], 14 | "names": [ "VSD.0301.17572.mha", "VSD.0302.17576.mha", "VSD.0303.17580.mha", "VSD.0304.17584.mha", "VSD.0305.17588.mha", "VSD.0306.17592.mha", "VSD.0307.17596.mha", "VSD.0308.17600.mha", "VSD.0309.17604.mha", "VSD.0310.17608.mha" ], 15 | "slices_list": [ 16 | [ 17 | [ 12, 172 ], 18 | [ 10, 154 ], 19 | [ 9, 137 ] 20 | ], 21 | [ 22 | [ 15, 175 ], 23 | [ 8, 152 ], 24 | [ 11, 139 ] 25 | ], 26 | [ 27 | [ 12, 172 ], 28 | [ 7, 151 ], 29 | [ 8, 136 ] 30 | ], 31 | [ 32 | [ 4, 164 ], 33 | [ 6, 150 ], 34 | [ 3, 131 ] 35 | ], 36 | [ 37 | [ 13, 173 ], 38 | [ 14, 158 ], 39 | [ 10, 138 ] 40 | ], 41 | [ 42 | [ 8, 168 ], 43 | [ 4, 148 ], 44 | [ 11, 139 ] 45 | ], 46 | [ 47 | [ 11, 171 ], 48 | [ 3, 147 ], 49 | [ 0, 128 ] 50 | ], 51 | [ 52 | [ 9, 169 ], 53 | [ 6, 150 ], 54 | [ 3, 131 ] 55 | ], 56 | [ 57 | [ 17, 177 ], 58 | [ 8, 152 ], 59 | [ 10, 138 ] 60 | ], 61 | [ 62 | [ 21, 181 ], 63 | [ 12, 156 ], 64 | [ 13, 141 ] 65 | ] 66 | ] 67 | } -------------------------------------------------------------------------------- /data/info/brats2013_leaderboard_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "sizes": [ 3 | [ 176, 216, 160 ], 4 | [ 176, 216, 160 ], 5 | [ 176, 216, 176 ], 6 | [ 196, 216, 160 ], 7 | [ 200, 200, 170 ], 8 | [ 230, 230, 162 ], 9 | [ 230, 230, 162 ], 10 | [ 230, 230, 162 ], 11 | [ 230, 230, 162 ], 12 | [ 151, 172, 147 ], 13 | [ 134, 188, 153 ], 14 | [ 161, 199, 184 ], 15 | [ 148, 209, 174 ], 16 | [ 148, 190, 182 ], 17 | [ 155, 199, 182 ], 18 | [ 148, 189, 183 ], 19 | [ 148, 214, 179 ], 20 | [ 148, 210, 181 ], 21 | [ 142, 199, 175 ], 22 | [ 148, 199, 177 ], 23 | [ 154, 211, 179 ], 24 | [ 176, 216, 191 ], 25 | [ 176, 216, 159 ], 26 | [ 196, 216, 192 ], 27 | [ 154, 178, 147 ] 28 | ], 29 | "names": [ "VSD.0116.3508.mha", "VSD.0117.3512.mha", "VSD.0119.3516.mha", "VSD.0120.3520.mha", "VSD.0130.3524.mha", "VSD.0131.3528.mha", "VSD.0133.3532.mha", "VSD.0134.3536.mha", "VSD.0135.3540.mha", "VSD.0136.3544.mha", "VSD.0137.3548.mha", "VSD.0201.15230.mha", "VSD.0202.15236.mha", "VSD.0203.15238.mha", "VSD.0204.15242.mha", "VSD.0205.15246.mha", "VSD.0206.15250.mha", "VSD.0207.15254.mha", "VSD.0208.15258.mha", "VSD.0209.15262.mha", "VSD.0210.15266.mha", "VSD.0103.3552.mha", "VSD.0105.3556.mha", "VSD.0109.3560.mha", "VSD.0116.3564.mha" ], 30 | "slices_list": [ 31 | [ 32 | [ 41, 201 ], 33 | [ 9, 153 ], 34 | [ 28, 156 ] 35 | ], 36 | [ 37 | [ 27, 187 ], 38 | [ 8, 152 ], 39 | [ 23, 151 ] 40 | ], 41 | [ 42 | [ 41, 201 ], 43 | [ 13, 157 ], 44 | [ 23, 151 ] 45 | ], 46 | [ 47 | [ 41, 201 ], 48 | [ 11, 155 ], 49 | [ 37, 165 ] 50 | ], 51 | [ 52 | [ 21, 181 ], 53 | [ 10, 154 ], 54 | [ 44, 172 ] 55 | ], 56 | [ 57 | [ 21, 181 ], 58 | [ 8, 152 ], 59 | [ 81, 209 ] 60 | ], 61 | [ 62 | [ 43, 203 ], 63 | [ 6, 150 ], 64 | [ 78, 206 ] 65 | ], 66 | [ 67 | [ 43, 203 ], 68 | [ 7, 151 ], 69 | [ 86, 214 ] 70 | ], 71 | [ 72 | [ 33, 193 ], 73 | [ 10, 154 ], 74 | [ 79, 207 ] 75 | ], 76 | [ 77 | [ 5, 165 ], 78 | [ 1, 145 ], 79 | [ 15, 143 ] 80 | ], 81 | [ 82 | [ 11, 171 ], 83 | [ 6, 150 ], 84 | [ 0, 128 ] 85 | ], 86 | [ 87 | [ 24, 184 ], 88 | [ 19, 163 ], 89 | [ 19, 147 ] 90 | ], 91 | [ 92 | [ 26, 186 ], 93 | [ 14, 158 ], 94 | [ 10, 138 ] 95 | ], 96 | [ 97 | [ 16, 176 ], 98 | [ 19, 163 ], 99 | [ 15, 143 ] 100 | ], 101 | [ 102 | [ 25, 185 ], 103 | [ 19, 163 ], 104 | [ 22, 150 ] 105 | ], 106 | [ 107 | [ 22, 182 ], 108 | [ 18, 162 ], 109 | [ 18, 146 ] 110 | ], 111 | [ 112 | [ 27, 187 ], 113 | [ 16, 160 ], 114 | [ 13, 141 ] 115 | ], 116 | [ 117 | [ 25, 185 ], 118 | [ 17, 161 ], 119 | [ 14, 142 ] 120 | ], 121 | [ 122 | [ 20, 180 ], 123 | [ 17, 161 ], 124 | [ 14, 142 ] 125 | ], 126 | [ 127 | [ 20, 180 ], 128 | [ 15, 159 ], 129 | [ 8, 136 ] 130 | ], 131 | [ 132 | [ 28, 188 ], 133 | [ 18, 162 ], 134 | [ 8, 136 ] 135 | ], 136 | [ 137 | [ 26, 186 ], 138 | [ 19, 163 ], 139 | [ 26, 154 ] 140 | ], 141 | [ 142 | [ 30, 190 ], 143 | [ 3, 147 ], 144 | [ 25, 153 ] 145 | ], 146 | [ 147 | [ 11, 171 ], 148 | [ 22, 166 ], 149 | [ 56, 184 ] 150 | ], 151 | [ 152 | [ 8, 168 ], 153 | [ 1, 145 ], 154 | [ 17, 145 ] 155 | ] 156 | ] 157 | } -------------------------------------------------------------------------------- /data/info/brats_test_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "names": [ "vsd.brats_2013_pat0103_1.54193.mha", "vsd.brats_2013_pat0105_1.54199.mha", "vsd.brats_2013_pat0109_1.54205.mha", "vsd.brats_2013_pat0116_1.54215.mha", "vsd.brats_2013_pat0117_1.54221.mha", "vsd.brats_2013_pat0119_1.54227.mha", "vsd.brats_2013_pat0120_1.54233.mha", "vsd.brats_2013_pat0130_1.54239.mha", "vsd.brats_2013_pat0131_1.54245.mha", "vsd.brats_2013_pat0133_1.54251.mha", "vsd.brats_2013_pat0134_1.54257.mha", "vsd.brats_2013_pat0135_1.54263.mha", "vsd.brats_2013_pat0136_1.54269.mha", "vsd.brats_2013_pat0137_1.54275.mha", "vsd.brats_2013_patx116_1.54281.mha", "vsd.brats_tcia_pat114_0001.40461.mha", "vsd.brats_tcia_pat114_0053.40453.mha", "vsd.brats_tcia_pat114_0087.40469.mha", "vsd.brats_tcia_pat114_0115.40457.mha", "vsd.brats_tcia_pat114_0149.40465.mha", "vsd.brats_tcia_pat123_0001.40493.mha", "vsd.brats_tcia_pat123_0054.40473.mha", "vsd.brats_tcia_pat123_0083.40485.mha", "vsd.brats_tcia_pat123_0116.40509.mha", "vsd.brats_tcia_pat123_0193.40481.mha", "vsd.brats_tcia_pat123_0254.40505.mha", "vsd.brats_tcia_pat123_0322.40497.mha", "vsd.brats_tcia_pat123_0335.40501.mha", "vsd.brats_tcia_pat123_0360.40477.mha", "vsd.brats_tcia_pat123_0435.40489.mha", "vsd.brats_tcia_pat123_0558.40513.mha", "vsd.brats_tcia_pat127_0001.40517.mha", "vsd.brats_tcia_pat139_0159.43204.mha", "vsd.brats_tcia_pat146_0001.40537.mha", "vsd.brats_tcia_pat148_0001.40541.mha", "vsd.brats_tcia_pat166_0001.40545.mha", "vsd.brats_tcia_pat172_0001.40549.mha", "vsd.brats_tcia_pat174_0001.40553.mha", "vsd.brats_tcia_pat210_0001.40557.mha", "vsd.brats_tcia_pat220_0001.40561.mha", "vsd.brats_tcia_pat225_0001.40565.mha", "vsd.brats_tcia_pat229_0001.40569.mha", "vsd.brats_tcia_pat232_0001.40573.mha", "vsd.brats_tcia_pat236_0001.40577.mha", "vsd.brats_tcia_pat239_0001.40581.mha", "vsd.brats_tcia_pat244_0001.40589.mha", "vsd.brats_tcia_pat244_0157.40585.mha", "vsd.brats_tcia_pat244_0281.40593.mha", "vsd.brats_tcia_pat263_0001.40605.mha", "vsd.brats_tcia_pat263_0060.40597.mha", "vsd.brats_tcia_pat263_0115.40609.mha", "vsd.brats_tcia_pat263_0170.40601.mha", "vsd.brats_tcia_pat263_0199.40613.mha", "vsd.brats_tcia_pat271_0001.40621.mha", "vsd.brats_tcia_pat273_0001.40625.mha", "vsd.brats_tcia_pat294_0001.40629.mha", "vsd.brats_tcia_pat337_0019.40633.mha", "vsd.brats_tcia_pat348_0002.40637.mha", "vsd.brats_tcia_pat352_0001.40645.mha", "vsd.brats_tcia_pat352_0114.40641.mha", "vsd.brats_tcia_pat359_0001.40671.mha", "vsd.brats_tcia_pat381_0001.40679.mha", "vsd.brats_tcia_pat383_0001.40683.mha", "vsd.brats_tcia_pat384_0001.40687.mha", "vsd.brats_tcia_pat385_0001.40691.mha", "vsd.brats_tcia_pat454_0001.40695.mha", "vsd.brats_tcia_pat456_0001.40711.mha", "vsd.brats_tcia_pat456_0043.40699.mha", "vsd.brats_tcia_pat456_0112.40707.mha", "vsd.brats_tcia_pat456_0127.40703.mha", "vsd.brats_tcia_pat456_0216.40715.mha", "vsd.brats_tcia_pat457_0079.40719.mha", "vsd.brats_tcia_pat457_0231.40735.mha", "vsd.brats_tcia_pat457_0371.40731.mha", "vsd.brats_tcia_pat464_0001.40739.mha", "vsd.brats_tcia_pat467_0001.40743.mha", "vsd.brats_tcia_pat476_0001.40747.mha", "vsd.brats_tcia_pat484_0001.40767.mha", "vsd.brats_tcia_pat484_0004.40763.mha", "vsd.brats_tcia_pat484_0071.40771.mha", "vsd.brats_tcia_pat484_0093.40791.mha", "vsd.brats_tcia_pat484_0152.40775.mha", "vsd.brats_tcia_pat484_0214.40779.mha", "vsd.brats_tcia_pat484_0275.40787.mha", "vsd.brats_tcia_pat484_0345.40759.mha", "vsd.brats_tcia_pat484_0386.40755.mha", "vsd.brats_tcia_pat484_0427.40751.mha", "vsd.brats_tcia_pat484_0458.40783.mha", "vsd.brats_tcia_pat496_0001.40795.mha", "vsd.brats_tcia_pat500_0002.40811.mha", "vsd.brats_tcia_pat500_0028.41210.mha", "vsd.brats_tcia_pat500_0030.40807.mha", "vsd.brats_tcia_pat500_0067.40803.mha", "vsd.brats_tcia_pat500_0098.40815.mha", "vsd.brats_tcia_pat500_0152.41194.mha", "vsd.brats_tcia_pat500_0206.41186.mha", "vsd.brats_tcia_pat500_0268.41190.mha", "vsd.brats_tcia_pat500_0334.41218.mha", "vsd.brats_tcia_pat500_0403.41226.mha", "vsd.brats_tcia_pat500_0465.41214.mha", "vsd.brats_tcia_pat500_0525.41159.mha", "vsd.brats_tcia_pat500_0552.41206.mha", "vsd.brats_tcia_pat500_0592.41222.mha", "vsd.brats_tcia_pat500_0613.41171.mha", "vsd.brats_tcia_pat500_0647.41163.mha", "vsd.brats_tcia_pat500_0716.41182.mha", "vsd.brats_tcia_pat500_0771.40799.mha", "vsd.brats_tcia_pat500_0832.41198.mha", "vsd.brats_tcia_pat500_0852.41202.mha", "vsd.brats_tcia_pat500_0867.41175.mha" ], 3 | "slices_list": [ 4 | [ 5 | [ 48, 208 ], 6 | [ 47, 191 ], 7 | [ 13, 141 ] 8 | ], 9 | [ 10 | [ 51, 211 ], 11 | [ 46, 190 ], 12 | [ 13, 141 ] 13 | ], 14 | [ 15 | [ 25, 185 ], 16 | [ 47, 191 ], 17 | [ 17, 145 ] 18 | ], 19 | [ 20 | [ 51, 211 ], 21 | [ 47, 191 ], 22 | [ 12, 140 ] 23 | ], 24 | [ 25 | [ 50, 210 ], 26 | [ 47, 191 ], 27 | [ 11, 139 ] 28 | ], 29 | [ 30 | [ 50, 210 ], 31 | [ 47, 191 ], 32 | [ 13, 141 ] 33 | ], 34 | [ 35 | [ 50, 210 ], 36 | [ 47, 191 ], 37 | [ 11, 139 ] 38 | ], 39 | [ 40 | [ 51, 211 ], 41 | [ 46, 190 ], 42 | [ 13, 141 ] 43 | ], 44 | [ 45 | [ 49, 209 ], 46 | [ 46, 190 ], 47 | [ 13, 141 ] 48 | ], 49 | [ 50 | [ 50, 210 ], 51 | [ 47, 191 ], 52 | [ 13, 141 ] 53 | ], 54 | [ 55 | [ 52, 212 ], 56 | [ 46, 190 ], 57 | [ 13, 141 ] 58 | ], 59 | [ 60 | [ 50, 210 ], 61 | [ 46, 190 ], 62 | [ 12, 140 ] 63 | ], 64 | [ 65 | [ 52, 212 ], 66 | [ 46, 190 ], 67 | [ 13, 141 ] 68 | ], 69 | [ 70 | [ 49, 209 ], 71 | [ 47, 191 ], 72 | [ 12, 140 ] 73 | ], 74 | [ 75 | [ 51, 211 ], 76 | [ 46, 190 ], 77 | [ 14, 142 ] 78 | ], 79 | [ 80 | [ 50, 210 ], 81 | [ 47, 191 ], 82 | [ 13, 141 ] 83 | ], 84 | [ 85 | [ 50, 210 ], 86 | [ 47, 191 ], 87 | [ 14, 142 ] 88 | ], 89 | [ 90 | [ 50, 210 ], 91 | [ 47, 191 ], 92 | [ 14, 142 ] 93 | ], 94 | [ 95 | [ 50, 210 ], 96 | [ 47, 191 ], 97 | [ 14, 142 ] 98 | ], 99 | [ 100 | [ 50, 210 ], 101 | [ 47, 191 ], 102 | [ 14, 142 ] 103 | ], 104 | [ 105 | [ 57, 217 ], 106 | [ 47, 191 ], 107 | [ 13, 141 ] 108 | ], 109 | [ 110 | [ 57, 217 ], 111 | [ 47, 191 ], 112 | [ 13, 141 ] 113 | ], 114 | [ 115 | [ 57, 217 ], 116 | [ 47, 191 ], 117 | [ 13, 141 ] 118 | ], 119 | [ 120 | [ 57, 217 ], 121 | [ 47, 191 ], 122 | [ 13, 141 ] 123 | ], 124 | [ 125 | [ 57, 217 ], 126 | [ 47, 191 ], 127 | [ 13, 141 ] 128 | ], 129 | [ 130 | [ 57, 217 ], 131 | [ 47, 191 ], 132 | [ 13, 141 ] 133 | ], 134 | [ 135 | [ 57, 217 ], 136 | [ 47, 191 ], 137 | [ 13, 141 ] 138 | ], 139 | [ 140 | [ 57, 217 ], 141 | [ 47, 191 ], 142 | [ 13, 141 ] 143 | ], 144 | [ 145 | [ 57, 217 ], 146 | [ 47, 191 ], 147 | [ 13, 141 ] 148 | ], 149 | [ 150 | [ 57, 217 ], 151 | [ 47, 191 ], 152 | [ 13, 141 ] 153 | ], 154 | [ 155 | [ 57, 217 ], 156 | [ 47, 191 ], 157 | [ 13, 141 ] 158 | ], 159 | [ 160 | [ 50, 210 ], 161 | [ 47, 191 ], 162 | [ 9, 137 ] 163 | ], 164 | [ 165 | [ 49, 209 ], 166 | [ 46, 190 ], 167 | [ 6, 134 ] 168 | ], 169 | [ 170 | [ 51, 211 ], 171 | [ 46, 190 ], 172 | [ 12, 140 ] 173 | ], 174 | [ 175 | [ 51, 211 ], 176 | [ 46, 190 ], 177 | [ 13, 141 ] 178 | ], 179 | [ 180 | [ 49, 209 ], 181 | [ 46, 190 ], 182 | [ 11, 139 ] 183 | ], 184 | [ 185 | [ 51, 211 ], 186 | [ 47, 191 ], 187 | [ 11, 139 ] 188 | ], 189 | [ 190 | [ 51, 211 ], 191 | [ 47, 191 ], 192 | [ 8, 136 ] 193 | ], 194 | [ 195 | [ 50, 210 ], 196 | [ 46, 190 ], 197 | [ 10, 138 ] 198 | ], 199 | [ 200 | [ 51, 211 ], 201 | [ 47, 191 ], 202 | [ 10, 138 ] 203 | ], 204 | [ 205 | [ 49, 209 ], 206 | [ 46, 190 ], 207 | [ 7, 135 ] 208 | ], 209 | [ 210 | [ 49, 209 ], 211 | [ 46, 190 ], 212 | [ 7, 135 ] 213 | ], 214 | [ 215 | [ 50, 210 ], 216 | [ 46, 190 ], 217 | [ 9, 137 ] 218 | ], 219 | [ 220 | [ 50, 210 ], 221 | [ 46, 190 ], 222 | [ 9, 137 ] 223 | ], 224 | [ 225 | [ 49, 209 ], 226 | [ 47, 191 ], 227 | [ 10, 138 ] 228 | ], 229 | [ 230 | [ 48, 208 ], 231 | [ 47, 191 ], 232 | [ 8, 136 ] 233 | ], 234 | [ 235 | [ 48, 208 ], 236 | [ 47, 191 ], 237 | [ 8, 136 ] 238 | ], 239 | [ 240 | [ 48, 208 ], 241 | [ 47, 191 ], 242 | [ 8, 136 ] 243 | ], 244 | [ 245 | [ 49, 209 ], 246 | [ 47, 191 ], 247 | [ 6, 134 ] 248 | ], 249 | [ 250 | [ 49, 209 ], 251 | [ 47, 191 ], 252 | [ 6, 134 ] 253 | ], 254 | [ 255 | [ 49, 209 ], 256 | [ 47, 191 ], 257 | [ 6, 134 ] 258 | ], 259 | [ 260 | [ 49, 209 ], 261 | [ 47, 191 ], 262 | [ 6, 134 ] 263 | ], 264 | [ 265 | [ 49, 209 ], 266 | [ 47, 191 ], 267 | [ 6, 134 ] 268 | ], 269 | [ 270 | [ 49, 209 ], 271 | [ 46, 190 ], 272 | [ 4, 132 ] 273 | ], 274 | [ 275 | [ 49, 209 ], 276 | [ 46, 190 ], 277 | [ 11, 139 ] 278 | ], 279 | [ 280 | [ 55, 215 ], 281 | [ 45, 189 ], 282 | [ 12, 140 ] 283 | ], 284 | [ 285 | [ 51, 211 ], 286 | [ 45, 189 ], 287 | [ 7, 135 ] 288 | ], 289 | [ 290 | [ 56, 216 ], 291 | [ 47, 191 ], 292 | [ 9, 137 ] 293 | ], 294 | [ 295 | [ 50, 210 ], 296 | [ 47, 191 ], 297 | [ 7, 135 ] 298 | ], 299 | [ 300 | [ 50, 210 ], 301 | [ 47, 191 ], 302 | [ 7, 135 ] 303 | ], 304 | [ 305 | [ 50, 210 ], 306 | [ 46, 190 ], 307 | [ 12, 140 ] 308 | ], 309 | [ 310 | [ 51, 211 ], 311 | [ 46, 190 ], 312 | [ 6, 134 ] 313 | ], 314 | [ 315 | [ 51, 211 ], 316 | [ 47, 191 ], 317 | [ 11, 139 ] 318 | ], 319 | [ 320 | [ 53, 213 ], 321 | [ 46, 190 ], 322 | [ 11, 139 ] 323 | ], 324 | [ 325 | [ 50, 210 ], 326 | [ 46, 190 ], 327 | [ 10, 138 ] 328 | ], 329 | [ 330 | [ 48, 208 ], 331 | [ 46, 190 ], 332 | [ 4, 132 ] 333 | ], 334 | [ 335 | [ 50, 210 ], 336 | [ 48, 192 ], 337 | [ 14, 142 ] 338 | ], 339 | [ 340 | [ 50, 210 ], 341 | [ 48, 192 ], 342 | [ 15, 143 ] 343 | ], 344 | [ 345 | [ 50, 210 ], 346 | [ 48, 192 ], 347 | [ 15, 143 ] 348 | ], 349 | [ 350 | [ 50, 210 ], 351 | [ 48, 192 ], 352 | [ 15, 143 ] 353 | ], 354 | [ 355 | [ 50, 210 ], 356 | [ 48, 192 ], 357 | [ 15, 143 ] 358 | ], 359 | [ 360 | [ 52, 212 ], 361 | [ 45, 189 ], 362 | [ 9, 137 ] 363 | ], 364 | [ 365 | [ 52, 212 ], 366 | [ 45, 189 ], 367 | [ 9, 137 ] 368 | ], 369 | [ 370 | [ 52, 212 ], 371 | [ 45, 189 ], 372 | [ 9, 137 ] 373 | ], 374 | [ 375 | [ 53, 213 ], 376 | [ 47, 191 ], 377 | [ 10, 138 ] 378 | ], 379 | [ 380 | [ 50, 210 ], 381 | [ 47, 191 ], 382 | [ 12, 140 ] 383 | ], 384 | [ 385 | [ 49, 209 ], 386 | [ 46, 190 ], 387 | [ 10, 138 ] 388 | ], 389 | [ 390 | [ 49, 209 ], 391 | [ 45, 189 ], 392 | [ 7, 135 ] 393 | ], 394 | [ 395 | [ 49, 209 ], 396 | [ 45, 189 ], 397 | [ 7, 135 ] 398 | ], 399 | [ 400 | [ 49, 209 ], 401 | [ 45, 189 ], 402 | [ 7, 135 ] 403 | ], 404 | [ 405 | [ 49, 209 ], 406 | [ 45, 189 ], 407 | [ 7, 135 ] 408 | ], 409 | [ 410 | [ 49, 209 ], 411 | [ 45, 189 ], 412 | [ 7, 135 ] 413 | ], 414 | [ 415 | [ 49, 209 ], 416 | [ 45, 189 ], 417 | [ 7, 135 ] 418 | ], 419 | [ 420 | [ 49, 209 ], 421 | [ 45, 189 ], 422 | [ 7, 135 ] 423 | ], 424 | [ 425 | [ 49, 209 ], 426 | [ 45, 189 ], 427 | [ 7, 135 ] 428 | ], 429 | [ 430 | [ 49, 209 ], 431 | [ 45, 189 ], 432 | [ 7, 135 ] 433 | ], 434 | [ 435 | [ 49, 209 ], 436 | [ 45, 189 ], 437 | [ 7, 135 ] 438 | ], 439 | [ 440 | [ 49, 209 ], 441 | [ 45, 189 ], 442 | [ 7, 135 ] 443 | ], 444 | [ 445 | [ 49, 209 ], 446 | [ 46, 190 ], 447 | [ 8, 136 ] 448 | ], 449 | [ 450 | [ 50, 210 ], 451 | [ 46, 190 ], 452 | [ 8, 136 ] 453 | ], 454 | [ 455 | [ 50, 210 ], 456 | [ 46, 190 ], 457 | [ 8, 136 ] 458 | ], 459 | [ 460 | [ 50, 210 ], 461 | [ 46, 190 ], 462 | [ 8, 136 ] 463 | ], 464 | [ 465 | [ 50, 210 ], 466 | [ 46, 190 ], 467 | [ 8, 136 ] 468 | ], 469 | [ 470 | [ 50, 210 ], 471 | [ 46, 190 ], 472 | [ 8, 136 ] 473 | ], 474 | [ 475 | [ 50, 210 ], 476 | [ 46, 190 ], 477 | [ 8, 136 ] 478 | ], 479 | [ 480 | [ 50, 210 ], 481 | [ 46, 190 ], 482 | [ 8, 136 ] 483 | ], 484 | [ 485 | [ 50, 210 ], 486 | [ 46, 190 ], 487 | [ 8, 136 ] 488 | ], 489 | [ 490 | [ 50, 210 ], 491 | [ 46, 190 ], 492 | [ 8, 136 ] 493 | ], 494 | [ 495 | [ 50, 210 ], 496 | [ 46, 190 ], 497 | [ 8, 136 ] 498 | ], 499 | [ 500 | [ 50, 210 ], 501 | [ 46, 190 ], 502 | [ 8, 136 ] 503 | ], 504 | [ 505 | [ 50, 210 ], 506 | [ 46, 190 ], 507 | [ 8, 136 ] 508 | ], 509 | [ 510 | [ 50, 210 ], 511 | [ 46, 190 ], 512 | [ 8, 136 ] 513 | ], 514 | [ 515 | [ 50, 210 ], 516 | [ 46, 190 ], 517 | [ 8, 136 ] 518 | ], 519 | [ 520 | [ 50, 210 ], 521 | [ 46, 190 ], 522 | [ 8, 136 ] 523 | ], 524 | [ 525 | [ 50, 210 ], 526 | [ 46, 190 ], 527 | [ 8, 136 ] 528 | ], 529 | [ 530 | [ 50, 210 ], 531 | [ 46, 190 ], 532 | [ 8, 136 ] 533 | ], 534 | [ 535 | [ 50, 210 ], 536 | [ 46, 190 ], 537 | [ 8, 136 ] 538 | ], 539 | [ 540 | [ 50, 210 ], 541 | [ 46, 190 ], 542 | [ 8, 136 ] 543 | ], 544 | [ 545 | [ 50, 210 ], 546 | [ 46, 190 ], 547 | [ 8, 136 ] 548 | ], 549 | [ 550 | [ 50, 210 ], 551 | [ 46, 190 ], 552 | [ 8, 136 ] 553 | ] 554 | ] 555 | } -------------------------------------------------------------------------------- /data/info/brats_test_names.txt: -------------------------------------------------------------------------------- 1 | vsd.brats-tcia_pat103-1.40819.mha 2 | vsd.brats-tcia_pat106-1.40449.mha 3 | vsd.brats-tcia_pat109-2.40823.mha 4 | vsd.brats-tcia_pat111-1.40827.mha 5 | vsd.brats-tcia_pat114-1.40453.mha 6 | vsd.brats-tcia_pat114-2.40457.mha 7 | vsd.brats-tcia_pat114-3.40461.mha 8 | vsd.brats-tcia_pat114-4.40465.mha 9 | vsd.brats-tcia_pat114-5.40469.mha 10 | vsd.brats-tcia_pat123-1.40473.mha 11 | vsd.brats-tcia_pat123-10.40477.mha 12 | vsd.brats-tcia_pat123-11.40481.mha 13 | vsd.brats-tcia_pat123-12.40485.mha 14 | vsd.brats-tcia_pat123-13.40489.mha 15 | vsd.brats-tcia_pat123-2.40493.mha 16 | vsd.brats-tcia_pat123-4.40497.mha 17 | vsd.brats-tcia_pat123-5.40501.mha 18 | vsd.brats-tcia_pat123-6.40505.mha 19 | vsd.brats-tcia_pat123-7.40509.mha 20 | vsd.brats-tcia_pat123-9.40513.mha 21 | vsd.brats-tcia_pat127-1.40517.mha 22 | vsd.brats-tcia_pat130-1.40831.mha 23 | vsd.brats-tcia_pat139-1.40521.mha 24 | vsd.brats-tcia_pat139-3.43204.mha 25 | vsd.brats-tcia_pat139-4.40529.mha 26 | vsd.brats-tcia_pat139-5.40533.mha 27 | vsd.brats-tcia_pat146-1.40537.mha 28 | vsd.brats-tcia_pat148-1.40541.mha 29 | vsd.brats-tcia_pat152-1.40835.mha 30 | vsd.brats-tcia_pat153-2.40839.mha 31 | vsd.brats-tcia_pat153-3.40843.mha 32 | vsd.brats-tcia_pat153-6.40847.mha 33 | vsd.brats-tcia_pat153-7.40851.mha 34 | vsd.brats-tcia_pat153-8.40855.mha 35 | vsd.brats-tcia_pat153-9.40859.mha 36 | vsd.brats-tcia_pat157-1.40863.mha 37 | vsd.brats-tcia_pat165-1.40871.mha 38 | vsd.brats-tcia_pat166-1.40545.mha 39 | vsd.brats-tcia_pat172-1.40549.mha 40 | vsd.brats-tcia_pat174-1.40553.mha 41 | vsd.brats-tcia_pat175-1.40875.mha 42 | vsd.brats-tcia_pat179-1.40879.mha 43 | vsd.brats-tcia_pat179-2.40883.mha 44 | vsd.brats-tcia_pat190-1.40887.mha 45 | vsd.brats-tcia_pat200-2.40895.mha 46 | vsd.brats-tcia_pat201-1.40903.mha 47 | vsd.brats-tcia_pat210-1.40557.mha 48 | vsd.brats-tcia_pat220-1.40561.mha 49 | vsd.brats-tcia_pat222-1.40907.mha 50 | vsd.brats-tcia_pat222-2.40911.mha 51 | vsd.brats-tcia_pat222-3.40915.mha 52 | vsd.brats-tcia_pat225-1.40565.mha 53 | vsd.brats-tcia_pat226-1.40919.mha 54 | vsd.brats-tcia_pat226-2.40923.mha 55 | vsd.brats-tcia_pat229-1.40569.mha 56 | vsd.brats-tcia_pat231-1.40927.mha 57 | vsd.brats-tcia_pat232-1.40573.mha 58 | vsd.brats-tcia_pat236-1.40577.mha 59 | vsd.brats-tcia_pat239-1.40581.mha 60 | vsd.brats-tcia_pat241-1.40931.mha 61 | vsd.brats-tcia_pat244-1.40585.mha 62 | vsd.brats-tcia_pat244-2.40589.mha 63 | vsd.brats-tcia_pat244-3.40593.mha 64 | vsd.brats-tcia_pat258-1.40935.mha 65 | vsd.brats-tcia_pat261-1.40939.mha 66 | vsd.brats-tcia_pat263-1.40597.mha 67 | vsd.brats-tcia_pat263-3.40601.mha 68 | vsd.brats-tcia_pat263-4.40605.mha 69 | vsd.brats-tcia_pat263-6.40609.mha 70 | vsd.brats-tcia_pat263-7.40613.mha 71 | vsd.brats-tcia_pat271-1.40621.mha 72 | vsd.brats-tcia_pat273-1.40625.mha 73 | vsd.brats-tcia_pat282-1.40943.mha 74 | vsd.brats-tcia_pat290-1.40947.mha 75 | vsd.brats-tcia_pat290-2.41230.mha 76 | vsd.brats-tcia_pat290-3.40955.mha 77 | vsd.brats-tcia_pat290-4.41636.mha 78 | vsd.brats-tcia_pat290-5.40963.mha 79 | vsd.brats-tcia_pat290-6.40967.mha 80 | vsd.brats-tcia_pat294-1.40629.mha 81 | vsd.brats-tcia_pat307-1.40971.mha 82 | vsd.brats-tcia_pat309-1.40975.mha 83 | vsd.brats-tcia_pat309-2.40979.mha 84 | vsd.brats-tcia_pat309-3.40983.mha 85 | vsd.brats-tcia_pat309-4.40987.mha 86 | vsd.brats-tcia_pat309-5.40991.mha 87 | vsd.brats-tcia_pat309-6.40995.mha 88 | vsd.brats-tcia_pat311-1.40999.mha 89 | vsd.brats-tcia_pat314-1.41036.mha 90 | vsd.brats-tcia_pat314-2.41006.mha 91 | vsd.brats-tcia_pat314-3.41234.mha 92 | vsd.brats-tcia_pat314-4.41640.mha 93 | vsd.brats-tcia_pat314-5.41018.mha 94 | vsd.brats-tcia_pat318-1.41022.mha 95 | vsd.brats-tcia_pat330-1.41026.mha 96 | vsd.brats-tcia_pat332-1.41030.mha 97 | vsd.brats-tcia_pat337-11.40633.mha 98 | vsd.brats-tcia_pat346-2.41040.mha 99 | vsd.brats-tcia_pat348-1.40637.mha 100 | vsd.brats-tcia_pat352-1.40641.mha 101 | vsd.brats-tcia_pat352-2.40645.mha 102 | vsd.brats-tcia_pat352-3.40649.mha 103 | vsd.brats-tcia_pat352-4.40653.mha 104 | vsd.brats-tcia_pat354-1.41044.mha 105 | vsd.brats-tcia_pat355-1.40663.mha 106 | vsd.brats-tcia_pat355-2.40667.mha 107 | vsd.brats-tcia_pat358-1.41048.mha 108 | vsd.brats-tcia_pat359-1.40671.mha 109 | vsd.brats-tcia_pat363-1.40675.mha 110 | vsd.brats-tcia_pat377-1.41052.mha 111 | vsd.brats-tcia_pat377-2.41056.mha 112 | vsd.brats-tcia_pat378-1.41060.mha 113 | vsd.brats-tcia_pat381-1.40679.mha 114 | vsd.brats-tcia_pat383-1.40683.mha 115 | vsd.brats-tcia_pat384-1.40687.mha 116 | vsd.brats-tcia_pat385-1.40691.mha 117 | vsd.brats-tcia_pat387-1.41064.mha 118 | vsd.brats-tcia_pat392-1.41068.mha 119 | vsd.brats-tcia_pat392-2.41072.mha 120 | vsd.brats-tcia_pat392-3.41238.mha 121 | vsd.brats-tcia_pat392-4.41080.mha 122 | vsd.brats-tcia_pat392-5.41644.mha 123 | vsd.brats-tcia_pat393-1.41088.mha 124 | vsd.brats-tcia_pat408-1.41092.mha 125 | vsd.brats-tcia_pat409-1.41096.mha 126 | vsd.brats-tcia_pat410-1.41100.mha 127 | vsd.brats-tcia_pat420-1.41104.mha 128 | vsd.brats-tcia_pat439-1.41111.mha 129 | vsd.brats-tcia_pat439-2.41115.mha 130 | vsd.brats-tcia_pat439-4.41119.mha 131 | vsd.brats-tcia_pat439-5.41123.mha 132 | vsd.brats-tcia_pat447-1.41127.mha 133 | vsd.brats-tcia_pat447-3.41131.mha 134 | vsd.brats-tcia_pat447-4.41135.mha 135 | vsd.brats-tcia_pat451-1.41139.mha 136 | vsd.brats-tcia_pat454-1.40695.mha 137 | vsd.brats-tcia_pat456-13.40699.mha 138 | vsd.brats-tcia_pat456-14.40703.mha 139 | vsd.brats-tcia_pat456-3.40707.mha 140 | vsd.brats-tcia_pat456-5.40711.mha 141 | vsd.brats-tcia_pat456-7.40715.mha 142 | vsd.brats-tcia_pat457-1.40719.mha 143 | vsd.brats-tcia_pat457-2.40723.mha 144 | vsd.brats-tcia_pat457-3.40727.mha 145 | vsd.brats-tcia_pat457-4.40731.mha 146 | vsd.brats-tcia_pat457-6.40735.mha 147 | vsd.brats-tcia_pat462-1.41143.mha 148 | vsd.brats-tcia_pat464-1.40739.mha 149 | vsd.brats-tcia_pat467-1.40743.mha 150 | vsd.brats-tcia_pat470-1.41147.mha 151 | vsd.brats-tcia_pat476-1.40747.mha 152 | vsd.brats-tcia_pat480-1.41151.mha 153 | vsd.brats-tcia_pat484-1.40751.mha 154 | vsd.brats-tcia_pat484-11.40755.mha 155 | vsd.brats-tcia_pat484-12.40759.mha 156 | vsd.brats-tcia_pat484-13.40763.mha 157 | vsd.brats-tcia_pat484-2.40767.mha 158 | vsd.brats-tcia_pat484-4.40771.mha 159 | vsd.brats-tcia_pat484-5.40775.mha 160 | vsd.brats-tcia_pat484-6.40779.mha 161 | vsd.brats-tcia_pat484-7.40783.mha 162 | vsd.brats-tcia_pat484-8.40787.mha 163 | vsd.brats-tcia_pat484-9.40791.mha 164 | vsd.brats-tcia_pat496-1.40795.mha 165 | vsd.brats-tcia_pat497-1.41155.mha 166 | vsd.brats-tcia_pat500-1.40799.mha 167 | vsd.brats-tcia_pat500-10.41159.mha 168 | vsd.brats-tcia_pat500-11.41163.mha 169 | vsd.brats-tcia_pat500-12.41632.mha 170 | vsd.brats-tcia_pat500-13.41171.mha 171 | vsd.brats-tcia_pat500-15.41175.mha 172 | vsd.brats-tcia_pat500-16.41182.mha 173 | vsd.brats-tcia_pat500-17.41186.mha 174 | vsd.brats-tcia_pat500-18.41190.mha 175 | vsd.brats-tcia_pat500-19.41194.mha 176 | vsd.brats-tcia_pat500-2.40803.mha 177 | vsd.brats-tcia_pat500-20.41198.mha 178 | vsd.brats-tcia_pat500-21.41202.mha 179 | vsd.brats-tcia_pat500-22.41206.mha 180 | vsd.brats-tcia_pat500-23.41210.mha 181 | vsd.brats-tcia_pat500-24.41214.mha 182 | vsd.brats-tcia_pat500-3.40807.mha 183 | vsd.brats-tcia_pat500-4.40811.mha 184 | vsd.brats-tcia_pat500-6.40815.mha 185 | vsd.brats-tcia_pat500-7.41218.mha 186 | vsd.brats-tcia_pat500-8.41222.mha 187 | vsd.brats-tcia_pat500-9.41226.mha -------------------------------------------------------------------------------- /data/info/check_names.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | with open('brats_test_names.txt', 'r') as f: 4 | official_names = f.read() 5 | 6 | official_names = official_names.split('\n') 7 | official_ids = [] 8 | 9 | for o_name in official_names: 10 | o_id = o_name.split('.')[-2] 11 | official_ids.append(o_id) 12 | 13 | with open('brats_test_info.json', 'r') as f: 14 | local_info = json.load(f) 15 | local_names = local_info['names'] 16 | 17 | local_ids = [] 18 | for l_name in local_names: 19 | l_id = l_name.split('.')[-2] 20 | local_ids.append(l_id) 21 | 22 | print 'Comparing lists...\n' 23 | 24 | count_l = 0 25 | for l_id in local_ids: 26 | if l_id not in official_ids: 27 | print 'No matching official for: ', l_id 28 | count_l += 1 29 | 30 | print 'Number of problems: %i/%i' % (count_l, len(local_ids)) 31 | print '-'*40 32 | 33 | count_o = 0 34 | for o_id in official_ids: 35 | if o_id not in local_ids: 36 | print 'No matching local for: ', o_id 37 | count_o += 1 38 | 39 | print 'Number of problems: %i/%i' % (count_o, len(official_ids)) 40 | 41 | print '\nCollecting matches...\n' 42 | count_l = 0 43 | for l_id in local_ids: 44 | if l_id in official_ids: 45 | print 'There is a matching official for: ', l_id 46 | count_l += 1 47 | 48 | print 'Number of matches: %i/%i' % (count_l, len(local_ids)) 49 | -------------------------------------------------------------------------------- /data/info/dummy_test_info.json: -------------------------------------------------------------------------------- 1 | {"names": ["dummy.mha"], "slices_list": [[[0, 32], [0, 32], [0, 32]]]} -------------------------------------------------------------------------------- /data/pkl_to_hdf5.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import gzip 3 | import numpy as np 4 | import h5py 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | from breze.learn.data import one_hot 9 | 10 | # 11 | # The following needs to be adapted 12 | # to the particular set of patch pickles 13 | # being read: 14 | # the name of the .hdf5 file, b_size, train_size, 15 | # test_size and valid_size, dims, gt_dims, p_code(located right 16 | # before the for-loop.) 17 | 18 | h_file = h5py.File('data96fcnmini.hdf5', 'w') 19 | 20 | b_size = 5 21 | train_size = 20*b_size 22 | valid_size = test_size = 5*b_size 23 | dims = (96, 96, 96) 24 | gt_dims = (dims[0],dims[1],dims[2]) 25 | dimprod = np.prod(np.array(dims)) 26 | gt_dimprod = np.prod(np.array(gt_dims)) 27 | 28 | print 'Train size: ', train_size 29 | print 'Valid size: ', valid_size 30 | print 'Test size: ', test_size 31 | 32 | full_x = np.zeros((train_size+valid_size+test_size, dims[2], 1, dims[0], dims[1]), dtype='float32') 33 | full_y = np.zeros((train_size+valid_size+test_size, gt_dimprod, 2), dtype='float32') 34 | 35 | ###### DELETE THIS LATER ###### 36 | #train_size = 2*b_size 37 | #valid_size = test_size = b_size 38 | ###### DELETE THIS LATER ###### 39 | 40 | train_x = h_file.create_dataset( 41 | 'train_x', (train_size, dims[2], 1, dims[0], dims[1]), dtype='float32' 42 | ) 43 | train_y = h_file.create_dataset( 44 | 'train_y', (train_size,gt_dimprod,2), dtype='float32' 45 | ) 46 | 47 | valid_x = h_file.create_dataset( 48 | 'valid_x', (valid_size, dims[2], 1, dims[0], dims[1]), dtype='float32' 49 | ) 50 | valid_y = h_file.create_dataset( 51 | 'valid_y', (valid_size,gt_dimprod,2), dtype='float32' 52 | ) 53 | 54 | test_x = h_file.create_dataset( 55 | 'test_x', (test_size, dims[2], 1, dims[0], dims[1]), dtype='float32' 56 | ) 57 | test_y = h_file.create_dataset( 58 | 'test_y', (test_size,gt_dimprod,2), dtype='float32' 59 | ) 60 | 61 | count = 1 62 | index = train_i = valid_i = test_i = 0 63 | p_code = 'fcn96' 64 | 65 | while count <= 30: 66 | print 'Reading image ', count 67 | with gzip.open('../../patches'+p_code+'/im'+str(count)+'.pkl.gz', 'rb') as f: 68 | patches, labels = pickle.load(f) 69 | print labels.shape 70 | for lbl in labels: 71 | print 'Mean value: ', lbl.mean() 72 | labels = np.reshape(labels, (-1,)) 73 | labels = np.asarray(labels, dtype='int16') 74 | full_x[index:index+b_size,:,0,:,:] = np.transpose(patches, (0, 3, 1, 2)) 75 | l = np.reshape(one_hot(labels, 2), (b_size, gt_dimprod, -1)) 76 | full_y[index:index+b_size,:,:] = np.asarray(l, dtype='float32') 77 | plt.imshow(full_x[index,dims[2]/2,0,:,:], cmap='Greys_r') 78 | plt.show() 79 | gt = np.reshape(labels[:gt_dimprod], gt_dims) 80 | plt.imshow(gt[:,:,gt_dims[2]/2], cmap='Greys_r') 81 | plt.show() 82 | index += b_size 83 | count += 1 84 | 85 | #rand_indices = np.random.permutation(train_size+valid_size+test_size) 86 | rand_indices = np.arange(train_size+valid_size+test_size) 87 | 88 | train_x[:,:,:,:,:] = full_x[rand_indices[:train_size],:,:,:,:] 89 | train_y[:,:,:] = full_y[rand_indices[:train_size],:,:] 90 | 91 | valid_x[:,:,:,:,:] = full_x[rand_indices[train_size:train_size+valid_size],:,:,:,:] 92 | valid_y[:,:,:] = full_y[rand_indices[train_size:train_size+valid_size],:,:] 93 | 94 | test_x[:,:,:,:,:] = full_x[rand_indices[train_size+valid_size:],:,:,:,:] 95 | test_y[:,:,:] = full_y[rand_indices[train_size+valid_size:],:,:] 96 | 97 | #while count <= 30: 98 | # print 'Reading image ', count 99 | # with gzip.open('../../patches'+p_code+'/im'+str(count)+'.pkl.gz', 'rb') as f: 100 | # patches, labels = pickle.load(f) 101 | # rand_indices = np.random.permutation(b_size) 102 | # patches = patches[rand_indices,:,:,:] 103 | # labels = labels[rand_indices,:] 104 | # labels = np.reshape(labels, (-1,)) 105 | # labels = np.asarray(labels, dtype='int16') 106 | # 107 | # if count <= 20: 108 | # train_x[train_i:train_i+b_size,:,0,:,:] = np.transpose(patches, (0, 3, 1, 2)) 109 | # l = np.reshape(one_hot(labels, 2),(b_size,dimprod,-1)) 110 | # train_y[train_i:train_i+b_size,:,:] = np.asarray(l, dtype='float32') 111 | # train_i += b_size 112 | # elif count <= 25: 113 | # valid_x[valid_i:valid_i+b_size,:,0,:,:] = np.transpose(patches, (0, 3, 1, 2)) 114 | # l = np.reshape(one_hot(labels, 2),(b_size,dimprod,-1)) 115 | # valid_y[valid_i:valid_i+b_size,:,:] = np.asarray(l, dtype='float32') 116 | # valid_i += b_size 117 | # else: 118 | # test_x[test_i:test_i+b_size,:,0,:,:] = np.transpose(patches, (0, 3, 1, 2)) 119 | # l = np.reshape(one_hot(labels, 2),(b_size,dimprod,-1)) 120 | # test_y[test_i:test_i+b_size,:,:] = np.asarray(l, dtype='float32') 121 | # test_i += b_size 122 | # print '\tdone.' 123 | # count += 1 124 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import json 3 | import os 4 | import time 5 | import sys 6 | from scipy.ndimage import zoom 7 | from scipy.ndimage.interpolation import rotate 8 | from skimage import color 9 | from breze.learn.data import one_hot 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import h5py 13 | import gnumpy 14 | 15 | import climin.stops 16 | from climin import mathadapt as ma 17 | 18 | import ash 19 | from model_defs import get_model 20 | 21 | from conv3d.model import SequentialModel 22 | 23 | def to_sections(im): 24 | check = [i % 2 == 0 for i in im.shape] 25 | if not all(check): 26 | raise ValueError('All dimensions must be even numbers. Got: %s' % check) 27 | size_x, size_y, size_z = im.shape 28 | step_x = size_x / 2 29 | step_y = size_y / 2 30 | step_z = size_z / 2 31 | 32 | sections = [] 33 | for z in range(2): 34 | for y in range(2): 35 | for x in range(2): 36 | begin_x = step_x*x 37 | begin_y = step_y*y 38 | begin_z = step_z*z 39 | 40 | end_x = begin_x + step_x 41 | end_y = begin_y + step_y 42 | end_z = begin_z + step_z 43 | 44 | section = im[begin_x:end_x, begin_y:end_y, begin_z:end_z].copy() 45 | sections.append(section) 46 | return sections 47 | 48 | def from_sections(sections, original_shape): 49 | if len(sections) == 0: 50 | raise ValueError('Section list is empty.') 51 | im = np.zeros(original_shape, dtype=sections[0].dtype) 52 | 53 | check = [i % 2 == 0 for i in im.shape] 54 | if not all(check): 55 | raise ValueError('All dimensions must be even numbers. Got: %s' % check) 56 | size_x, size_y, size_z = im.shape 57 | step_x = size_x / 2 58 | step_y = size_y / 2 59 | step_z = size_z / 2 60 | 61 | count = 0 62 | for z in range(2): 63 | for y in range(2): 64 | for x in range(2): 65 | begin_x = step_x*x 66 | begin_y = step_y*y 67 | begin_z = step_z*z 68 | 69 | end_x = begin_x + step_x 70 | end_y = begin_y + step_y 71 | end_z = begin_z + step_z 72 | 73 | section = sections[count] 74 | count += 1 75 | im[begin_x:end_x, begin_y:end_y, begin_z:end_z] = section[:,:,:].copy() 76 | 77 | return im 78 | 79 | def dice_alt(seg, gt): 80 | sim = float(np.sum(np.minimum(seg, gt))) 81 | sim /= np.sum(np.maximum(seg, gt)) 82 | return sim 83 | 84 | def dice_(seg, gt): 85 | intersection = 2. * np.sum(seg * gt) 86 | denominator = (np.sum(np.square(seg)) + np.sum(np.square(gt))) 87 | if denominator == 0: 88 | return 1. 89 | similarity = intersection / denominator 90 | return similarity 91 | 92 | def dice(seg, gt): 93 | seg_transposed = np.transpose(seg, (3, 0, 1, 2)) 94 | gt = np.transpose(gt, (3, 0, 1, 2)) 95 | 96 | dice_list = [dice_(s, g) for s, g in zip(seg_transposed, gt)] 97 | return dice_list 98 | 99 | def get_whole(map): 100 | healthy = map[:, :, :, 0] 101 | non_healthy = np.sum(map[:, :, :, 1:], axis=3) 102 | 103 | result = np.zeros((map.shape[:3] + (2,))) 104 | result[:, :, :, 0] = healthy 105 | result[:, :, :, 1] = non_healthy 106 | 107 | return result.argmax(axis=3) 108 | 109 | def get_core(map): 110 | core = map[:, :, :, 1] + np.sum(map[:, :, :, 3:], axis=3) 111 | non_core = map[:, :, :, 0] + map[:, :, :, 2] 112 | 113 | result = np.zeros((map.shape[:3] + (2,))) 114 | result[:, :, :, 0] = non_core 115 | result[:, :, :, 1] = core 116 | 117 | return result.argmax(axis=3) 118 | 119 | def get_active(map): 120 | active = map[:, :, :, 3] 121 | non_active = np.sum(map[:, :, :, :3], axis=3) + map[:, :, :, 4] 122 | 123 | result = np.zeros((map.shape[:3] + (2,))) 124 | result[:, :, :, 0] = non_active 125 | result[:, :, :, 1] = active 126 | 127 | return result.argmax(axis=3) 128 | 129 | def brats_dice(seg, gt): 130 | whole_seg = get_whole(seg) 131 | whole_gt = get_whole(gt) 132 | core_seg = get_core(seg) 133 | core_gt = get_core(gt) 134 | active_seg = get_active(seg) 135 | active_gt = get_active(gt) 136 | 137 | seg_and_gt = [(whole_seg, whole_gt), (core_seg, core_gt), (active_seg, active_gt)] 138 | dice_list = [dice_(s, g) for s, g in seg_and_gt] 139 | 140 | return dice_list 141 | 142 | def discrete(seg, n_classes): 143 | original_shape = seg.shape 144 | discrete_seg = seg.argmax(axis=3) 145 | discrete_seg = np.reshape(discrete_seg, (-1,)) 146 | discrete_seg = np.reshape(one_hot(discrete_seg, n_classes), original_shape) 147 | 148 | return discrete_seg 149 | 150 | def vis_col_result(im, seg, gt, savefile=None): 151 | indices_0 = np.where(gt == 0) 152 | indices_1 = np.where(gt == 1) # metacarpal 153 | indices_2 = np.where(gt == 2) # proximal 154 | indices_3 = np.where(gt == 3) # middle (thumb: distal) 155 | indices_4 = np.where(gt == 4) # distal (thumb: none) 156 | 157 | indices_s0 = np.where(seg == 0) 158 | indices_s1 = np.where(seg == 1) 159 | indices_s2 = np.where(seg == 2) 160 | indices_s3 = np.where(seg == 3) 161 | indices_s4 = np.where(seg == 4) 162 | 163 | im = im * 1. / im.max() 164 | rgb_image = color.gray2rgb(im) 165 | m0 = [0.6, 0.6, 1.] 166 | m1 = [0.2, 1., 0.2] 167 | m2 = [1., 1., 0.2] 168 | m3 = [1., 0.6, 0.2] 169 | m4 = [1., 0., 0.] 170 | 171 | im_gt = rgb_image.copy() 172 | im_seg = rgb_image.copy() 173 | im_gt[indices_0[0], indices_0[1], :] *= m0 174 | im_gt[indices_1[0], indices_1[1], :] *= m1 175 | im_gt[indices_2[0], indices_2[1], :] *= m2 176 | im_gt[indices_3[0], indices_3[1], :] *= m3 177 | im_gt[indices_4[0], indices_4[1], :] *= m4 178 | 179 | im_seg[indices_s0[0], indices_s0[1], :] *= m0 180 | im_seg[indices_s1[0], indices_s1[1], :] *= m1 181 | im_seg[indices_s2[0], indices_s2[1], :] *= m2 182 | im_seg[indices_s3[0], indices_s3[1], :] *= m3 183 | im_seg[indices_s4[0], indices_s4[1], :] *= m4 184 | 185 | fig = plt.figure() 186 | a = fig.add_subplot(1, 2, 1) 187 | plt.imshow(im_seg) 188 | a.set_title('Segmentation') 189 | a = fig.add_subplot(1, 2, 2) 190 | plt.imshow(im_gt) 191 | a.set_title('Ground truth') 192 | if savefile is not None: 193 | plt.savefig(savefile) 194 | else: 195 | plt.show() 196 | plt.close() 197 | 198 | def vis_result(image, seg, gt, title1='Segmentation', title2='Ground truth', savefile=None): 199 | indices = np.where(seg >= 0.5) 200 | indices_gt = np.where(gt >= 0.5) 201 | 202 | im_norm = image / image.max() 203 | rgb_image = color.gray2rgb(im_norm) 204 | multiplier = [0., 1., 1.] 205 | multiplier_gt = [1., 1., 0.] 206 | 207 | im_seg = rgb_image.copy() 208 | im_gt = rgb_image.copy() 209 | im_seg[indices[0], indices[1], :] *= multiplier 210 | im_gt[indices_gt[0], indices_gt[1], :] *= multiplier_gt 211 | 212 | fig = plt.figure() 213 | a = fig.add_subplot(1, 2, 1) 214 | plt.imshow(im_seg) 215 | a.set_title(title1) 216 | a = fig.add_subplot(1, 2, 2) 217 | plt.imshow(im_gt) 218 | a.set_title(title2) 219 | 220 | if savefile is None: 221 | plt.show() 222 | else: 223 | plt.savefig(savefile) 224 | plt.close() 225 | 226 | def build_net(model_folder, model_code, n_classes, train_size, inpt_h, inpt_w, inpt_d, n_channels): 227 | model_path = os.path.join('models', model_folder) 228 | 229 | param_file = os.path.join(model_path, 'params.hdf5') 230 | bn_par_file = os.path.join(model_path, 'bn_pars.pkl') 231 | 232 | log = None 233 | for f_name in os.listdir(model_path): 234 | if f_name.endswith('.json') and not f_name.startswith('dice'): 235 | with open(os.path.join(model_path, f_name), 'r') as f: 236 | log = json.load(f) 237 | break 238 | if 'layers' not in log: 239 | log = None 240 | 241 | model_def = get_model(model_code) 242 | 243 | layer_vars = model_def.layer_vars if log is None else log['layers'] 244 | batchnorm = model_def.batchnorm 245 | loss_id = model_def.loss_id 246 | out_transfer = model_def.out_transfer 247 | 248 | batch_size = 1 249 | max_passes = 1 250 | inpt_dims = (inpt_h, inpt_w, inpt_d) 251 | 252 | n_report = train_size / batch_size 253 | max_iter = n_report * max_passes 254 | 255 | optimizer = 'adam' 256 | 257 | model = SequentialModel( 258 | image_height=inpt_dims[0], image_width=inpt_dims[1], 259 | image_depth=inpt_dims[2], n_channels=n_channels, 260 | n_output=n_classes, layer_vars=layer_vars, 261 | out_transfer=out_transfer, loss_id=loss_id, 262 | optimizer=optimizer, batch_size=batch_size, 263 | max_iter=max_iter, using_bn=batchnorm 264 | ) 265 | 266 | f_params = h5py.File(param_file, 'r') 267 | params = np.zeros(model.parameters.data.shape) 268 | params[...] = f_params['best_pars'] 269 | f_params.close() 270 | model.parameters.data[...] = params 271 | 272 | if batchnorm and os.path.exists(bn_par_file): 273 | with open(bn_par_file, 'r') as f: 274 | bn_pars = pickle.load(f) 275 | model.set_batchnorm_params(bn_pars) 276 | else: 277 | if batchnorm: 278 | raise AssertionError('Batch norm used but running metrics not available.') 279 | 280 | if batchnorm: 281 | predict = ash.BatchNormFuns( 282 | model=model, 283 | fun=model.predict, 284 | phase='infer' 285 | ) 286 | else: 287 | predict = model.predict 288 | 289 | return predict 290 | 291 | def get_data(data_name, x_only=False): 292 | data_path = os.path.join('data', 'datasets', data_name+'.hdf5') 293 | data = h5py.File(data_path, 'r') 294 | 295 | test_x = data['test_x'] 296 | test_y = data['test_y'] if not x_only else None 297 | 298 | return test_x, test_y 299 | 300 | def compute_results(predict, X, Y): 301 | for x, y in zip(X, Y): 302 | depth, n_channels, height, width = x.shape 303 | start = time.time() 304 | model_output = predict(x[np.newaxis]) 305 | end = time.time() 306 | n_classes = y.shape[-1] 307 | model_output = model_output.as_numpy_array() if isinstance(model_output, gnumpy.garray) else model_output 308 | seg = np.reshape( 309 | model_output, 310 | (height, width, depth, n_classes) 311 | ) 312 | gt = np.reshape(y, (height, width, depth, n_classes)) 313 | seg = discrete(seg, n_classes) 314 | dice_list = dice(seg, gt) 315 | seg = seg.argmax(axis=3) 316 | gt = gt.argmax(axis=3) 317 | dice_all = dice_alt(seg, gt) 318 | dice_list = [dice_all] + dice_list 319 | print '\tdice: ', dice_list 320 | print '\ttime taken: ', (end - start) 321 | print '-' * 20 322 | 323 | image = x[:, 0, :, :] 324 | image = np.transpose(image, (1, 2, 0)) 325 | yield (image, seg, gt, dice_list) 326 | 327 | class SeqPredict(object): 328 | def __init__(self, predict, n_classes): 329 | self.predict = predict 330 | self.n_classes = n_classes 331 | 332 | def __call__(self, x): 333 | n_classes = self.n_classes 334 | image = np.transpose(x[0], (1, 2, 3, 0)) 335 | sections = np.array([to_sections(modality) for modality in image], dtype='int16') # mod sect h w d 336 | sections = np.transpose(sections, (1, 4, 0, 2, 3)) 337 | seg_sections = [] 338 | for section in sections: 339 | depth, n_chans, height, width = section.shape 340 | model_output = self.predict(section[np.newaxis]) 341 | model_output = model_output.as_numpy_array() if isinstance(model_output, gnumpy.garray) else model_output 342 | seg = np.reshape( 343 | model_output, 344 | (height, width, depth, n_classes) 345 | ) 346 | seg = seg.argmax(axis=3) 347 | seg_sections.append(seg) 348 | final_seg = from_sections(seg_sections, original_shape=(x.shape[3], x.shape[4], x.shape[1])) 349 | 350 | seg_onehot = np.reshape(final_seg, (-1,)) 351 | seg_onehot = np.reshape(one_hot(seg_onehot, n_classes), (-1, n_classes)) 352 | 353 | return seg_onehot 354 | 355 | def save_results(image, seg, gt, result_path): 356 | slice_count = 0 357 | for _slice in np.arange(0, image.shape[-1], 1): 358 | im_slice = image[:, :, _slice] 359 | gt_slice = gt[:, :, _slice] 360 | seg_slice = seg[:, :, _slice] 361 | 362 | save_file = os.path.join(result_path, 'slice' + str(slice_count) + '.png') 363 | vis_col_result(im=im_slice, gt=gt_slice, seg=seg_slice, savefile=save_file) 364 | slice_count += 1 365 | 366 | def test(): 367 | model_folder = os.path.join('brats_fold0', 'session1_2') 368 | model_code = 'fcn_rffc2' 369 | data_name = 'brats2013_leaderboard_data' 370 | save_path = os.path.join('results', 'as_hdf', 'brats2013_leaderboard_results.hdf5') 371 | 372 | tx, _ = get_data(data_name, x_only=True) 373 | 374 | print 'Saving results to: ', save_path 375 | save_hdf5 = h5py.File(save_path, 'w') 376 | seg_maps = save_hdf5.create_dataset( 377 | 'test_result', (tx.shape[0], tx.shape[3], tx.shape[4], tx.shape[1]), dtype='int8') 378 | 379 | train_size, depth, n_chans, height, width = tx.shape 380 | n_classes = 5 381 | predict = build_net(model_folder, model_code, 382 | n_classes=n_classes, train_size=train_size, 383 | inpt_h=height, inpt_w=width, inpt_d=depth, 384 | n_channels=n_chans) 385 | 386 | index = 0 387 | for test_image in tx: 388 | model_output = predict(test_image[np.newaxis]) 389 | model_output = model_output.as_numpy_array() if isinstance(model_output, gnumpy.garray) else model_output 390 | fuzzy_seg = np.reshape( 391 | model_output, 392 | (height, width, depth, n_classes) 393 | ) 394 | seg = fuzzy_seg.argmax(axis=3) 395 | seg_maps[index,:,:,:] = seg 396 | 397 | index += 1 398 | 399 | save_hdf5.close() 400 | 401 | def demonstrate(): 402 | model_folder = os.path.join('brats_fold0', 'session1_2') 403 | model_code = 'fcn_rffc4' 404 | data_name = 'brats_fold0' 405 | sectionalized = False 406 | 407 | tx, ty = get_data(data_name) 408 | train_size, depth, n_chans, height, width = tx.shape 409 | n_classes = ty.shape[-1] 410 | if sectionalized: 411 | depth = depth / 2 412 | height = height / 2 413 | width = width / 2 414 | predict = build_net(model_folder, model_code, 415 | n_classes=n_classes, train_size=train_size, 416 | inpt_h=height, inpt_w=width, inpt_d=depth, 417 | n_channels=n_chans) 418 | 419 | if sectionalized: 420 | predict = SeqPredict(predict=predict, n_classes=n_classes) 421 | 422 | count = 1 423 | dice_lists = [] 424 | for image, seg, gt, dl in compute_results(predict, tx, ty): 425 | dice_lists.append(dl) 426 | result_path = os.path.join('results', model_folder, 'testim_'+str(count)) 427 | if not os.path.exists(result_path): 428 | os.makedirs(result_path) 429 | 430 | save_results(image, seg, gt, result_path) 431 | 432 | count += 1 433 | dice_matrix = np.array(dice_lists) 434 | dice_means = np.mean(dice_matrix, axis=0) 435 | print 'Mean dice values: ', dice_means 436 | 437 | if __name__ == '__main__': 438 | demonstrate() 439 | #test() -------------------------------------------------------------------------------- /folds.pkl: -------------------------------------------------------------------------------- 1 | (lp1 2 | (dp2 3 | S'test' 4 | p3 5 | (lp4 6 | S'BRATS2015_Training/HGG/brats_tcia_pat378_0001' 7 | p5 8 | aS'BRATS2015_Training/HGG/brats_tcia_pat260_0244' 9 | p6 10 | aS'BRATS2015_Training/HGG/brats_tcia_pat437_0001' 11 | p7 12 | aS'BRATS2015_Training/LGG/brats_tcia_pat101_0001' 13 | p8 14 | aS'BRATS2015_Training/HGG/brats_tcia_pat153_0002' 15 | p9 16 | aS'BRATS2015_Training/LGG/brats_tcia_pat298_0001' 17 | p10 18 | aS'BRATS2015_Training/HGG/brats_tcia_pat460_0001' 19 | p11 20 | aS'BRATS2015_Training/HGG/brats_2013_pat0011_1' 21 | p12 22 | aS'BRATS2015_Training/HGG/brats_tcia_pat447_0122' 23 | p13 24 | aS'BRATS2015_Training/LGG/brats_tcia_pat402_0001' 25 | p14 26 | aS'BRATS2015_Training/HGG/brats_tcia_pat491_0001' 27 | p15 28 | aS'BRATS2015_Training/HGG/brats_tcia_pat455_0001' 29 | p16 30 | aS'BRATS2015_Training/HGG/brats_tcia_pat444_0001' 31 | p17 32 | aS'BRATS2015_Training/HGG/brats_tcia_pat217_0001' 33 | p18 34 | aS'BRATS2015_Training/HGG/brats_tcia_pat120_0001' 35 | p19 36 | aS'BRATS2015_Training/LGG/brats_tcia_pat325_0001' 37 | p20 38 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_0387' 39 | p21 40 | aS'BRATS2015_Training/LGG/brats_tcia_pat354_0001' 41 | p22 42 | aS'BRATS2015_Training/HGG/brats_tcia_pat230_0710' 43 | p23 44 | aS'BRATS2015_Training/HGG/brats_tcia_pat247_0001' 45 | p24 46 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0001' 47 | p25 48 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_1627' 49 | p26 50 | aS'BRATS2015_Training/HGG/brats_tcia_pat439_0360' 51 | p27 52 | aS'BRATS2015_Training/LGG/brats_tcia_pat241_0001' 53 | p28 54 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0217' 55 | p29 56 | aS'BRATS2015_Training/HGG/brats_tcia_pat375_0001' 57 | p30 58 | aS'BRATS2015_Training/HGG/brats_tcia_pat153_0109' 59 | p31 60 | aS'BRATS2015_Training/LGG/brats_2013_pat0011_1' 61 | p32 62 | aS'BRATS2015_Training/HGG/brats_tcia_pat178_0002' 63 | p33 64 | aS'BRATS2015_Training/HGG/brats_tcia_pat186_0001' 65 | p34 66 | aS'BRATS2015_Training/HGG/brats_tcia_pat118_0001' 67 | p35 68 | aS'BRATS2015_Training/HGG/brats_tcia_pat234_0001' 69 | p36 70 | aS'BRATS2015_Training/HGG/brats_tcia_pat430_0001' 71 | p37 72 | aS'BRATS2015_Training/HGG/brats_tcia_pat111_0001' 73 | p38 74 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_0001' 75 | p39 76 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0156' 77 | p40 78 | aS'BRATS2015_Training/HGG/brats_tcia_pat150_0001' 79 | p41 80 | asS'train' 81 | p42 82 | (lp43 83 | S'BRATS2015_Training/HGG/brats_tcia_pat404_0001' 84 | p44 85 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_0618' 86 | p45 87 | aS'BRATS2015_Training/LGG/brats_tcia_pat408_0001' 88 | p46 89 | aS'BRATS2015_Training/HGG/brats_tcia_pat277_0001' 90 | p47 91 | aS'BRATS2015_Training/HGG/brats_tcia_pat314_0290' 92 | p48 93 | aS'BRATS2015_Training/HGG/brats_2013_pat0015_1' 94 | p49 95 | aS'BRATS2015_Training/HGG/brats_tcia_pat121_0001' 96 | p50 97 | aS'BRATS2015_Training/HGG/brats_tcia_pat162_0001' 98 | p51 99 | aS'BRATS2015_Training/HGG/brats_tcia_pat198_0001' 100 | p52 101 | aS'BRATS2015_Training/HGG/brats_tcia_pat203_0001' 102 | p53 103 | aS'BRATS2015_Training/HGG/brats_tcia_pat170_0002' 104 | p54 105 | aS'BRATS2015_Training/HGG/brats_2013_pat0027_1' 106 | p55 107 | aS'BRATS2015_Training/LGG/brats_tcia_pat312_0001' 108 | p56 109 | aS'BRATS2015_Training/LGG/brats_tcia_pat249_0001' 110 | p57 111 | aS'BRATS2015_Training/HGG/brats_tcia_pat474_0001' 112 | p58 113 | aS'BRATS2015_Training/LGG/brats_tcia_pat351_0001' 114 | p59 115 | aS'BRATS2015_Training/LGG/brats_tcia_pat480_0001' 116 | p60 117 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0369' 118 | p61 119 | aS'BRATS2015_Training/HGG/brats_tcia_pat153_0181' 120 | p62 121 | aS'BRATS2015_Training/HGG/brats_tcia_pat309_0203' 122 | p63 123 | aS'BRATS2015_Training/HGG/brats_tcia_pat164_0001' 124 | p64 125 | aS'BRATS2015_Training/HGG/brats_tcia_pat309_0462' 126 | p65 127 | aS'BRATS2015_Training/HGG/brats_tcia_pat226_0090' 128 | p66 129 | aS'BRATS2015_Training/HGG/brats_tcia_pat260_0001' 130 | p67 131 | aS'BRATS2015_Training/HGG/brats_2013_pat0008_1' 132 | p68 133 | aS'BRATS2015_Training/HGG/brats_tcia_pat436_0001' 134 | p69 135 | aS'BRATS2015_Training/HGG/brats_2013_pat0014_1' 136 | p70 137 | aS'BRATS2015_Training/LGG/brats_tcia_pat152_0001' 138 | p71 139 | aS'BRATS2015_Training/HGG/brats_tcia_pat167_0001' 140 | p72 141 | aS'BRATS2015_Training/HGG/brats_tcia_pat221_0001' 142 | p73 143 | aS'BRATS2015_Training/HGG/brats_tcia_pat193_0002' 144 | p74 145 | aS'BRATS2015_Training/HGG/brats_tcia_pat499_0001' 146 | p75 147 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0105' 148 | p76 149 | aS'BRATS2015_Training/HGG/brats_tcia_pat429_0001' 150 | p77 151 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_0557' 152 | p78 153 | aS'BRATS2015_Training/HGG/brats_tcia_pat444_0038' 154 | p79 155 | aS'BRATS2015_Training/HGG/brats_tcia_pat338_0001' 156 | p80 157 | aS'BRATS2015_Training/LGG/brats_tcia_pat254_0001' 158 | p81 159 | aS'BRATS2015_Training/HGG/brats_tcia_pat439_0333' 160 | p82 161 | aS'BRATS2015_Training/HGG/brats_tcia_pat200_0210' 162 | p83 163 | aS'BRATS2015_Training/HGG/brats_tcia_pat124_0003' 164 | p84 165 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_1231' 166 | p85 167 | aS'BRATS2015_Training/HGG/brats_tcia_pat448_0001' 168 | p86 169 | aS'BRATS2015_Training/HGG/brats_tcia_pat230_0637' 170 | p87 171 | aS'BRATS2015_Training/HGG/brats_tcia_pat231_0001' 172 | p88 173 | aS'BRATS2015_Training/HGG/brats_tcia_pat179_0001' 174 | p89 175 | aS'BRATS2015_Training/HGG/brats_tcia_pat230_0511' 176 | p90 177 | aS'BRATS2015_Training/HGG/brats_tcia_pat439_0263' 178 | p91 179 | aS'BRATS2015_Training/HGG/brats_tcia_pat235_0001' 180 | p92 181 | aS'BRATS2015_Training/HGG/brats_tcia_pat479_0001' 182 | p93 183 | aS'BRATS2015_Training/HGG/brats_tcia_pat198_0283' 184 | p94 185 | aS'BRATS2015_Training/LGG/brats_2013_pat0014_1' 186 | p95 187 | aS'BRATS2015_Training/LGG/brats_tcia_pat449_0001' 188 | p96 189 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_1126' 190 | p97 191 | aS'BRATS2015_Training/HGG/brats_tcia_pat131_0001' 192 | p98 193 | aS'BRATS2015_Training/HGG/brats_tcia_pat419_0001' 194 | p99 195 | aS'BRATS2015_Training/HGG/brats_tcia_pat280_0001' 196 | p100 197 | aS'BRATS2015_Training/HGG/brats_tcia_pat138_0001' 198 | p101 199 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0002' 200 | p102 201 | aS'BRATS2015_Training/HGG/brats_2013_pat0004_1' 202 | p103 203 | aS'BRATS2015_Training/LGG/brats_tcia_pat346_0001' 204 | p104 205 | aS'BRATS2015_Training/LGG/brats_2013_pat0012_1' 206 | p105 207 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0294' 208 | p106 209 | aS'BRATS2015_Training/LGG/brats_tcia_pat282_0001' 210 | p107 211 | aS'BRATS2015_Training/LGG/brats_tcia_pat177_0001' 212 | p108 213 | aS'BRATS2015_Training/HGG/brats_tcia_pat309_0243' 214 | p109 215 | aS'BRATS2015_Training/LGG/brats_2013_pat0001_1' 216 | p110 217 | aS'BRATS2015_Training/HGG/brats_tcia_pat157_0001' 218 | p111 219 | aS'BRATS2015_Training/HGG/brats_tcia_pat105_0001' 220 | p112 221 | aS'BRATS2015_Training/HGG/brats_tcia_pat165_0001' 222 | p113 223 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_0383' 224 | p114 225 | aS'BRATS2015_Training/HGG/brats_tcia_pat425_0001' 226 | p115 227 | aS'BRATS2015_Training/HGG/brats_tcia_pat260_0075' 228 | p116 229 | aS'BRATS2015_Training/HGG/brats_tcia_pat184_0001' 230 | p117 231 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0139' 232 | p118 233 | aS'BRATS2015_Training/LGG/brats_tcia_pat420_0001' 234 | p119 235 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0527' 236 | p120 237 | aS'BRATS2015_Training/HGG/brats_tcia_pat314_0150' 238 | p121 239 | aS'BRATS2015_Training/HGG/brats_tcia_pat332_0001' 240 | p122 241 | aS'BRATS2015_Training/HGG/brats_2013_pat0025_1' 242 | p123 243 | aS'BRATS2015_Training/HGG/brats_tcia_pat361_0001' 244 | p124 245 | aS'BRATS2015_Training/HGG/brats_tcia_pat211_0001' 246 | p125 247 | aS'BRATS2015_Training/HGG/brats_tcia_pat265_0001' 248 | p126 249 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0176' 250 | p127 251 | aS'BRATS2015_Training/HGG/brats_tcia_pat222_0122' 252 | p128 253 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0479' 254 | p129 255 | aS'BRATS2015_Training/LGG/brats_tcia_pat442_0001' 256 | p130 257 | aS'BRATS2015_Training/HGG/brats_tcia_pat260_0152' 258 | p131 259 | aS'BRATS2015_Training/LGG/brats_2013_pat0002_1' 260 | p132 261 | aS'BRATS2015_Training/HGG/brats_tcia_pat417_0001' 262 | p133 263 | aS'BRATS2015_Training/HGG/brats_tcia_pat258_0001' 264 | p134 265 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_0001' 266 | p135 267 | aS'BRATS2015_Training/HGG/brats_tcia_pat290_0669' 268 | p136 269 | aS'BRATS2015_Training/HGG/brats_tcia_pat432_0001' 270 | p137 271 | aS'BRATS2015_Training/HGG/brats_tcia_pat319_0001' 272 | p138 273 | aS'BRATS2015_Training/HGG/brats_tcia_pat417_0019' 274 | p139 275 | aS'BRATS2015_Training/LGG/brats_2013_pat0015_1' 276 | p140 277 | aS'BRATS2015_Training/LGG/brats_tcia_pat276_0001' 278 | p141 279 | aS'BRATS2015_Training/HGG/brats_2013_pat0001_1' 280 | p142 281 | aS'BRATS2015_Training/HGG/brats_tcia_pat290_0305' 282 | p143 283 | aS'BRATS2015_Training/HGG/brats_tcia_pat300_0001' 284 | p144 285 | aS'BRATS2015_Training/HGG/brats_tcia_pat147_0001' 286 | p145 287 | aS'BRATS2015_Training/LGG/brats_tcia_pat387_0001' 288 | p146 289 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_1126' 290 | p147 291 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_1426' 292 | p148 293 | aS'BRATS2015_Training/HGG/brats_tcia_pat117_0001' 294 | p149 295 | aS'BRATS2015_Training/HGG/brats_tcia_pat368_0001' 296 | p150 297 | aS'BRATS2015_Training/HGG/brats_tcia_pat343_0001' 298 | p151 299 | aS'BRATS2015_Training/LGG/brats_tcia_pat466_0001' 300 | p152 301 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0595' 302 | p153 303 | aS'BRATS2015_Training/HGG/brats_tcia_pat321_0001' 304 | p154 305 | aS'BRATS2015_Training/HGG/brats_tcia_pat260_0129' 306 | p155 307 | aS'BRATS2015_Training/HGG/brats_tcia_pat396_0117' 308 | p156 309 | aS'BRATS2015_Training/HGG/brats_tcia_pat444_0104' 310 | p157 311 | aS'BRATS2015_Training/HGG/brats_2013_pat0007_1' 312 | p158 313 | aS'BRATS2015_Training/HGG/brats_2013_pat0002_1' 314 | p159 315 | aS'BRATS2015_Training/LGG/brats_tcia_pat202_0001' 316 | p160 317 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_0001' 318 | p161 319 | aS'BRATS2015_Training/HGG/brats_tcia_pat113_0001' 320 | p162 321 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_0801' 322 | p163 323 | aS'BRATS2015_Training/HGG/brats_tcia_pat153_0277' 324 | p164 325 | aS'BRATS2015_Training/HGG/brats_tcia_pat205_0001' 326 | p165 327 | aS'BRATS2015_Training/HGG/brats_tcia_pat498_0001' 328 | p166 329 | aS'BRATS2015_Training/LGG/brats_tcia_pat410_0001' 330 | p167 331 | aS'BRATS2015_Training/HGG/brats_tcia_pat280_0003' 332 | p168 333 | aS'BRATS2015_Training/LGG/brats_2013_pat0004_1' 334 | p169 335 | aS'BRATS2015_Training/HGG/brats_tcia_pat230_0481' 336 | p170 337 | aS'BRATS2015_Training/HGG/brats_tcia_pat372_0001' 338 | p171 339 | aS'BRATS2015_Training/LGG/brats_tcia_pat307_0001' 340 | p172 341 | aS'BRATS2015_Training/HGG/brats_tcia_pat406_0001' 342 | p173 343 | aS'BRATS2015_Training/LGG/brats_tcia_pat103_0001' 344 | p174 345 | aS'BRATS2015_Training/HGG/brats_2013_pat0010_1' 346 | p175 347 | aS'BRATS2015_Training/LGG/brats_tcia_pat299_0001' 348 | p176 349 | aS'BRATS2015_Training/LGG/brats_tcia_pat255_0001' 350 | p177 351 | aS'BRATS2015_Training/HGG/brats_tcia_pat401_0001' 352 | p178 353 | aS'BRATS2015_Training/HGG/brats_tcia_pat322_0001' 354 | p179 355 | aS'BRATS2015_Training/HGG/brats_tcia_pat447_0199' 356 | p180 357 | aS'BRATS2015_Training/LGG/brats_tcia_pat130_0001' 358 | p181 359 | aS'BRATS2015_Training/HGG/brats_tcia_pat309_0320' 360 | p182 361 | aS'BRATS2015_Training/HGG/brats_2013_pat0024_1' 362 | p183 363 | aS'BRATS2015_Training/HGG/brats_tcia_pat274_0001' 364 | p184 365 | aS'BRATS2015_Training/LGG/brats_tcia_pat470_0001' 366 | p185 367 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_0780' 368 | p186 369 | aS'BRATS2015_Training/HGG/brats_tcia_pat156_0001' 370 | p187 371 | aS'BRATS2015_Training/HGG/brats_tcia_pat469_0001' 372 | p188 373 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_0907' 374 | p189 375 | aS'BRATS2015_Training/HGG/brats_tcia_pat314_0016' 376 | p190 377 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_0950' 378 | p191 379 | aS'BRATS2015_Training/HGG/brats_tcia_pat153_0165' 380 | p192 381 | aS'BRATS2015_Training/HGG/brats_tcia_pat133_0001' 382 | p193 383 | aS'BRATS2015_Training/HGG/brats_tcia_pat242_0001' 384 | p194 385 | aS'BRATS2015_Training/HGG/brats_tcia_pat201_0001' 386 | p195 387 | aS'BRATS2015_Training/HGG/brats_tcia_pat392_0340' 388 | p196 389 | aS'BRATS2015_Training/HGG/brats_tcia_pat149_0001' 390 | p197 391 | aS'BRATS2015_Training/HGG/brats_tcia_pat296_0001' 392 | p198 393 | aS'BRATS2015_Training/HGG/brats_2013_pat0005_1' 394 | p199 395 | aS'BRATS2015_Training/LGG/brats_tcia_pat451_0001' 396 | p200 397 | aS'BRATS2015_Training/HGG/brats_tcia_pat377_0640' 398 | p201 399 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0417' 400 | p202 401 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_1165' 402 | p203 403 | aS'BRATS2015_Training/LGG/brats_2013_pat0006_1' 404 | p204 405 | aS'BRATS2015_Training/HGG/brats_tcia_pat473_0001' 406 | p205 407 | aS'BRATS2015_Training/HGG/brats_tcia_pat218_0001' 408 | p206 409 | aS'BRATS2015_Training/LGG/brats_tcia_pat261_0001' 410 | p207 411 | aS'BRATS2015_Training/LGG/brats_tcia_pat490_0001' 412 | p208 413 | aS'BRATS2015_Training/HGG/brats_tcia_pat444_0077' 414 | p209 415 | aS'BRATS2015_Training/LGG/brats_2013_pat0008_1' 416 | p210 417 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0815' 418 | p211 419 | aS'BRATS2015_Training/HGG/brats_2013_pat0013_1' 420 | p212 421 | aS'BRATS2015_Training/HGG/brats_tcia_pat192_0001' 422 | p213 423 | aS'BRATS2015_Training/HGG/brats_tcia_pat290_0001' 424 | p214 425 | aS'BRATS2015_Training/HGG/brats_tcia_pat390_0001' 426 | p215 427 | aS'BRATS2015_Training/LGG/brats_2013_pat0013_1' 428 | p216 429 | aS'BRATS2015_Training/LGG/brats_tcia_pat462_0001' 430 | p217 431 | aS'BRATS2015_Training/LGG/brats_tcia_pat330_0001' 432 | p218 433 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_0909' 434 | p219 435 | aS'BRATS2015_Training/HGG/brats_tcia_pat409_0001' 436 | p220 437 | aS'BRATS2015_Training/HGG/brats_tcia_pat173_0001' 438 | p221 439 | aS'BRATS2015_Training/HGG/brats_tcia_pat463_0001' 440 | p222 441 | aS'BRATS2015_Training/HGG/brats_tcia_pat278_0001' 442 | p223 443 | aS'BRATS2015_Training/HGG/brats_tcia_pat374_0356' 444 | p224 445 | aS'BRATS2015_Training/HGG/brats_tcia_pat226_0001' 446 | p225 447 | aS'BRATS2015_Training/HGG/brats_tcia_pat412_0001' 448 | p226 449 | aS'BRATS2015_Training/HGG/brats_tcia_pat309_0120' 450 | p227 451 | aS'BRATS2015_Training/HGG/brats_tcia_pat199_0001' 452 | p228 453 | aS'BRATS2015_Training/HGG/brats_2013_pat0006_1' 454 | p229 455 | aS'BRATS2015_Training/HGG/brats_tcia_pat468_0001' 456 | p230 457 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0217' 458 | p231 459 | aS'BRATS2015_Training/LGG/brats_tcia_pat175_0001' 460 | p232 461 | aS'BRATS2015_Training/LGG/brats_tcia_pat141_0001' 462 | p233 463 | aS'BRATS2015_Training/HGG/brats_tcia_pat331_0001' 464 | p234 465 | aS'BRATS2015_Training/HGG/brats_tcia_pat394_0001' 466 | p235 467 | aS'BRATS2015_Training/HGG/brats_tcia_pat190_0001' 468 | p236 469 | aS'BRATS2015_Training/HGG/brats_tcia_pat168_0001' 470 | p237 471 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_1354' 472 | p238 473 | aS'BRATS2015_Training/HGG/brats_tcia_pat314_0001' 474 | p239 475 | aS'BRATS2015_Training/HGG/brats_tcia_pat424_0001' 476 | p240 477 | aS'BRATS2015_Training/HGG/brats_2013_pat0003_1' 478 | p241 479 | aS'BRATS2015_Training/HGG/brats_tcia_pat328_0001' 480 | p242 481 | aS'BRATS2015_Training/HGG/brats_tcia_pat444_0033' 482 | p243 483 | asS'valid' 484 | p244 485 | (lp245 486 | S'BRATS2015_Training/LGG/brats_tcia_pat109_0001' 487 | p246 488 | aS'BRATS2015_Training/HGG/brats_tcia_pat283_0001' 489 | p247 490 | aS'BRATS2015_Training/HGG/brats_2013_pat0012_1' 491 | p248 492 | aS'BRATS2015_Training/HGG/brats_tcia_pat222_0304' 493 | p249 494 | aS'BRATS2015_Training/HGG/brats_tcia_pat309_0001' 495 | p250 496 | aS'BRATS2015_Training/HGG/brats_tcia_pat439_0001' 497 | p251 498 | aS'BRATS2015_Training/HGG/brats_2013_pat0009_1' 499 | p252 500 | aS'BRATS2015_Training/HGG/brats_2013_pat0022_1' 501 | p253 502 | aS'BRATS2015_Training/LGG/brats_tcia_pat483_0001' 503 | p254 504 | aS'BRATS2015_Training/LGG/brats_tcia_pat413_0001' 505 | p255 506 | aS'BRATS2015_Training/HGG/brats_tcia_pat290_0412' 507 | p256 508 | aS'BRATS2015_Training/HGG/brats_tcia_pat153_0294' 509 | p257 510 | aS'BRATS2015_Training/HGG/brats_tcia_pat171_0200' 511 | p258 512 | aS'BRATS2015_Training/HGG/brats_tcia_pat411_0001' 513 | p259 514 | aS'BRATS2015_Training/HGG/brats_tcia_pat391_0002' 515 | p260 516 | aS'BRATS2015_Training/HGG/brats_tcia_pat399_0290' 517 | p261 518 | aS'BRATS2015_Training/HGG/brats_tcia_pat335_0001' 519 | p262 520 | aS'BRATS2015_Training/HGG/brats_tcia_pat290_0580' 521 | p263 522 | aS'BRATS2015_Training/HGG/brats_2013_pat0026_1' 523 | p264 524 | aS'BRATS2015_Training/LGG/brats_tcia_pat266_0001' 525 | p265 526 | aS'BRATS2015_Training/HGG/brats_tcia_pat135_0001' 527 | p266 528 | aS'BRATS2015_Training/LGG/brats_tcia_pat393_0001' 529 | p267 530 | aS'BRATS2015_Training/HGG/brats_tcia_pat471_0001' 531 | p268 532 | aS'BRATS2015_Training/HGG/brats_tcia_pat478_0001' 533 | p269 534 | aS'BRATS2015_Training/HGG/brats_tcia_pat151_0001' 535 | p270 536 | aS'BRATS2015_Training/HGG/brats_tcia_pat257_0001' 537 | p271 538 | aS'BRATS2015_Training/LGG/brats_tcia_pat493_0001' 539 | p272 540 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_0569' 541 | p273 542 | aS'BRATS2015_Training/HGG/brats_tcia_pat260_0317' 543 | p274 544 | aS'BRATS2015_Training/HGG/brats_tcia_pat447_0313' 545 | p275 546 | aS'BRATS2015_Training/HGG/brats_tcia_pat180_0001' 547 | p276 548 | aS'BRATS2015_Training/LGG/brats_tcia_pat428_0001' 549 | p277 550 | aS'BRATS2015_Training/HGG/brats_tcia_pat208_0001' 551 | p278 552 | aS'BRATS2015_Training/HGG/brats_tcia_pat377_0001' 553 | p279 554 | aS'BRATS2015_Training/HGG/brats_tcia_pat370_1470' 555 | p280 556 | aS'BRATS2015_Training/HGG/brats_tcia_pat183_0001' 557 | p281 558 | aS'BRATS2015_Training/HGG/brats_tcia_pat230_0199' 559 | p282 560 | asa(dp283 561 | g3 562 | (lp284 563 | g81 564 | ag82 565 | ag83 566 | ag84 567 | ag85 568 | ag86 569 | ag87 570 | ag88 571 | ag89 572 | ag90 573 | ag91 574 | ag92 575 | ag93 576 | ag94 577 | ag95 578 | ag96 579 | ag97 580 | ag98 581 | ag99 582 | ag100 583 | ag101 584 | ag102 585 | ag103 586 | ag104 587 | ag105 588 | ag106 589 | ag107 590 | ag108 591 | ag109 592 | ag110 593 | ag111 594 | ag112 595 | ag113 596 | ag114 597 | ag115 598 | ag116 599 | ag117 600 | asg42 601 | (lp285 602 | g246 603 | ag247 604 | ag248 605 | ag249 606 | ag250 607 | ag251 608 | ag252 609 | ag253 610 | ag254 611 | ag255 612 | ag256 613 | ag257 614 | ag258 615 | ag259 616 | ag260 617 | ag261 618 | ag262 619 | ag263 620 | ag264 621 | ag265 622 | ag266 623 | ag267 624 | ag268 625 | ag269 626 | ag270 627 | ag271 628 | ag272 629 | ag273 630 | ag274 631 | ag275 632 | ag276 633 | ag277 634 | ag278 635 | ag279 636 | ag280 637 | ag281 638 | ag282 639 | ag5 640 | ag6 641 | ag7 642 | ag8 643 | ag9 644 | ag10 645 | ag11 646 | ag12 647 | ag13 648 | ag14 649 | ag15 650 | ag16 651 | ag17 652 | ag18 653 | ag19 654 | ag20 655 | ag21 656 | ag22 657 | ag23 658 | ag24 659 | ag25 660 | ag26 661 | ag27 662 | ag28 663 | ag29 664 | ag30 665 | ag31 666 | ag32 667 | ag33 668 | ag34 669 | ag35 670 | ag36 671 | ag37 672 | ag38 673 | ag39 674 | ag40 675 | ag41 676 | ag118 677 | ag119 678 | ag120 679 | ag121 680 | ag122 681 | ag123 682 | ag124 683 | ag125 684 | ag126 685 | ag127 686 | ag128 687 | ag129 688 | ag130 689 | ag131 690 | ag132 691 | ag133 692 | ag134 693 | ag135 694 | ag136 695 | ag137 696 | ag138 697 | ag139 698 | ag140 699 | ag141 700 | ag142 701 | ag143 702 | ag144 703 | ag145 704 | ag146 705 | ag147 706 | ag148 707 | ag149 708 | ag150 709 | ag151 710 | ag152 711 | ag153 712 | ag154 713 | ag155 714 | ag156 715 | ag157 716 | ag158 717 | ag159 718 | ag160 719 | ag161 720 | ag162 721 | ag163 722 | ag164 723 | ag165 724 | ag166 725 | ag167 726 | ag168 727 | ag169 728 | ag170 729 | ag171 730 | ag172 731 | ag173 732 | ag174 733 | ag175 734 | ag176 735 | ag177 736 | ag178 737 | ag179 738 | ag180 739 | ag181 740 | ag182 741 | ag183 742 | ag184 743 | ag185 744 | ag186 745 | ag187 746 | ag188 747 | ag189 748 | ag190 749 | ag191 750 | ag192 751 | ag193 752 | ag194 753 | ag195 754 | ag196 755 | ag197 756 | ag198 757 | ag199 758 | ag200 759 | ag201 760 | ag202 761 | ag203 762 | ag204 763 | ag205 764 | ag206 765 | ag207 766 | ag208 767 | ag209 768 | ag210 769 | ag211 770 | ag212 771 | ag213 772 | ag214 773 | ag215 774 | ag216 775 | ag217 776 | ag218 777 | ag219 778 | ag220 779 | ag221 780 | ag222 781 | ag223 782 | ag224 783 | ag225 784 | ag226 785 | ag227 786 | ag228 787 | ag229 788 | ag230 789 | ag231 790 | ag232 791 | ag233 792 | ag234 793 | ag235 794 | ag236 795 | ag237 796 | ag238 797 | ag239 798 | ag240 799 | ag241 800 | ag242 801 | ag243 802 | asg244 803 | (lp286 804 | g44 805 | ag45 806 | ag46 807 | ag47 808 | ag48 809 | ag49 810 | ag50 811 | ag51 812 | ag52 813 | ag53 814 | ag54 815 | ag55 816 | ag56 817 | ag57 818 | ag58 819 | ag59 820 | ag60 821 | ag61 822 | ag62 823 | ag63 824 | ag64 825 | ag65 826 | ag66 827 | ag67 828 | ag68 829 | ag69 830 | ag70 831 | ag71 832 | ag72 833 | ag73 834 | ag74 835 | ag75 836 | ag76 837 | ag77 838 | ag78 839 | ag79 840 | ag80 841 | asa(dp287 842 | g3 843 | (lp288 844 | g155 845 | ag156 846 | ag157 847 | ag158 848 | ag159 849 | ag160 850 | ag161 851 | ag162 852 | ag163 853 | ag164 854 | ag165 855 | ag166 856 | ag167 857 | ag168 858 | ag169 859 | ag170 860 | ag171 861 | ag172 862 | ag173 863 | ag174 864 | ag175 865 | ag176 866 | ag177 867 | ag178 868 | ag179 869 | ag180 870 | ag181 871 | ag182 872 | ag183 873 | ag184 874 | ag185 875 | ag186 876 | ag187 877 | ag188 878 | ag189 879 | ag190 880 | ag191 881 | asg42 882 | (lp289 883 | g246 884 | ag247 885 | ag248 886 | ag249 887 | ag250 888 | ag251 889 | ag252 890 | ag253 891 | ag254 892 | ag255 893 | ag256 894 | ag257 895 | ag258 896 | ag259 897 | ag260 898 | ag261 899 | ag262 900 | ag263 901 | ag264 902 | ag265 903 | ag266 904 | ag267 905 | ag268 906 | ag269 907 | ag270 908 | ag271 909 | ag272 910 | ag273 911 | ag274 912 | ag275 913 | ag276 914 | ag277 915 | ag278 916 | ag279 917 | ag280 918 | ag281 919 | ag282 920 | ag5 921 | ag6 922 | ag7 923 | ag8 924 | ag9 925 | ag10 926 | ag11 927 | ag12 928 | ag13 929 | ag14 930 | ag15 931 | ag16 932 | ag17 933 | ag18 934 | ag19 935 | ag20 936 | ag21 937 | ag22 938 | ag23 939 | ag24 940 | ag25 941 | ag26 942 | ag27 943 | ag28 944 | ag29 945 | ag30 946 | ag31 947 | ag32 948 | ag33 949 | ag34 950 | ag35 951 | ag36 952 | ag37 953 | ag38 954 | ag39 955 | ag40 956 | ag41 957 | ag44 958 | ag45 959 | ag46 960 | ag47 961 | ag48 962 | ag49 963 | ag50 964 | ag51 965 | ag52 966 | ag53 967 | ag54 968 | ag55 969 | ag56 970 | ag57 971 | ag58 972 | ag59 973 | ag60 974 | ag61 975 | ag62 976 | ag63 977 | ag64 978 | ag65 979 | ag66 980 | ag67 981 | ag68 982 | ag69 983 | ag70 984 | ag71 985 | ag72 986 | ag73 987 | ag74 988 | ag75 989 | ag76 990 | ag77 991 | ag78 992 | ag79 993 | ag80 994 | ag81 995 | ag82 996 | ag83 997 | ag84 998 | ag85 999 | ag86 1000 | ag87 1001 | ag88 1002 | ag89 1003 | ag90 1004 | ag91 1005 | ag92 1006 | ag93 1007 | ag94 1008 | ag95 1009 | ag96 1010 | ag97 1011 | ag98 1012 | ag99 1013 | ag100 1014 | ag101 1015 | ag102 1016 | ag103 1017 | ag104 1018 | ag105 1019 | ag106 1020 | ag107 1021 | ag108 1022 | ag109 1023 | ag110 1024 | ag111 1025 | ag112 1026 | ag113 1027 | ag114 1028 | ag115 1029 | ag116 1030 | ag117 1031 | ag192 1032 | ag193 1033 | ag194 1034 | ag195 1035 | ag196 1036 | ag197 1037 | ag198 1038 | ag199 1039 | ag200 1040 | ag201 1041 | ag202 1042 | ag203 1043 | ag204 1044 | ag205 1045 | ag206 1046 | ag207 1047 | ag208 1048 | ag209 1049 | ag210 1050 | ag211 1051 | ag212 1052 | ag213 1053 | ag214 1054 | ag215 1055 | ag216 1056 | ag217 1057 | ag218 1058 | ag219 1059 | ag220 1060 | ag221 1061 | ag222 1062 | ag223 1063 | ag224 1064 | ag225 1065 | ag226 1066 | ag227 1067 | ag228 1068 | ag229 1069 | ag230 1070 | ag231 1071 | ag232 1072 | ag233 1073 | ag234 1074 | ag235 1075 | ag236 1076 | ag237 1077 | ag238 1078 | ag239 1079 | ag240 1080 | ag241 1081 | ag242 1082 | ag243 1083 | asg244 1084 | (lp290 1085 | g118 1086 | ag119 1087 | ag120 1088 | ag121 1089 | ag122 1090 | ag123 1091 | ag124 1092 | ag125 1093 | ag126 1094 | ag127 1095 | ag128 1096 | ag129 1097 | ag130 1098 | ag131 1099 | ag132 1100 | ag133 1101 | ag134 1102 | ag135 1103 | ag136 1104 | ag137 1105 | ag138 1106 | ag139 1107 | ag140 1108 | ag141 1109 | ag142 1110 | ag143 1111 | ag144 1112 | ag145 1113 | ag146 1114 | ag147 1115 | ag148 1116 | ag149 1117 | ag150 1118 | ag151 1119 | ag152 1120 | ag153 1121 | ag154 1122 | asa. -------------------------------------------------------------------------------- /helper/get_dice.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import matplotlib.pyplot as plt 5 | 6 | train_folder = 'brats_fold0//session1_2' 7 | dice_path = os.path.join('..', 'models', train_folder, 'dice.json') 8 | 9 | with open(dice_path, 'r') as f: 10 | dice_dict = json.load(f) 11 | 12 | dice_vals = np.array(dice_dict['dice_values']) 13 | 14 | four_classes = dice_vals[:, 2:] 15 | mean_dice = np.mean(four_classes, axis=0) 16 | std_dice = np.std(four_classes, axis=0) 17 | print mean_dice 18 | print std_dice 19 | 20 | colors = ['lightgreen', 'yellow', 'orange', 'lightcoral'] 21 | bplot = plt.boxplot(four_classes, patch_artist=True, showmeans=True) 22 | plt.yticks(np.arange(0., 1., 0.1)) 23 | for patch, color in zip(bplot['boxes'], colors): 24 | patch.set_facecolor(color) 25 | ax = plt.gca() 26 | ax.grid(True) 27 | plt.show() -------------------------------------------------------------------------------- /helper/store_mha_results.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import SimpleITK as sitk 4 | import os 5 | import json 6 | 7 | save_dir = os.path.join('..', 'results', 'as_mha', 'brats2013_leaderboard') 8 | file_path = os.path.join('..', 'results', 'as_hdf', 'brats2013_leaderboard_results.hdf5') 9 | seg_file = h5py.File(file_path, 'r') 10 | cropped_seg_maps = seg_file['test_result'] 11 | 12 | test_set_info_path = os.path.join('..', 'data', 'info', 'brats2013_leaderboard_info.json') 13 | with open(test_set_info_path, 'r') as f: 14 | info = json.load(f) 15 | 16 | if 'sizes' not in info: 17 | print 'Using default size (240, 240, 155).' 18 | size_list = [(155, 240, 240)] * len(info['names']) 19 | info['sizes'] = size_list 20 | 21 | for crop_seg, slice_pairs, seg_name, array_size in zip(cropped_seg_maps, info['slices_list'], info['names'], info['sizes']): 22 | size = (array_size[1], array_size[2], array_size[0]) 23 | full_seg = np.zeros(size, dtype='int8') 24 | slices = [slice(s[0], s[1]) for s in slice_pairs] 25 | x_s, y_s, z_s = slices 26 | full_seg[x_s, y_s, z_s] = crop_seg 27 | full_seg = np.transpose(full_seg, (2, 0, 1)) 28 | 29 | seg_map_itk = sitk.GetImageFromArray(full_seg) 30 | assert(seg_map_itk.GetSize() == tuple(array_size[::-1])) 31 | 32 | save_path = os.path.join(save_dir, seg_name) 33 | print 'Saving image as: ', save_path 34 | sitk.WriteImage(seg_map_itk, str(save_path)) 35 | -------------------------------------------------------------------------------- /helper/view_test_images.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | def gen_ims(image_stack): 6 | for image in image_stack: 7 | #image = np.transpose(image, (1, 0, 2, 3)) 8 | yield image 9 | 10 | data = h5py.File('..//data//datasets//brats_fold0.hdf5', 'r') 11 | 12 | starting = 3 13 | x = data['test_x'][starting:] 14 | 15 | for im in gen_ims(x): 16 | for _slice in np.arange(0, im.shape[0], 1): 17 | im_slice = im[_slice] 18 | fig = plt.figure() 19 | 20 | i = 1 21 | keys = ['Flair', 'T1', 'T1c', 'T2'] 22 | for modality in im_slice: 23 | a = fig.add_subplot(2, 2, i) 24 | plt.imshow(modality, cmap='Greys_r') 25 | a.set_title(keys[(i-1)]) 26 | plt.axis('off') 27 | i += 1 28 | plt.show() 29 | plt.close() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | breze (pre-0.1, /home/baris/Documents/breze) 2 | climin (pre-0.1, /home/baris/Documents/climin-master/climin-master) 3 | cudamat (0.3, /home/baris/Documents/cudamat) 4 | cycler (0.10.0) 5 | dask (0.10.0) 6 | functools32 (3.2.3.post2) 7 | h5py (2.7.0) 8 | matplotlib (2.0.2) 9 | networkx (1.11) 10 | nose (1.3.7) 11 | numpy (1.13.1) 12 | pdfkit (0.5.0) 13 | Pillow (3.3.0) 14 | pip (9.0.1) 15 | pyparsing (2.2.0) 16 | python-dateutil (2.6.1) 17 | pytz (2017.2) 18 | PyYAML (3.11) 19 | scikit-image (0.13.dev0, /home/baris/Libraries/scikit-image-master) 20 | scikit-learn (0.18.2) 21 | scipy (0.19.1) 22 | setuptools (27.2.0) 23 | SimpleITK (1.0.0) 24 | six (1.10.0) 25 | subprocess32 (3.2.7) 26 | Theano (0.9.0) 27 | toolz (0.8.0) 28 | wheel (0.29.0) 29 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import h5py 3 | import cPickle as pickle 4 | import os 5 | import json 6 | 7 | import numpy as np 8 | import gnumpy 9 | 10 | import ash 11 | import brain_data_scripts as bds 12 | from model_defs import get_model 13 | from conv3d.model import SequentialModel 14 | 15 | def load_mhas_as_dict(path): 16 | """ 17 | Takes a path to a dictionary holding a series of subdirectories, 18 | where each subdirectory corresponds to one of the MRI modalities 19 | 'Flair', 'T1', 'T1c' and 'T2'. The subdir corresponding to a mo- 20 | dality must also contain the name of that modality in its name. 21 | In other words, we assume the following kind of file hierarchy: 22 | . 23 | ├── VSD.Brain_3more.XX.O.OT.54517 24 | │   ├── License_CC_BY_NC_SA_3.0.txt 25 | │   └── VSD.Brain_3more.XX.O.OT.54517.mha 26 | ├── VSD.Brain.XX.O.MR_Flair.54512 27 | │   ├── License_CC_BY_NC_SA_3.0.txt 28 | │   └── VSD.Brain.XX.O.MR_Flair.54512.mha 29 | ├── VSD.Brain.XX.O.MR_T1.54513 30 | │   ├── License_CC_BY_NC_SA_3.0.txt 31 | │   └── VSD.Brain.XX.O.MR_T1.54513.mha 32 | ├── VSD.Brain.XX.O.MR_T1c.54514 33 | │   ├── License_CC_BY_NC_SA_3.0.txt 34 | │   └── VSD.Brain.XX.O.MR_T1c.54514.mha 35 | └── VSD.Brain.XX.O.MR_T2.54515 36 | ├── License_CC_BY_NC_SA_3.0.txt 37 | └── VSD.Brain.XX.O.MR_T2.54515.mha 38 | 39 | The method will return a tuple dictionary with the modalities 40 | as keys and numpy ndarrays as values. 41 | """ 42 | return bds.get_im(path) 43 | 44 | 45 | def load_dict_as_inputable_ndarray(im): 46 | """" 47 | The method will take a dictionary like the one returned by 48 | load_mhas_as_dict and return a tuple (image, slices) 49 | where image is a 5D numpy array with dimensions 50 | corresponding to: (1, depth, n_chans, height, width) 51 | and slices will be used to insert the segmentation into 52 | a volume of the same size as the original image. 53 | 54 | We need this because the network is trained on images of size 55 | (128, 160, 144) so during deployment we would ideally extract 56 | a patch of this size from the original image, segment it and 57 | re-insert the segmentation into its appropriate place. 58 | """ 59 | im, slices = bds.get_image_slice(im) 60 | 61 | np_image = bds.get_im_as_ndarray(im, downsize=False) 62 | np_image = np.transpose(np_image, (1, 0, 2, 3)) 63 | np_image = np_image[np.newaxis] 64 | 65 | return np_image, slices 66 | 67 | def build_net(model_folder, model_code, n_classes, train_size, inpt_h, inpt_w, inpt_d, n_channels): 68 | """ 69 | Takes everything that defines a trained neural network 70 | in our setting and returns a function predict that accepts 71 | a numpy array as input and returns the segmentation corresponding 72 | to it. 73 | Parameters: 74 | model_folder: path to a directory containing the training results 75 | of the network. 76 | model_code: id of the model. This will be used to find the neural 77 | net architecture in model_defs.py 78 | n_classes: number of labels in the segmentation problem 79 | inpt_h: height of input 80 | inpt_w: width of input 81 | inpt_d: depth of input 82 | n_channels: number of input channels 83 | """ 84 | model_path = os.path.join('models', model_folder) 85 | 86 | param_file = os.path.join(model_path, 'params.hdf5') 87 | bn_par_file = os.path.join(model_path, 'bn_pars.pkl') 88 | 89 | log = None 90 | for f_name in os.listdir(model_path): 91 | if f_name.endswith('.json') and not f_name.startswith('dice'): 92 | with open(os.path.join(model_path, f_name), 'r') as f: 93 | log = json.load(f) 94 | break 95 | if 'layers' not in log: 96 | log = None 97 | 98 | model_def = get_model(model_code) 99 | 100 | layer_vars = model_def.layer_vars if log is None else log['layers'] 101 | batchnorm = model_def.batchnorm 102 | loss_id = model_def.loss_id 103 | out_transfer = model_def.out_transfer 104 | 105 | batch_size = 1 106 | max_passes = 1 107 | inpt_dims = (inpt_h, inpt_w, inpt_d) 108 | 109 | n_report = train_size / batch_size 110 | max_iter = n_report * max_passes 111 | 112 | optimizer = 'adam' 113 | 114 | model = SequentialModel( 115 | image_height=inpt_dims[0], image_width=inpt_dims[1], 116 | image_depth=inpt_dims[2], n_channels=n_channels, 117 | n_output=n_classes, layer_vars=layer_vars, 118 | out_transfer=out_transfer, loss_id=loss_id, 119 | optimizer=optimizer, batch_size=batch_size, 120 | max_iter=max_iter, using_bn=batchnorm 121 | ) 122 | 123 | f_params = h5py.File(param_file, 'r') 124 | params = np.zeros(model.parameters.data.shape) 125 | params[...] = f_params['best_pars'] 126 | f_params.close() 127 | model.parameters.data[...] = params 128 | 129 | if batchnorm and os.path.exists(bn_par_file): 130 | with open(bn_par_file, 'r') as f: 131 | bn_pars = pickle.load(f) 132 | model.set_batchnorm_params(bn_pars) 133 | else: 134 | if batchnorm: 135 | raise AssertionError('Batch norm used but running metrics not available.') 136 | 137 | if batchnorm: 138 | predict = ash.BatchNormFuns( 139 | model=model, 140 | fun=model.predict, 141 | phase='infer' 142 | ) 143 | else: 144 | predict = model.predict 145 | 146 | return predict 147 | 148 | def apply_network(inpt, predict_fn, n_classes=5): 149 | """Applies the predict function returned by build_net to a numpy array.""" 150 | _, depth, _, height, width = inpt.shape 151 | 152 | model_output = predict_fn(inpt) 153 | model_output = model_output.as_numpy_array() if isinstance(model_output, gnumpy.garray) else model_output 154 | fuzzy_seg = np.reshape( 155 | model_output, 156 | (height, width, depth, n_classes) 157 | ) 158 | seg = fuzzy_seg.argmax(axis=3) 159 | 160 | return seg 161 | 162 | def segment_dict(im_dict, model_folder, model_code, n_classes=5): 163 | """ 164 | Segments an image using a trained neural network. 165 | Parameters: 166 | im_dict: a dictionary where the keys are 'Flair', 'T1', 'T1c' and 'T2' 167 | and the values are numpy ndarrays. 168 | model_folder: path to a directory containing the training results 169 | of the network. 170 | model_code: id of the model. This will be used to find the neural 171 | net architecture in model_defs.py 172 | """ 173 | orig_shape = im_dict['Flair'].shape 174 | 175 | inpt, slices = load_dict_as_inputable_ndarray(im_dict) 176 | train_size, inpt_d, n_channels, inpt_h, inpt_w = inpt.shape 177 | 178 | predict_fn = build_net(model_folder, model_code, n_classes, train_size, inpt_h, inpt_w, inpt_d, n_channels) 179 | 180 | seg = apply_network(inpt, predict_fn, n_classes) 181 | 182 | segmentation = np.zeros(orig_shape) 183 | z_s, x_s, y_s = slices 184 | segmentation[z_s, x_s, y_s] = seg.transpose((2, 0, 1)) 185 | 186 | return segmentation 187 | 188 | def segment(path, model_folder, model_code, n_classes=5): 189 | """ 190 | Segments an image using a trained neural network. 191 | Parameters: 192 | path: path of a directory containing .mha files in 193 | its subdirectories. 194 | This is specified in: load_mhas_as_dict 195 | model_folder: path to a directory containing the training results 196 | of the network. 197 | model_code: id of the model. This will be used to find the neural 198 | net architecture in model_defs.py 199 | """ 200 | im_dict = load_mhas_as_dict(path) 201 | 202 | return segment_dict(im_dict, model_folder, model_code, n_classes) 203 | 204 | if __name__ == '__main__': 205 | import matplotlib.pyplot as plt 206 | 207 | seg = segment('BRATS2015_Training/HGG/brats_2013_pat0001_1', 208 | 'dummy45', 'fcn_rffc4') 209 | for depth_slice in seg: 210 | bds.vis_col_im(im=np.ones_like(depth_slice), gt=depth_slice) 211 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script used for training one of the neural networks 3 | defined in model_defs.py 4 | The usage is: 5 | 6 | python train.py [model_id] [dataset] [save_folder] [n_epoch] -ch (True/False) 7 | 8 | [model_id] is the id of the model used in model_defs.py 9 | [dataset] is the name of a dataset from data/datasets 10 | [save_folder] our results will be saved to models/[save_folder] 11 | [n_epoch] We will iterate over the entire data set [n_epoch] many times 12 | 13 | -ch: Flag indicating whether we want to start training from an earlier checkpoint 14 | WARNING: checkpoints are specific to the model_id and not to the experiment. 15 | If you have two different experiments using the same model_id running 16 | in parallel, their checkpoints will be in conflict. 17 | 18 | In our paper, we trained our network using: 19 | 20 | python train.py fcn_rffc4 brats_fold0 brats_fold0 600 -ch False 21 | """ 22 | 23 | #import break_handling 24 | import cPickle as pickle 25 | import json 26 | import os 27 | import datetime 28 | import sys 29 | import argparse 30 | 31 | import matplotlib.pyplot as plt 32 | import numpy as np 33 | import gnumpy 34 | import h5py 35 | 36 | import climin.stops 37 | from breze.learn.trainer import report 38 | 39 | import ash 40 | from ash import PocketTrainer 41 | from model_defs import get_model 42 | 43 | from conv3d.model import SequentialModel 44 | 45 | def make_parser(): 46 | parser = argparse.ArgumentParser(description='Train model on data.') 47 | parser.add_argument('model_code', metavar='model', type=str, help='model to use') 48 | parser.add_argument('data_code', metavar='data', type=str, help='data to train on') 49 | parser.add_argument('train_code', metavar='tdir', type=str, help='directory path to store results in') 50 | parser.add_argument('n_epochs', metavar='ne', type=int, help='num of passes through the training set') 51 | parser.add_argument('-ch', '--checkpoint', help='set to load from checkpoint if available') 52 | 53 | return parser 54 | 55 | def retrieve_data(data_code): 56 | data_loc = os.path.join('data', 'datasets', data_code + '.hdf5') 57 | data = h5py.File(data_loc, 'r') 58 | 59 | train_x = data['train_x'] 60 | train_y = data['train_y'] 61 | valid_x = data['valid_x'] 62 | valid_y = data['valid_y'] 63 | test_x = data['test_x'] 64 | test_y = data['test_y'] 65 | 66 | return [(train_x, train_y), (valid_x, valid_y), (test_x, test_y)] 67 | 68 | def load_checkpoint(model_code, param_shape): 69 | param_loc = os.path.join('models', 'checkpoints', model_code + '.hdf5') 70 | log_code = os.path.join('models', 'checkpoints', model_code + '_log.json') 71 | bn_pars_path = os.path.join('models', 'checkpoints', model_code + '_bn_pars.pkl') 72 | bn_pars = None 73 | if not os.path.exists(param_loc): 74 | print 'No checkpoint available, using random initialization instead.' 75 | return np.random.normal(0, 0.01, param_shape), None 76 | else: 77 | with open(log_code, 'r') as f: 78 | log = json.load(f) 79 | n_epochs_done = log['n_epochs'] 80 | n_iters_done = log['n_iters'] 81 | b_loss = log['best_loss'] 82 | if os.path.exists(bn_pars_path): 83 | with open(bn_pars_path, 'r') as f: 84 | bn_pars = pickle.load(f) 85 | print 'bn parameters found' 86 | param_file = h5py.File(param_loc, 'r') 87 | params_np = np.zeros(param_shape) 88 | b_params_np = np.zeros(param_shape) 89 | params_np[...] = param_file['last_pars'] 90 | b_params_np[...] = param_file['best_pars'] 91 | param_file.close() 92 | t_dic = { 93 | 'best_pars': b_params_np, 94 | 'best_loss': b_loss, 95 | 'n_epochs': n_epochs_done, 96 | 'n_iters': n_iters_done, 97 | 'bn_pars': bn_pars 98 | } 99 | return params_np, t_dic 100 | 101 | def build_model(model_code, checkpoint, info): 102 | model_def = get_model(model_code=model_code) 103 | 104 | layer_vars = model_def.layer_vars 105 | batchnorm = model_def.batchnorm 106 | loss_id = model_def.loss_id 107 | loss_layer_def = model_def.loss_layer_def 108 | out_transfer = model_def.out_transfer 109 | 110 | if model_def.regularize: 111 | print 'using regularization: l1: %s, l2: %s' % (model_def.l1, model_def.l2) 112 | 113 | model = SequentialModel( 114 | image_height=info['height'], image_width=info['width'], 115 | image_depth=info['depth'], n_channels=info['n_inpt'], 116 | n_output=info['n_classes'], layer_vars=layer_vars, 117 | out_transfer=out_transfer, loss_id=loss_id, 118 | loss_layer_def=loss_layer_def, optimizer=info['optimizer'], 119 | batch_size=info['batch_size'], max_iter=info['max_iter'], 120 | using_bn=batchnorm, regularize=model_def.regularize, 121 | l1=model_def.l1, l2=model_def.l2, 122 | perform_transform=model_def.perform_transform 123 | ) 124 | 125 | if checkpoint: 126 | model.parameters.data[...], t_dic = load_checkpoint(model_code, model.parameters.data.shape) 127 | if t_dic is not None: 128 | model.max_iter -= t_dic['n_iters'] 129 | if t_dic['bn_pars'] is not None: 130 | model.set_batchnorm_params(t_dic['bn_pars']) 131 | print 'bn parameters loaded' 132 | else: 133 | t_dic = None 134 | rng = np.random.RandomState(123) 135 | model.parameters.data[...] = rng.normal(0, 0.01, model.parameters.data.shape) 136 | 137 | return model, model_def, t_dic 138 | 139 | def setup_training(model_code, data_code, checkpoint, max_passes): 140 | train, valid, test = retrieve_data(data_code=data_code) 141 | 142 | train_size, inpt_d, n_channels, inpt_h, inpt_w = train[0].shape 143 | n_classes = train[1].shape[-1] 144 | valid_size = valid[0].shape[0] 145 | test_size = test[0].shape[0] 146 | 147 | print 'input data dimensions: h: %i w: %i d: %i' % (inpt_h, inpt_w, inpt_d) 148 | print 'set stats: train: %i, valid: %i, test: %i' % (train_size, valid_size, test_size) 149 | 150 | optimizer = 'adam' 151 | batch_size = 1 152 | 153 | n_report = train_size / batch_size 154 | max_iter = n_report * max_passes 155 | 156 | info = { 157 | 'height':inpt_h, 'width': inpt_w, 'depth': inpt_d, 158 | 'n_classes': n_classes, 'n_inpt': n_channels, 'optimizer': optimizer, 159 | 'batch_size': batch_size, 'max_iter': max_iter, 'n_report': n_report 160 | } 161 | 162 | model, model_def, t_dic = build_model(model_code, checkpoint, info) 163 | 164 | stop = climin.stops.AfterNIterations(max_iter=model.max_iter) 165 | pause = climin.stops.ModuloNIterations(n_report) 166 | 167 | data = { 168 | 'train': train, 'val': valid, 'test': test 169 | } 170 | 171 | report_fun = report.OneLinePrinter( 172 | ['n_iter', 'runtime', 'loss', 'val_loss'], 173 | spaces=['4', '7.4f', '5.4f', '7.4f'] 174 | ) 175 | score_fun = ash.MinibatchScoreFCN(max_samples=batch_size, sample_dims=[0, 0]) 176 | 177 | coach = PocketTrainer( 178 | model=model, data=data, stop=stop, 179 | pause=pause, score_fun=score_fun, 180 | report_fun=report_fun, evaluate=True, 181 | test=False, batchnorm=model_def.batchnorm, 182 | model_code=model_code, n_report=n_report 183 | ) 184 | 185 | if t_dic is not None: 186 | coach.best_pars = t_dic['best_pars'].copy() 187 | coach.best_loss = t_dic['best_loss'] 188 | 189 | return coach, model_def 190 | 191 | def secure_data(coach, params_shape, model_def, train_dir): 192 | param_loc = os.path.join(train_dir, 'params.hdf5') 193 | param_file = h5py.File(param_loc, 'w') 194 | model_params = param_file.create_dataset( 195 | 'best_pars', params_shape, dtype='float32' 196 | ) 197 | 198 | if isinstance(coach.best_pars, gnumpy.garray): 199 | model_params[...] = coach.best_pars.as_numpy_array() 200 | else: 201 | model_params[...] = coach.best_pars 202 | 203 | if coach.using_bn: 204 | bn_pars = coach.model.get_batchnorm_params() 205 | bn_pars_loc = os.path.join(train_dir, 'bn_pars.pkl') 206 | with open(bn_pars_loc, 'w') as f: 207 | pickle.dump(bn_pars, f) 208 | 209 | now = str(datetime.datetime.now()) 210 | date, time = now.split(' ') 211 | time = time.replace(':', '_') 212 | time = time.replace('.', '_') 213 | log_code = 'log' + date + '@' + time + '.json' 214 | log_loc = os.path.join(train_dir, log_code) 215 | 216 | log = { 217 | 'params': param_loc, 218 | 'layers': model_def.layer_vars, 219 | 'loss_id': model_def.loss_name, 220 | 'losses': coach.losses, 221 | 'test_performance': coach.test_performance, 222 | 'regularize': model_def.regularize, 223 | 'l1': model_def.l1, 224 | 'l2': model_def.l2, 225 | 'perform_transform': model_def.perform_transform.__name__ if model_def.perform_transform is not None else None 226 | } 227 | 228 | with open(log_loc, 'w') as f: 229 | json.dump(log, f) 230 | 231 | t_loss, v_loss = plt.plot(coach.losses) 232 | plt.legend([t_loss, v_loss], ['train loss', 'val loss']) 233 | save_file = os.path.join(train_dir, 'figure.png') 234 | plt.savefig(save_file) 235 | 236 | def save_demo(coach, train_dir, size_reduction): 237 | if coach.using_bn: 238 | predict = ash.BatchNormFuns( 239 | model=coach.model, 240 | fun=coach.model.predict, 241 | phase='infer' 242 | ) 243 | else: 244 | predict = coach.model.predict 245 | 246 | test_x, test_y = coach.data['test'] 247 | dice_values = [] 248 | for i in range(test_x.shape[0]): 249 | im_name = os.path.join(train_dir, 'im' + str(i) + '.png') 250 | this_dice_value = coach.demo( 251 | predict=predict, image=test_x[i:i + 1], 252 | gt=test_y[i], size_reduction=size_reduction, 253 | im_name=im_name 254 | ) 255 | dice_values.append(this_dice_value) 256 | 257 | mean_dice = 0 258 | for d in dice_values: 259 | mean_dice += d[0] 260 | mean_dice = mean_dice * 1./len(dice_values) 261 | dice_log = {'mean_dice': mean_dice, 'dice_values': dice_values} 262 | dice_log_path = os.path.join(train_dir, 'dice.json') 263 | with open(dice_log_path, 'w') as f: 264 | json.dump(dice_log, f) 265 | 266 | def start_training(model_code, data_code, checkpoint, max_passes, train_dir=None): 267 | if train_dir is None: 268 | train_dir = os.path.join('models', data_code) 269 | if not os.path.exists(train_dir): 270 | os.makedirs(train_dir) 271 | 272 | print 'Building model, coach...' 273 | coach, model_def = setup_training(model_code, data_code, checkpoint, max_passes) 274 | 275 | #break_handling.make_checkpoint = coach.quit_training 276 | print 'Starting training...' 277 | coach.fit() 278 | 279 | print 'Securing results...' 280 | secure_data( 281 | coach=coach, params_shape=coach.model.parameters.data.shape, 282 | model_def=model_def, train_dir=train_dir 283 | ) 284 | 285 | print 'Saving demo images...' 286 | save_demo(coach, train_dir, size_reduction=model_def.size_reduction) 287 | 288 | print 'done.' 289 | 290 | if __name__ == '__main__': 291 | parser = make_parser() 292 | 293 | if len(sys.argv[1:]) > 0: 294 | args = parser.parse_args() 295 | else: 296 | args = parser.parse_args(['fcn_rffc4', 'dummy5', 'dummy5', '2', '-ch', 'True']) # model data train checkpoint 297 | 298 | model_code = args.model_code 299 | data_code = args.data_code 300 | checkpoint = args.checkpoint 301 | t_code = args.train_code 302 | n_epochs = args.n_epochs 303 | 304 | train_dir = os.path.join('models', t_code) 305 | start_training(model_code, data_code, checkpoint, max_passes=n_epochs, train_dir=train_dir) -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import json 3 | import os 4 | import datetime 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import h5py 9 | 10 | import climin.stops 11 | from climin import mathadapt as ma 12 | 13 | from breze.learn.trainer import report 14 | 15 | import ash 16 | from ash import PocketTrainer 17 | from model_defs import get_model 18 | 19 | from conv3d.model import SequentialModel 20 | 21 | vis = False 22 | retrain = False 23 | 24 | d_code = 'handsize2_v2' 25 | model_code = 'fcn96_rescaled' 26 | train_dir = os.path.join('models', d_code) 27 | assert os.path.exists(train_dir) 28 | 29 | now = str(datetime.datetime.now()) 30 | date, time = now.split(' ') 31 | time = time.replace(':', '_') 32 | time = time.replace('.', '_') 33 | log_code = 'log' + date + '@' + time + '.json' 34 | 35 | param_file = os.path.join(train_dir, 'params.hdf5') 36 | f = h5py.File('data/datasets/'+d_code+'.hdf5', 'r') 37 | 38 | train_x = f['train_x'] 39 | train_y = f['train_y'] 40 | 41 | valid_x = f['valid_x'] 42 | valid_y = f['valid_y'] 43 | 44 | test_x = f['test_x'] 45 | test_y = f['test_y'] 46 | 47 | n_classes = 2 48 | 49 | model_def = get_model(model_code) 50 | 51 | alpha = model_def.alpha 52 | layer_vars = model_def.layer_vars 53 | batchnorm = model_def.batchnorm 54 | loss_id = model_def.loss_id 55 | out_transfer = model_def.out_transfer 56 | size_reduction = model_def.size_reduction 57 | 58 | train_size, inpt_d, n_channels, inpt_h, inpt_w = train_x.shape 59 | 60 | set_x = train_x 61 | set_y = train_y 62 | 63 | output_h = inpt_h-size_reduction 64 | output_w = inpt_w-size_reduction 65 | output_d = inpt_d-size_reduction 66 | 67 | if vis: 68 | for i in range(set_x.shape[0]): 69 | plt.imshow(set_x[i,inpt_d/2,0,:,:], cmap='Greys_r') 70 | plt.show() 71 | ty = np.reshape(set_y[i], (output_h,output_w,output_d,n_classes)) 72 | ty = ty.argmax(axis=3) 73 | plt.imshow(ty[:,:,output_d/2], cmap='Greys_r') 74 | plt.show() 75 | 76 | batch_size = 1 77 | max_passes = 20 78 | inpt_dims = (inpt_h, inpt_w, inpt_d) 79 | 80 | n_report = train_size / batch_size 81 | max_iter = n_report * max_passes 82 | 83 | #stop = climin.stops.Patience( 84 | # func_or_key='val_loss', initial=max_iter, 85 | # grow_factor=2., grow_offset=0, 86 | # threshold=1e-4 87 | #) 88 | stop = climin.stops.AfterNIterations(max_iter=max_iter) 89 | pause = climin.stops.ModuloNIterations(n_report) 90 | 91 | print 'Input data dimensions: h: %i w: %i d: %i ' % (inpt_h, inpt_w, inpt_d) 92 | print 'Set stats: train: %i, valid: %i, test: %i' % (train_x.shape[0], valid_x.shape[0], test_x.shape[0]) 93 | 94 | print 'max iter: ', max_iter 95 | print 'report frequency: every %i iterations' % n_report 96 | 97 | optimizer = 'adam' 98 | 99 | print '\nbuilding model...' 100 | pkchu = SequentialModel( 101 | image_height=inpt_dims[0], image_width=inpt_dims[1], 102 | image_depth=inpt_dims[2], n_channels=n_channels, 103 | n_output=n_classes, layer_vars=layer_vars, 104 | out_transfer=out_transfer, loss_id=loss_id, 105 | optimizer=optimizer, batch_size=batch_size, 106 | max_iter=max_iter, using_bn=batchnorm 107 | ) 108 | 109 | rng = np.random.RandomState(123) 110 | pkchu.parameters.data[...] = rng.normal(0, 0.01, pkchu.parameters.data.shape) 111 | 112 | if retrain: 113 | print 'retrieving old params...' 114 | f_params = h5py.File(param_file, 'r') 115 | pkchu.parameters.data[...] = f_params['best_pars'] 116 | 117 | if batchnorm: 118 | bn_par_file = os.path.join(train_dir, 'bn_pars.pkl') 119 | with open(bn_par_file, 'r') as f: 120 | bn_pars = pickle.load(f) 121 | pkchu.set_batchnorm_params(bn_pars) 122 | 123 | param_file = os.path.join(train_dir, 'newparams.hdf5') 124 | 125 | report_fun = report.OneLinePrinter( 126 | ['n_iter', 'runtime', 'loss', 'val_loss', 'test_avg'], 127 | spaces=['4', '7.4f', '5.4f', '7.4f', '7.4f'] 128 | ) 129 | 130 | score_fun = ash.MinibatchScoreFCN(max_samples=batch_size, sample_dims=[0, 0]) 131 | data = { 132 | 'train':(train_x, train_y), 133 | 'val':(valid_x, valid_y), 134 | 'test':(test_x, test_y) 135 | } 136 | 137 | test_fun = ash.MinibatchTestFCN(max_samples=batch_size, sample_dims=[0, 0]) 138 | 139 | #initial_err = ma.scalar(score_fun(pkchu.score, *data['train'])) 140 | #print 'Initial train loss: %.4f' % initial_err 141 | 142 | coach = PocketTrainer( 143 | model=pkchu, data=data, stop=stop, 144 | pause=pause, score_fun=score_fun, 145 | report_fun=report_fun, test_fun=test_fun, 146 | evaluate=True, test=True, batchnorm=batchnorm 147 | ) 148 | 149 | print 'training...' 150 | coach.fit() 151 | print 'training complete.' 152 | 153 | pkchu.parameters.data[...] = coach.best_pars 154 | 155 | f_params = h5py.File(param_file, 'w') 156 | model_params = f_params.create_dataset( 157 | 'best_pars', pkchu.parameters.data.shape, dtype='float32' 158 | ) 159 | 160 | print 'securing params...' 161 | model_params[...] = coach.best_pars.as_numpy_array() 162 | 163 | if batchnorm: 164 | print 'securing batch-norm params...' 165 | bn_pars = pkchu.get_batchnorm_params() 166 | with open(os.path.join(train_dir, 'bn_pars.pkl'), 'w') as f: 167 | pickle.dump(bn_pars, f) 168 | 169 | log = { 170 | 'data': d_code, 171 | 'params': param_file, 172 | 'layers': layer_vars, 173 | 'loss_id': loss_id.__name__, 174 | 'losses': coach.losses, 175 | 'test_performance': coach.test_performance 176 | } 177 | 178 | print 'printing log and visualizing results...' 179 | with open(os.path.join(train_dir, log_code), 'w') as f: 180 | json.dump(log, f) 181 | 182 | t_loss, v_loss = plt.plot(coach.losses) 183 | plt.legend([t_loss, v_loss], ['train loss', 'val loss']) 184 | save_file = os.path.join(train_dir, 'figure.png') 185 | plt.savefig(save_file) 186 | #plt.show() 187 | 188 | predict = ash.BatchNormFuns( 189 | model=pkchu, 190 | fun=pkchu.predict, 191 | phase='infer' 192 | ) 193 | 194 | for i in range(test_x.shape[0]): 195 | coach.demo(predict=predict, image=test_x[i:i+1], gt=test_y[i], size_reduction=size_reduction) 196 | print 'all done, good night.' 197 | 198 | -------------------------------------------------------------------------------- /train_vox.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import h5py 6 | 7 | import climin.stops 8 | from climin import mathadapt as ma 9 | 10 | from breze.learn.trainer import report 11 | 12 | import ash 13 | from ash import PocketTrainer 14 | 15 | from conv3d.model import FCN 16 | 17 | vis = False 18 | params = False 19 | stacked_filters = False 20 | 21 | d_code = '96fcnmini' 22 | param_file = 'params/params' + d_code + '.pkl' 23 | f = h5py.File('data/data'+d_code+'.hdf5', 'r') 24 | 25 | train_x = f['train_x'] 26 | train_y = f['train_y'] 27 | 28 | valid_x = f['valid_x'] 29 | valid_y = f['valid_y'] 30 | 31 | test_x = f['test_x'] 32 | test_y = f['test_y'] 33 | 34 | n_classes = 2 35 | 36 | if d_code == '48fcnmini' or d_code == '48fcn' and stacked_filters: 37 | nkerns_d = [16, 16, 32, 32, 64, 64] 38 | fs_d = [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)] 39 | ps_d = ['no_pool', (2, 2, 2), 'no_pool', (2, 2, 2), 'no_pool', (2, 2, 2)] 40 | strides_d = (1, 1, 1) 41 | 42 | nkerns_u = [128, 128, 64, 64, 32, 32, 16, n_classes] 43 | fs_u = [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (1, 1, 1)] 44 | ps_u = ['no_pool', (2, 2, 2), 'no_pool', (2, 2, 2), 'no_pool', (2, 2, 2), 'no_pool', 'no_pool'] 45 | 46 | hidden_transfers_conv = ['rectifier']*6 47 | hidden_transfers_upconv = ['rectifier']*7 + ['identity'] 48 | 49 | bm_down = ['same']*6 50 | bm_up = ['same']*7 + ['valid'] 51 | loss_id = ash.fcn_cat_ce 52 | 53 | padding = 0 54 | elif d_code.startswith('48fcn'): 55 | nkerns_d = [2, 16, 32] 56 | fs_d = [(3, 3, 3), (3, 3, 3), (3, 3, 3)] 57 | ps_d = [(2, 2, 2), (2, 2, 2), (2, 2, 2)] 58 | strides_d = (1, 1, 1) 59 | 60 | nkerns_u = [32, 32, 16, n_classes] 61 | fs_u = [(3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3)] 62 | ps_u = [(2, 2, 2), (2, 2, 2), (2, 2, 2), 'no_pool'] 63 | 64 | hidden_transfers_conv = ['rectifier', 'rectifier', 'rectifier'] 65 | hidden_transfers_upconv = ['rectifier', 'rectifier', 'rectifier', 'identity'] 66 | 67 | bm_down = ['same', 'same', 'same'] 68 | bm_up = ['same', 'same', 'same', 'same'] 69 | loss_id = ash.fcn_cat_ce 70 | 71 | padding = 0 72 | elif d_code.startswith('9680fcn') and stacked_filters: 73 | nkerns_d = [16, 16, 32, 32, 64, 64] 74 | fs_d = [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)] 75 | ps_d = ['no_pool', (2, 2, 2), 'no_pool', (2, 2, 2), 'no_pool', (2, 2, 2)] 76 | strides_d = (1, 1, 1) 77 | 78 | nkerns_u = [128, 128, 64, 64, 32, 32, 16, n_classes] 79 | fs_u = [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (1, 1, 1)] 80 | ps_u = ['no_pool', (2, 2, 2), 'no_pool', (2, 2, 2), 'no_pool', (2, 2, 2), 'no_pool', 'no_pool'] 81 | 82 | hidden_transfers_conv = ['rectifier']*6 83 | hidden_transfers_upconv = ['rectifier']*7 + ['identity'] 84 | 85 | bm_down = ['same']*6 86 | bm_up = ['valid']*2 + ['same']*5 + ['valid'] 87 | loss_id = ash.fcn_cat_ce 88 | 89 | padding = 16 90 | elif d_code.startswith('9680fcn'): 91 | nkerns_d = [8, 16, 32] 92 | fs_d = [(3, 3, 3), (3, 3, 3), (3, 3, 3)] 93 | ps_d = [(2, 2, 2), (2, 2, 2), (2, 2, 2)] 94 | strides_d = (1, 1, 1) 95 | 96 | nkerns_u = [32, 32, 16, n_classes] 97 | fs_u = [(3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3)] 98 | ps_u = [(2, 2, 2), (2, 2, 2), (2, 2, 2), 'no_pool'] 99 | 100 | hidden_transfers_conv = ['rectifier', 'rectifier', 'rectifier'] 101 | hidden_transfers_upconv = ['rectifier', 'rectifier', 'rectifier', 'identity'] 102 | 103 | bm_down = ['same', 'same', 'same'] 104 | bm_up = ['valid', 'same', 'same', 'same'] 105 | loss_id = ash.fcn_cat_ce 106 | 107 | padding = 16 108 | elif d_code.startswith('96fcn'): 109 | nkerns_d = [8, 16, 32] 110 | fs_d = [(3, 3, 3), (3, 3, 3), (3, 3, 3)] 111 | ps_d = [(2, 2, 2), (2, 2, 2), (2, 2, 2)] 112 | strides_d = (1, 1, 1) 113 | 114 | nkerns_u = [32, 32, 16, n_classes] 115 | fs_u = [(3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3)] 116 | ps_u = [(2, 2, 2), (2, 2, 2), (2, 2, 2), 'no_pool'] 117 | 118 | hidden_transfers_conv = ['rectifier', 'rectifier', 'rectifier'] 119 | hidden_transfers_upconv = ['rectifier', 'rectifier', 'rectifier', 'identity'] 120 | 121 | bm_down = ['same', 'same', 'same'] 122 | bm_up = ['same', 'same', 'same', 'same'] 123 | loss_id = ash.dice 124 | 125 | padding = 0 126 | else: 127 | raise Exception('No such dataset.') 128 | 129 | train_size, inpt_d, n_channels, inpt_h, inpt_w = train_x.shape 130 | 131 | batch_size = 1 132 | max_passes = 50 133 | 134 | set_x = train_x 135 | set_y = train_y 136 | 137 | output_h = inpt_h-padding 138 | output_w = inpt_w-padding 139 | output_d = inpt_d-padding 140 | 141 | if vis: 142 | for i in range(set_x.shape[0]): 143 | plt.imshow(set_x[i,inpt_d/2,0,:,:], cmap='Greys_r') 144 | plt.show() 145 | ty = np.reshape(set_y[i], (output_h,output_w,output_d,n_classes)) 146 | ty = ty.argmax(axis=3) 147 | plt.imshow(ty[:,:,output_d/2], cmap='Greys_r') 148 | plt.show() 149 | 150 | inpt_dims = (inpt_h, inpt_w, inpt_d) 151 | 152 | n_report = train_size / batch_size 153 | max_iter = n_report * max_passes 154 | 155 | print 'Train x shape: ', train_x.shape 156 | print 'Train y shape: ', train_y.shape 157 | print 'Valid x shape: ', valid_x.shape 158 | print 'Valid y shape: ', valid_y.shape 159 | print 'Test x shape: ', test_x.shape 160 | print 'Test y shape: ', test_y.shape 161 | 162 | print '\nmax iter: ', max_iter 163 | print 'report frequency: every %i iterations\n' % n_report 164 | 165 | stop = climin.stops.AfterNIterations(max_iter) 166 | pause = climin.stops.ModuloNIterations(n_report) 167 | 168 | optimizer = 'adam' 169 | 170 | pkchu = FCN( 171 | image_height=inpt_dims[0], image_width=inpt_dims[1], 172 | image_depth=inpt_dims[2], n_channel=n_channels, 173 | n_output=n_classes, n_hiddens_conv=nkerns_d, 174 | down_filter_shapes=fs_d, hidden_transfers_conv=hidden_transfers_conv, 175 | down_pools=ps_d, n_hiddens_upconv=nkerns_u, 176 | up_filter_shapes=fs_u, hidden_transfers_upconv=hidden_transfers_upconv, 177 | up_pools=ps_u, out_transfer='softmax', loss=loss_id, 178 | optimizer=optimizer, batch_size=batch_size, 179 | bm_up=bm_up, bm_down=bm_down, 180 | max_iter=max_iter, implementation='dnn_conv3d', strides_d=strides_d 181 | ) 182 | 183 | print '\nARCHITECTURE: ' 184 | print '\tFilters: ', fs_d 185 | print '\tFeature maps: ', nkerns_d 186 | print '\tPools: ', ps_d 187 | print '\tUp-filters: ', fs_u 188 | print '\tFeature maps: ', nkerns_u 189 | print '\tUppools: ', ps_u 190 | 191 | if not params: 192 | rng = np.random.RandomState(123) 193 | pkchu.parameters.data[...] = rng.normal(0, 0.01, pkchu.parameters.data.shape) 194 | else: 195 | with open(param_file, 'r') as f: 196 | #with open('params/params9680fcnmini.pkl', 'r') as f: 197 | pkchu.parameters.data[...] = pickle.load(f) 198 | 199 | report_fun = report.OneLinePrinter( 200 | ['n_iter', 'runtime', 'loss', 'val_loss', 'test_avg'], 201 | spaces=['4', '7.4f', '5.4f', '7.4f', '7.4f'] 202 | ) 203 | 204 | score_fun = ash.MinibatchScoreFCN(max_samples=batch_size, sample_dims=[0, 0]) 205 | data = { 206 | 'train':(train_x, train_y), 207 | 'val':(valid_x, valid_y), 208 | 'test':(test_x, test_y) 209 | } 210 | 211 | test_fun = ash.MinibatchTestFCN(max_samples=batch_size, sample_dims=[0, 0]) 212 | 213 | initial_err = ma.scalar(score_fun(pkchu.score, *data['train'])) 214 | 215 | print 'Initial train loss: %.4f' % initial_err 216 | 217 | coach = PocketTrainer( 218 | model=pkchu, data=data, stop=stop, 219 | pause=pause, score_fun=score_fun, 220 | report_fun=report_fun, test_fun=test_fun, 221 | evaluate=True, test=True 222 | ) 223 | 224 | coach.fit() 225 | 226 | pkchu.parameters.data[...] = coach.best_pars 227 | 228 | plt.plot(coach.losses) 229 | plt.show() 230 | 231 | for i in range(test_x.shape[0]): 232 | plt.imshow(test_x[i, inpt_d/2, 0, :, :], cmap='Greys_r') 233 | plt.show() 234 | y = pkchu.predict(test_x[i:i+1]) 235 | y = np.reshape(y[0].as_numpy_array(), (output_h, output_w, output_d, n_classes)) 236 | y = y.argmax(axis=3) 237 | plt.imshow(y[:, :, output_d/2], cmap='Greys_r') 238 | plt.show() 239 | 240 | with open(param_file,'w') as f: 241 | pickle.dump(coach.best_pars, f) -------------------------------------------------------------------------------- /transformations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | def extract_section(im, x, y, z, padding, section_shape): 5 | x_sect, y_sect, z_sect = section_shape 6 | size_x, size_y, size_z = im.shape 7 | take_x = x_sect + padding 8 | take_y = y_sect + padding 9 | take_z = z_sect + padding 10 | 11 | if x - (take_x / 2) < 0: 12 | sl_x = slice(0, take_x) 13 | elif x + (take_x / 2) > size_x: 14 | sl_x = slice(size_x - take_x, size_x) 15 | else: 16 | sl_x = slice(x - (take_x / 2), x + (take_x / 2)) 17 | 18 | if y - (take_y / 2) < 0: 19 | sl_y = slice(0, take_y) 20 | elif y + (take_y / 2) > size_y: 21 | sl_y = slice(size_y - take_y, size_y) 22 | else: 23 | sl_y = slice(y - (take_y / 2), y + (take_y / 2)) 24 | 25 | if z - (take_z / 2) < 0: 26 | sl_z = slice(0, take_z) 27 | elif z + (take_z / 2) > size_z: 28 | sl_z = slice(size_z - take_z, size_z) 29 | else: 30 | sl_z = slice(z - (take_z / 2), z + (take_z / 2)) 31 | 32 | return im[sl_x, sl_y, sl_z] 33 | 34 | def random_brain_points(gt): 35 | whole = np.array(np.where(gt != 0)).T 36 | core = np.array(np.where((gt != 0) & (gt != 2))).T 37 | active = np.array(np.where(gt == 4)).T 38 | healthy = np.array(np.where(gt == 0)).T 39 | 40 | region_candidates = [whole, core, active, healthy, healthy] 41 | regions = [] 42 | for candidate in region_candidates: 43 | if len(candidate) > 0: 44 | regions.append(candidate) 45 | 46 | if len(regions) == 0: 47 | raise ValueError('Ground truth does not make sense.') 48 | 49 | reg = random.choice(regions) 50 | i = np.random.randint(0, len(reg)) 51 | point = reg[i] 52 | return point 53 | 54 | def random_hand_points(gt): 55 | bone = np.array(np.where(gt > 0)).T 56 | center = np.array([int(i) for i in np.round(np.mean(bone, axis=0))]) 57 | offset = np.random.randint(-25, 25, size=(3,)) 58 | point = center + offset 59 | 60 | # region_candidates = [metacarpal, proximal, middle] 61 | # regions = [] 62 | # for candidate in region_candidates: 63 | # if len(candidate) > 0: 64 | # regions.append(candidate) 65 | # 66 | # if len(regions) == 0: 67 | # raise ValueError('Ground truth does not make sense.') 68 | # 69 | # reg = random.choice(regions) 70 | # i = np.random.randint(0, len(reg)) 71 | # point = reg[i] 72 | 73 | return point 74 | 75 | def extract_random_section(image, gt, random_point_selection, section_shape): 76 | point = random_point_selection(gt.argmax(axis=0)) 77 | x, y, z = point 78 | sections = [extract_section(modality, x, y, z, padding=0, section_shape=section_shape) for modality in image] 79 | gt_sections = [extract_section(class_map, x, y, z, padding=0, section_shape=section_shape) for class_map in gt] 80 | 81 | return np.array(sections, dtype='int16'), np.array(gt_sections, dtype='int8') 82 | 83 | 84 | class RandomSectionSelection(object): 85 | def __init__(self, data_mode, followed_by=None): 86 | if data_mode == 'brain': 87 | self.random_point_selection = random_brain_points 88 | self.section_shape = (80, 72, 64) 89 | elif data_mode == 'hand': 90 | self.random_point_selection = random_hand_points 91 | self.section_shape = (144, 120, 96) 92 | elif data_mode == 'debug_hand': 93 | self.random_point_selection = random_hand_points 94 | self.section_shape = (64, 64, 64) 95 | elif data_mode == 'debug_brain': 96 | self.random_point_selection = random_brain_points 97 | self.section_shape = (32, 32, 32) 98 | else: 99 | raise ValueError('Data modes are: hand, brain.') 100 | self.data_mode = data_mode 101 | self.followed_by = followed_by 102 | 103 | self.__name__ = 'RandomSectionSelection' 104 | 105 | def __call__(self, x, z): 106 | n_classes = z.shape[-1] 107 | 108 | nx = np.transpose(x, (0, 2, 3, 4, 1)) 109 | nz = np.reshape(z, (1, x.shape[3], x.shape[4], x.shape[1], n_classes)) 110 | nz = np.transpose(nz, (0, 4, 1, 2, 3)) 111 | 112 | nx, nz = extract_random_section(nx[0], nz[0], self.random_point_selection, self.section_shape) 113 | nx = np.transpose(nx[np.newaxis], (0, 4, 1, 2, 3)) 114 | nz = np.transpose(nz[np.newaxis], (0, 2, 3, 4, 1)) 115 | nz = np.reshape(nz, (nz.shape[0], nz.shape[1]*nz.shape[2]*nz.shape[3], n_classes)) 116 | 117 | if self.followed_by is None: 118 | return (nx, nz) 119 | else: 120 | return self.followed_by(nx, nz) 121 | 122 | def random_flip(x, z): 123 | flip_axis = random.choice([0, 1, 2]) 124 | original_shape = z.shape 125 | nz = np.reshape(z, (1, x.shape[3], x.shape[4], x.shape[1], z.shape[-1])) 126 | if flip_axis == 0: 127 | # flip along the height 128 | nx = x[:, :, :, ::-1, :] 129 | nz = nz[:, ::-1, :, :, :] 130 | elif flip_axis == 1: 131 | # flip along the width 132 | nx = x[:, :, :, :, ::-1] 133 | nz = nz[:, :, ::-1, :, :] 134 | else: 135 | # flip along the depth 136 | nx = x[:, ::-1, :, :, :] 137 | nz = nz[:, :, :, ::-1, :] 138 | nz = np.reshape(nz, original_shape) 139 | return (nx, nz) 140 | 141 | def percentile_filter(x, z): 142 | from scipy.ndimage import percentile_filter 143 | from breze.learn.data import one_hot 144 | percentile = np.random.randint(0, 10) 145 | 146 | nx = np.transpose(x, (0, 2, 1, 3, 4)) 147 | nx[0] = [percentile_filter(modality, percentile, (2, 2, 2)) for modality in nx[0]] 148 | nx = np.transpose(nx, (0, 2, 1, 3, 4)) 149 | 150 | n_classes = z.shape[-1] 151 | nz = np.reshape(z, (x.shape[3], x.shape[4], x.shape[1], n_classes)) 152 | nz = np.transpose(nz, (3, 0, 1, 2)) 153 | nz = np.array([percentile_filter(class_map, percentile, (2, 2, 2)) for class_map in nz]) 154 | nz = nz.argmax(axis=0) 155 | nz = np.reshape(nz, (-1,)) 156 | nz = np.reshape(one_hot(nz, n_classes), z.shape) 157 | 158 | nx = np.asarray(nx, dtype=x.dtype) 159 | nz = np.asarray(nz, dtype=z.dtype) 160 | 161 | return (nx, nz) 162 | 163 | def swirl_(im, strength, radius): 164 | from skimage.transform import swirl 165 | return [swirl(im_slice, rotation=0, strength=strength, radius=radius) for im_slice in im] 166 | 167 | def swirl_transform(x, z): 168 | """ 169 | Adds a swirl effect to every depth slice. 170 | Assuming a batch size of 1. 171 | More specifically: x is (1, depth, channels, height, width) and z is (1, height*width*depth, classes) 172 | """ 173 | from breze.learn.data import one_hot 174 | strength = np.random.uniform(1, 2) 175 | radius = np.random.randint(90, 140) 176 | z_original_shape = z.shape 177 | n_classes = z.shape[-1] 178 | 179 | nx = np.transpose(x, (0, 2, 1, 3, 4)) 180 | nz = np.reshape(z, (1, x.shape[3], x.shape[4], x.shape[1], n_classes)) 181 | nz = np.transpose(nz, (0, 4, 3, 1, 2)) 182 | nx[0] = [swirl_(modality, strength, radius) for modality in nx[0]] 183 | nx = np.transpose(nx, (0, 2, 1, 3, 4)) 184 | nz[0] = [swirl_(class_map, strength, radius) for class_map in nz[0]] 185 | nz = nz[0].argmax(axis=0) 186 | nz = np.transpose(nz, (1, 2, 0)) 187 | nz = np.reshape(nz, (-1,)) 188 | nz = np.reshape(one_hot(nz, n_classes), z_original_shape) 189 | 190 | nx = np.asarray(nx, dtype=x.dtype) 191 | nz = np.asarray(nz, dtype=z.dtype) 192 | 193 | return (nx, nz) 194 | 195 | def minor_rotation(x, z): 196 | """ 197 | Assuming a batch size of 1. 198 | More specifically: x is (1, depth, channels, height, width) and z is (1, height*width*depth, classes) 199 | """ 200 | from scipy.ndimage.interpolation import rotate as rotate_scipy 201 | from breze.learn.data import one_hot 202 | z_original_shape = z.shape 203 | n_classes = z.shape[-1] 204 | ang = float(np.random.uniform(-90, 90)) 205 | axes = np.random.permutation(3)[:2] 206 | 207 | nx = np.transpose(x, (0, 2, 3, 4, 1)) 208 | nz = np.reshape(z, (1, x.shape[3], x.shape[4], x.shape[1], n_classes)) 209 | nz = np.transpose(nz, (0, 4, 1, 2, 3)) 210 | 211 | nx[0] = [rotate_scipy(modality, ang, axes=axes, order=3, reshape=False) for modality in nx[0]] 212 | nx = np.transpose(nx, (0, 4, 1, 2, 3)) 213 | nz[0] = [rotate_scipy(class_map, ang, axes=axes, order=3, reshape=False) for class_map in nz[0]] 214 | nz = nz[0].argmax(axis=0) 215 | nz = np.reshape(nz, (-1,)) 216 | nz = np.reshape(one_hot(nz, n_classes), z_original_shape) 217 | 218 | nx = np.asarray(nx, dtype=x.dtype) 219 | nz = np.asarray(nz, dtype=z.dtype) 220 | 221 | return (nx, nz) 222 | 223 | def full_rotation(x, z): 224 | """ 225 | Assuming a batch size of 1. 226 | More specifically: x is (1, depth, channels, height, width) and z is (1, height*width*depth, classes) 227 | """ 228 | from scipy.ndimage.interpolation import rotate as rotate_scipy 229 | from breze.learn.data import one_hot 230 | z_original_shape = z.shape 231 | n_classes = z.shape[-1] 232 | ang = float(np.random.uniform(0, 360)) 233 | axes = np.random.permutation(3)[:2] 234 | 235 | nx = np.transpose(x, (0, 2, 3, 4, 1)) 236 | nz = np.reshape(z, (1, x.shape[3], x.shape[4], x.shape[1], n_classes)) 237 | nz = np.transpose(nz, (0, 4, 1, 2, 3)) 238 | 239 | nx[0] = [rotate_scipy(modality, ang, axes=axes, order=3, reshape=False) for modality in nx[0]] 240 | nx = np.transpose(nx, (0, 4, 1, 2, 3)) 241 | nz[0] = [rotate_scipy(class_map, ang, axes=axes, order=3, reshape=False) for class_map in nz[0]] 242 | nz = nz[0].argmax(axis=0) 243 | nz = np.reshape(nz, (-1,)) 244 | nz = np.reshape(one_hot(nz, n_classes), z_original_shape) 245 | 246 | nx = np.asarray(nx, dtype=x.dtype) 247 | nz = np.asarray(nz, dtype=z.dtype) 248 | 249 | return (nx, nz) 250 | 251 | def identity(x, z): 252 | return (x, z) 253 | 254 | def nil(x, z): 255 | nx = np.zeros(x.shape) 256 | nz = np.zeros(z.shape) 257 | 258 | return (nx, nz) 259 | 260 | def random_transformation(x, z): 261 | import random 262 | transformations = ['identity', 'random_flip', 'percentile_filter', 'full_rotation'] 263 | transform_dict = { 264 | 'identity': identity, 265 | 'random_flip': random_flip, 266 | 'percentile_filter': percentile_filter, 267 | 'full_rotation': full_rotation 268 | } 269 | 270 | transform_key = random.choice(transformations) 271 | transform_fun = transform_dict[transform_key] 272 | 273 | nx, nz = transform_fun(x, z) 274 | 275 | second_transform_key = random.choice(transformations) 276 | if second_transform_key == transform_key: 277 | return (nx, nz) 278 | else: 279 | second_transform_fun = transform_dict[second_transform_key] 280 | return second_transform_fun(nx, nz) 281 | 282 | def random_geometric_transformation(x, z): 283 | import random 284 | transformations = ['identity', 'random_flip', 'full_rotation'] 285 | transform_dict = { 286 | 'identity': identity, 287 | 'random_flip': random_flip, 288 | 'full_rotation': full_rotation 289 | } 290 | 291 | transform_key = random.choice(transformations) 292 | transform_fun = transform_dict[transform_key] 293 | 294 | nx, nz = transform_fun(x, z) 295 | 296 | second_transform_key = random.choice(transformations) 297 | if second_transform_key == transform_key: 298 | return (nx, nz) 299 | else: 300 | second_transform_fun = transform_dict[second_transform_key] 301 | return second_transform_fun(nx, nz) 302 | 303 | def random_soft_geometric_transformation(x, z): 304 | import random 305 | transformations = ['identity', 'random_flip', 'full_rotation'] 306 | transform_dict = { 307 | 'identity': identity, 308 | 'random_flip': random_flip, 309 | 'full_rotation': full_rotation 310 | } 311 | 312 | transform_key = random.choice(transformations) 313 | transform_fun = transform_dict[transform_key] 314 | 315 | nx, nz = transform_fun(x, z) 316 | 317 | return (nx, nz) --------------------------------------------------------------------------------