├── models ├── .gitignore ├── .vm1_cc.h5.icloud ├── .vm1_l2.h5.icloud ├── .vm2_cc.h5.icloud ├── .vm2_l2.h5.icloud └── .miccai_2018_70000_005.h5.icloud ├── ext ├── pynd-lib │ ├── .gitignore │ ├── pynd │ │ ├── __init__.py │ │ ├── imutils.py │ │ ├── segutils.py │ │ └── ndutils.py │ └── readme.md ├── medipy-lib │ ├── medipy │ │ ├── __init__.py │ │ ├── metrics.pyc │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── metrics.cpython-35.pyc │ │ │ ├── metrics.cpython-36.pyc │ │ │ ├── __init__.cpython-35.pyc │ │ │ └── __init__.cpython-36.pyc │ │ └── metrics.py │ └── readme.md ├── pytools-lib │ ├── readme.md │ ├── pytools │ │ ├── __init__.py │ │ ├── timer.py │ │ ├── plotting.py │ │ └── iniparse.py │ └── .gitignore └── neuron │ ├── neuron │ ├── __pycache__ │ │ ├── inits.cpython-35.pyc │ │ ├── layers.cpython-35.pyc │ │ ├── models.cpython-35.pyc │ │ ├── plot.cpython-35.pyc │ │ ├── utils.cpython-35.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── dataproc.cpython-35.pyc │ │ ├── metrics.cpython-35.pyc │ │ ├── callbacks.cpython-35.pyc │ │ └── generators.cpython-35.pyc │ ├── __init__.py │ ├── inits.py │ ├── plot.py │ └── dataproc.py │ └── README.md ├── data ├── labels.mat ├── test_seg.npz ├── .test_vol.npz.icloud ├── .atlas_norm.npz.icloud ├── MAS_atlas.txt ├── test_examples.txt └── allcsv.txt ├── src ├── datagenerators.py ├── demo.py ├── test.py ├── test_miccai2018.py ├── MAS5_train.py ├── MAS2_train.py ├── MAS3_train.py ├── MAS4_train.py ├── losses.py ├── train.py ├── train_miccai2018.py ├── networks.py ├── SAS_train.py ├── SAS_test.py ├── MAS2_test_linear.py ├── MAS3_test_linear.py ├── MAS2_test.py └── MAS4_test_linear.py └── README.md /models/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ext/pynd-lib/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/__init__.py: -------------------------------------------------------------------------------- 1 | from . import metrics -------------------------------------------------------------------------------- /ext/pytools-lib/readme.md: -------------------------------------------------------------------------------- 1 | # pytools 2 | General python tools 3 | -------------------------------------------------------------------------------- /ext/pynd-lib/pynd/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ndutils 2 | from . import segutils -------------------------------------------------------------------------------- /data/labels.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/data/labels.mat -------------------------------------------------------------------------------- /ext/medipy-lib/readme.md: -------------------------------------------------------------------------------- 1 | # MedIPy 2 | Medical Image Analysis library for Python 3 | -------------------------------------------------------------------------------- /data/test_seg.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/data/test_seg.npz -------------------------------------------------------------------------------- /ext/pynd-lib/readme.md: -------------------------------------------------------------------------------- 1 | # ND utilities 2 | Python Library for ND (n-dimensional) array operations -------------------------------------------------------------------------------- /data/.test_vol.npz.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/data/.test_vol.npz.icloud -------------------------------------------------------------------------------- /models/.vm1_cc.h5.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/models/.vm1_cc.h5.icloud -------------------------------------------------------------------------------- /models/.vm1_l2.h5.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/models/.vm1_l2.h5.icloud -------------------------------------------------------------------------------- /models/.vm2_cc.h5.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/models/.vm2_cc.h5.icloud -------------------------------------------------------------------------------- /models/.vm2_l2.h5.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/models/.vm2_l2.h5.icloud -------------------------------------------------------------------------------- /data/.atlas_norm.npz.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/data/.atlas_norm.npz.icloud -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/metrics.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/medipy-lib/medipy/metrics.pyc -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/medipy-lib/medipy/__init__.pyc -------------------------------------------------------------------------------- /models/.miccai_2018_70000_005.h5.icloud: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/models/.miccai_2018_70000_005.h5.icloud -------------------------------------------------------------------------------- /ext/pytools-lib/pytools/__init__.py: -------------------------------------------------------------------------------- 1 | from . import iniparse 2 | from . import patchlib 3 | from . import timer 4 | from . import plotting 5 | -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/inits.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/inits.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/layers.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/models.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/models.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/plot.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/plot.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/dataproc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/dataproc.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/metrics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/metrics.cpython-35.pyc -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/__pycache__/metrics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/medipy-lib/medipy/__pycache__/metrics.cpython-35.pyc -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/medipy-lib/medipy/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/callbacks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/callbacks.cpython-35.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__pycache__/generators.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/neuron/neuron/__pycache__/generators.cpython-35.pyc -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/medipy-lib/medipy/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clark-s-dev/MAS/HEAD/ext/medipy-lib/medipy/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ext/neuron/neuron/__init__.py: -------------------------------------------------------------------------------- 1 | # import various 2 | from . import dataproc 3 | from . import generators 4 | from . import callbacks 5 | from . import plot 6 | from . import metrics 7 | from . import inits 8 | from . import models 9 | from . import utils 10 | from . import layers 11 | -------------------------------------------------------------------------------- /ext/pynd-lib/pynd/imutils.py: -------------------------------------------------------------------------------- 1 | ''' image utilities ''' 2 | 3 | from numpy import np 4 | 5 | def gray2color(gray, color): 6 | ''' 7 | transform a gray image (2d array) to a color image given the color (1x3 vector) 8 | untested 9 | ''' 10 | 11 | return np.concatenate((gray * c for c in color), 2) -------------------------------------------------------------------------------- /data/MAS_atlas.txt: -------------------------------------------------------------------------------- 1 | /home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz 2 | /home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz 3 | /home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990405_vc922.npz 4 | /home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991006_vc1337.npz 5 | /home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991120_vc1456.npz 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /ext/neuron/neuron/inits.py: -------------------------------------------------------------------------------- 1 | ''' initializations for the neuron project ''' 2 | 3 | # general imports 4 | import os 5 | import numpy as np 6 | import keras.backend as K 7 | 8 | 9 | def output_init(shape, name=None, dim_ordering=None): 10 | ''' initialization for output weights''' 11 | size = (shape[0], shape[1], shape[2] - shape[3], shape[3]) 12 | 13 | # initialize output weights with random and identity 14 | rpart = np.random.random(size) 15 | # idpart_ = np.eye(size[3]) 16 | idpart_ = np.ones((size[3], size[3])) 17 | idpart = np.expand_dims(np.expand_dims(idpart_, 0), 0) 18 | value = np.concatenate((rpart, idpart), axis=2) 19 | return K.variable(value, name=name) 20 | -------------------------------------------------------------------------------- /ext/pytools-lib/.gitignore: -------------------------------------------------------------------------------- 1 | # Windows image file caches 2 | Thumbs.db 3 | ehthumbs.db 4 | 5 | # Folder config file 6 | Desktop.ini 7 | 8 | # Recycle Bin used on file shares 9 | $RECYCLE.BIN/ 10 | 11 | # Windows Installer files 12 | *.cab 13 | *.msi 14 | *.msm 15 | *.msp 16 | 17 | # Windows shortcuts 18 | *.lnk 19 | 20 | # ========================= 21 | # Operating System Files 22 | # ========================= 23 | 24 | # OSX 25 | # ========================= 26 | 27 | .DS_Store 28 | .AppleDouble 29 | .LSOverride 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear in the root of a volume 35 | .DocumentRevisions-V100 36 | .fseventsd 37 | .Spotlight-V100 38 | .TemporaryItems 39 | .Trashes 40 | .VolumeIcon.icns 41 | 42 | # Directories potentially created on remote AFP share 43 | .AppleDB 44 | .AppleDesktop 45 | Network Trash Folder 46 | Temporary Items 47 | .apdisk 48 | 49 | # python 50 | __pycache__ -------------------------------------------------------------------------------- /ext/pytools-lib/pytools/timer.py: -------------------------------------------------------------------------------- 1 | ''' A collection of general python utilities ''' 2 | 3 | import time 4 | 5 | class Timer(object): 6 | """ 7 | modified from: 8 | http://stackoverflow.com/questions/5849800/tic-toc-functions-analog-in-python 9 | a helper class for timing 10 | use: 11 | with Timer('foo_stuff'): 12 | # do some foo 13 | # do some stuff 14 | as an alternative to 15 | t = time.time() 16 | # do stuff 17 | elapsed = time.time() - t 18 | """ 19 | 20 | def __init__(self, name=None, verbose=True): 21 | self.name = name 22 | self.verbose = verbose 23 | 24 | def __enter__(self): 25 | self.tstart = time.time() 26 | 27 | def __exit__(self, type, value, traceback): 28 | if self.verbose: 29 | if self.name: 30 | print('[%s]' % self.name, end="") 31 | print('Elapsed: %6.4s' % (time.time() - self.tstart)) 32 | -------------------------------------------------------------------------------- /src/datagenerators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | def load_example_by_name(vol_name, seg_name): 6 | 7 | X = np.load(vol_name)['vol_data'] 8 | X = np.reshape(X, (1,) + X.shape + (1,)) 9 | 10 | return_vals = [X] 11 | 12 | X_seg = np.load(seg_name)['vol_data'] 13 | X_seg = np.reshape(X_seg, (1,) + X_seg.shape + (1,)) 14 | return_vals.append(X_seg) 15 | 16 | return tuple(return_vals) 17 | 18 | 19 | def example_gen(vol_names, return_segs=False, seg_dir=None): 20 | #idx = 0 21 | while(True): 22 | idx = np.random.randint(len(vol_names)) 23 | X = np.load(vol_names[idx])['vol_data'] 24 | X = np.reshape(X, (1,) + X.shape + (1,)) 25 | 26 | return_vals = [X] 27 | 28 | if(return_segs): 29 | name = os.path.basename(vol_names[idx]) 30 | X_seg = np.load(seg_dir + name[0:-8]+'aseg.npz')['vol_data'] 31 | X_seg = np.reshape(X_seg, (1,) + X_seg.shape + (1,)) 32 | return_vals.append(X_seg) 33 | 34 | # print vol_names[idx] + "," + seg_dir + name[0:-8]+'aseg.npz' 35 | 36 | yield tuple(return_vals) 37 | -------------------------------------------------------------------------------- /ext/medipy-lib/medipy/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | metrics 3 | 4 | Contact: adalca@csail.mit.edu 5 | ''' 6 | 7 | # imports 8 | import numpy as np 9 | 10 | 11 | def dice(vol1, vol2, labels=None, nargout=1): 12 | ''' 13 | Dice [1] volume overlap metric 14 | 15 | The default is to *not* return a measure for the background layer (label = 0) 16 | 17 | [1] Dice, Lee R. "Measures of the amount of ecologic association between species." 18 | Ecology 26.3 (1945): 297-302. 19 | 20 | Parameters 21 | ---------- 22 | vol1 : nd array. The first volume (e.g. predicted volume) 23 | vol2 : nd array. The second volume (e.g. "true" volume) 24 | labels : optional vector of labels on which to compute Dice. 25 | If this is not provided, Dice is computed on all non-background (non-0) labels 26 | nargout : optional control of output arguments. if 1, output Dice measure(s). 27 | if 2, output tuple of (Dice, labels) 28 | 29 | Output 30 | ------ 31 | if nargout == 1 : dice : vector of dice measures for each labels 32 | if nargout == 2 : (dice, labels) : where labels is a vector of the labels on which 33 | dice was computed 34 | ''' 35 | if labels is None: 36 | labels = np.unique(np.concatenate((vol1, vol2))) 37 | labels = np.delete(labels, np.where(labels == 0)) # remove background 38 | 39 | dicem = np.zeros(len(labels)) 40 | for idx, lab in enumerate(labels): 41 | top = 2 * np.sum(np.logical_and(vol1 == lab, vol2 == lab)) 42 | bottom = np.sum(vol1 == lab) + np.sum(vol2 == lab) 43 | bottom = np.maximum(bottom, np.finfo(float).eps) # add epsilon. 44 | dicem[idx] = top / bottom 45 | 46 | if nargout == 1: 47 | return dicem 48 | else: 49 | return (dicem, labels) 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAS & SAS 2 | [comment]: 3 | 4 | Multi Atlas Segmentation of 3D Brain MRI Based on Unsupervised Learning 5 | 6 | ## SAS(Single Atlas Segmentation) 7 | ![image](https://github.com/ShouYuqing/Images/blob/master/p1-1.png) 8 | > Change the direction between volume data and atlas data while training: 9 | ```python 10 | train([atlas,volume],[volume,flow]) 11 | ``` 12 | > Use different metrics. 13 | > Use dice score to evaluate model. 14 | ### How does segmentation work 15 | ![image](https://github.com/ShouYuqing/Images/blob/master/p1-2.png) 16 | ## MAS(Multi Atlas Segmentation) 17 | ![image](https://github.com/ShouYuqing/Images/blob/master/p1-5.png) 18 | >Label fusion 19 | 20 | >Models: MAS-2 MAS-3 MAS-4 MAS-5 (vm-1, vm-2 double) 21 | 22 | >Spatial transform: linear/nearest 23 | 24 | >Use dice score to evaluate model. 25 | 26 | ## Citation 27 | Based on [Voxelmorph](https://arxiv.org/abs/1809.05231/) and [Unsupervised learning for registration](https://arxiv.org/abs/1805.04605v1/) 28 | 29 | 30 | **Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration** 31 | [Adrian V. Dalca](http://adalca.mit.edu), [Guha Balakrishnan](http://people.csail.mit.edu/balakg/), [John Guttag](https://people.csail.mit.edu/guttag/), [Mert R. Sabuncu](http://sabuncu.engineering.cornell.edu/) 32 | MICCAI 2018. [eprint arXiv:1805.04605](https://arxiv.org/abs/1805.04605) 33 | 34 | 35 | **An Unsupervised Learning Model for Deformable Medical Image Registration** 36 | [Guha Balakrishnan](http://people.csail.mit.edu/balakg/), [Amy Zhao](http://people.csail.mit.edu/xamyzhao/), [Mert R. Sabuncu](http://sabuncu.engineering.cornell.edu/), [John Guttag](https://people.csail.mit.edu/guttag/), [Adrian V. Dalca](http://adalca.mit.edu) 37 | CVPR 2018. [eprint arXiv:1802.02604](https://arxiv.org/abs/1802.02604) 38 | -------------------------------------------------------------------------------- /ext/neuron/README.md: -------------------------------------------------------------------------------- 1 | # neuron 2 | A Neural networks toolbox for anatomical image analysis 3 | 4 | This toolbox is **currently in development**, with the goal providing a set of tools with infrastructure for medical image analysis with neural network. While the tools are somewhat general, `neuron` will generally run with `keras` on top of `tensorflow`. 5 | 6 | ### Main tools 7 | `callbacks`: a set of callbacks during keras training to help with understanding your fit, such as Dice measurements and volume-segmentation overlaps 8 | `generators`: generators for medical image volumes and various combinations of volumes, segmentation, categorical and other output 9 | `dataproc`: a set of tools for processing medical imaging data for preparation for training/testing 10 | `metrics`: metrics (most of which can be used as loss functions), such as dice or weighted categorical crossentropy. 11 | `models`: a set of flexible models (many parameters to play with...) particularly useful in medical image analysis, such as a U-net/hourglass model and a standard classifier. 12 | `layers`: a few simple layers 13 | `plot`: plotting tools, mostly for debugging models 14 | `utils`: various utilities useful in debugging. 15 | 16 | Other utilities and a few `jupyter` notebooks are also provided. 17 | 18 | ### Requirements: 19 | - tensorflow 20 | - keras and all of its requirements (e.g. hyp5) 21 | - numpy, scipy 22 | - tqdm 23 | - [python libraries](https://github.com/search?q=user%3Aadalca+topic%3Apython) from @adalca github account 24 | 25 | ### Development: 26 | Please contact Adrian Dalca, adalca@csail.mit.edu for question related to `neuron` 27 | 28 | ### Papers: 29 | **Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation** 30 | AV Dalca, J Guttag, MR Sabuncu 31 | *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.* 32 | 33 | **Spatial Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation** 34 | A.V. Dalca, J. Guttag, and M. R. Sabuncu 35 | *NIPS ML4H: Machine Learning for Health. 2017.* 36 | -------------------------------------------------------------------------------- /src/demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | demo for testing the function in the neuron & voxelmorph 3 | ''' 4 | import os 5 | import sys 6 | import glob 7 | 8 | # third party 9 | import tensorflow as tf 10 | import scipy.io as sio 11 | import numpy as np 12 | from keras.backend.tensorflow_backend import set_session 13 | from scipy.interpolate import interpn 14 | import matplotlib.pyplot as plt 15 | 16 | # project 17 | sys.path.append('../ext/medipy-lib') 18 | sys.path.append('../ext/neuron') 19 | sys.path.append('../ext/pynd-lib') 20 | sys.path.append('../ext/pytools-lib') 21 | 22 | import medipy 23 | import networks 24 | from medipy.metrics import dice 25 | import datagenerators 26 | import neuron as nu 27 | 28 | X_vol, X_seg = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz', 29 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz') # (160, 192, 224) 30 | print('volume shape') 31 | print(X_vol.shape) 32 | print('seg shape') 33 | print(X_seg.shape) 34 | #X_seg_slice = X_seg[0, :, :, :, 0 ] 35 | #print(X_seg_slice) 36 | #X_seg_slice.reshape([X_seg_slice.shape[0],X_seg_slice.shape[2]]) 37 | #print(X_seg_slice.shape) 38 | #X_seg_slice = X_seg_slice.reshape((X_seg_slice.shape[1],X_seg_slice.shape[0],X_seg_slice.shape[2])) 39 | #X_seg_slice = X_seg_slice.reshape((X_seg_slice.shape[2],X_seg_slice.shape[1],X_seg_slice.shape[0])) 40 | #for i in range(0,X_seg_slice.shape[0]): 41 | # list.insert(X_seg_slice[i,:,:]) 42 | #X_seg_slice = X_seg_slice.reshape([X_seg_slice.shape[0],X_seg_slice.shape[1],X_seg_slice.shape[2]]) 43 | #X_seg_slice = [X_seg_slice] 44 | #fig,axs = nu.plot.slices(X_seg_slice) 45 | #fig.set_size_inches(width, rows/cols*width) 46 | #plt.tight_layout() 47 | #print(fig.shape) 48 | #fig.savefig("1.pdf") 49 | #warp_seg = X_seg_slice 50 | #warp_seg = X_seg_slice.reshape((warp_seg.shape[1], warp_seg.shape[0], warp_seg.shape[2])) 51 | #warp_seg2 = np.empty(shape=(warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 52 | #for i in range(0, warp_seg.shape[1]): 53 | # warp_seg2[i, :, :] = np.transpose(warp_seg[:, i, :]) 54 | #nu.plot.slices(warp_seg) 55 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # py imports 2 | import os 3 | import sys 4 | import glob 5 | 6 | # third party 7 | import tensorflow as tf 8 | import scipy.io as sio 9 | import numpy as np 10 | from keras.backend.tensorflow_backend import set_session 11 | from scipy.interpolate import interpn 12 | 13 | # project 14 | sys.path.append('../ext/medipy-lib') 15 | import medipy 16 | import networks 17 | from medipy.metrics import dice 18 | import datagenerators 19 | 20 | 21 | def test(model_name, iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 22 | """ 23 | test 24 | 25 | nf_enc and nf_dec 26 | #nf_dec = [32,32,32,32,32,16,16,3] 27 | # This needs to be changed. Ideally, we could just call load_model, and we wont have to 28 | # specify the # of channels here, but the load_model is not working with the custom loss... 29 | """ 30 | 31 | gpu = '/gpu:' + str(gpu_id) 32 | 33 | # Anatomical labels we want to evaluate 34 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 35 | 36 | atlas = np.load('../data/atlas_norm.npz') 37 | atlas_vol = atlas['vol'] 38 | atlas_seg = atlas['seg'] 39 | atlas_vol = np.reshape(atlas_vol, (1,)+atlas_vol.shape+(1,)) 40 | 41 | config = tf.ConfigProto() 42 | config.gpu_options.allow_growth = True 43 | config.allow_soft_placement = True 44 | set_session(tf.Session(config=config)) 45 | 46 | # load weights of model 47 | with tf.device(gpu): 48 | net = networks.unet(vol_size, nf_enc, nf_dec) 49 | net.load_weights('../models/' + model_name + 50 | '/' + str(iter_num) + '.h5') 51 | 52 | xx = np.arange(vol_size[1]) 53 | yy = np.arange(vol_size[0]) 54 | zz = np.arange(vol_size[2]) 55 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) 56 | 57 | X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 58 | 59 | with tf.device(gpu): 60 | pred = net.predict([X_vol, atlas_vol]) 61 | 62 | # Warp segments with flow 63 | flow = pred[1][0, :, :, :, :] 64 | sample = flow+grid 65 | sample = np.stack((sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3) 66 | warp_seg = interpn((yy, xx, zz), X_seg[0, :, :, :, 0], sample, method='nearest', bounds_error=False, fill_value=0) 67 | 68 | vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2) 69 | print(np.mean(vals), np.std(vals)) 70 | 71 | 72 | if __name__ == "__main__": 73 | test(sys.argv[1], sys.argv[2], sys.argv[3]) 74 | -------------------------------------------------------------------------------- /data/test_examples.txt: -------------------------------------------------------------------------------- 1 | GSP_081030_NC89WB_FS_mri_talairach_norm.npz 2 | GSP_081101_NK25FK_FS_mri_talairach_norm.npz 3 | GSP_090106_JT53JK_FS_mri_talairach_norm.npz 4 | 5 | GSP_090110_MG55BK_FS_mri_talairach_norm.npz 6 | 7 | GSP_100616_GT66DH_FS_mri_talairach_norm.npz 8 | 9 | GSP_110103_CR64WU_FS_mri_talairach_norm.npz 10 | 11 | GSP_110103_JV75FH_FS_mri_talairach_norm.npz 12 | 13 | GSP_120120_QV52VK_FS_mri_talairach_norm.npz 14 | 15 | GSP_120106_RJ68UP_FS_mri_talairach_norm.npz 16 | 17 | GSP_130103_DW69GU_FS_mri_talairach_norm.npz 18 | 19 | GSP_130103_RP37BU_FS_mri_talairach_norm.npz 20 | 21 | GSP_140116_DB34EP_FS_mri_talairach_norm.npz 22 | 23 | GSP_140116_TU88FC_FS_mri_talairach_norm.npz 24 | 25 | MCIC_Site_A_A00036501_mri_talairach_norm.npz 26 | 27 | MCIC_Site_A_A00036517_mri_talairach_norm.npz 28 | 29 | MCIC_Site_A_A00036518_mri_talairach_norm.npz 30 | 31 | MCIC_Site_A_A00036520_mri_talairach_norm.npz 32 | 33 | MCIC_Site_C_A00036131_mri_talairach_norm.npz 34 | 35 | MCIC_Site_C_A00036133_mri_talairach_norm.npz 36 | 37 | MCIC_Site_C_A00036136_mri_talairach_norm.npz 38 | 39 | MCIC_Site_C_A00036169_mri_talairach_norm.npz 40 | 41 | MCIC_Site_D_A00036412_mri_talairach_norm.npz 42 | 43 | MCIC_Site_D_A00036461_mri_talairach_norm.npz 44 | 45 | MCIC_Site_D_A00036481_mri_talairach_norm.npz 46 | 47 | MCIC_Site_D_A00036492_mri_talairach_norm.npz 48 | 49 | OASIS_OAS1_0045_MR1_mri_talairach_norm.npz 50 | 51 | OASIS_OAS1_0049_MR1_mri_talairach_norm.npz 52 | 53 | OASIS_OAS1_0050_MR1_mri_talairach_norm.npz 54 | 55 | OASIS_OAS1_0051_MR1_mri_talairach_norm.npz 56 | 57 | OASIS_OAS1_0101_MR2_mri_talairach_norm.npz 58 | 59 | OASIS_OAS1_0108_MR1_mri_talairach_norm.npz 60 | 61 | OASIS_OAS1_0113_MR1_mri_talairach_norm.npz 62 | 63 | OASIS_OAS1_0117_MR2_mri_talairach_norm.npz 64 | 65 | OASIS_OAS1_0388_MR1_mri_talairach_norm.npz 66 | 67 | OASIS_OAS1_0390_MR1_mri_talairach_norm.npz 68 | 69 | OASIS_OAS1_0392_MR1_mri_talairach_norm.npz 70 | 71 | OASIS_OAS1_0395_MR1_mri_talairach_norm.npz 72 | 73 | PPMI_3053_mri_talairach_norm.npz 74 | 75 | PPMI_3056_mri_talairach_norm.npz 76 | 77 | PPMI_3059_mri_talairach_norm.npz 78 | 79 | PPMI_3061_mri_talairach_norm.npz 80 | 81 | PPMI_3668_mri_talairach_norm.npz 82 | 83 | PPMI_3614_mri_talairach_norm.npz 84 | 85 | PPMI_3620_mri_talairach_norm.npz 86 | 87 | PPMI_3621_mri_talairach_norm.npz 88 | 89 | PPMI_4001_mri_talairach_norm.npz 90 | 91 | PPMI_4004_mri_talairach_norm.npz 92 | 93 | PPMI_4005_mri_talairach_norm.npz 94 | 95 | PPMI_4006_mri_talairach_norm.npz -------------------------------------------------------------------------------- /ext/pytools-lib/pytools/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | function to help in plotting 3 | """ 4 | 5 | import numpy as np 6 | import six 7 | import matplotlib 8 | import matplotlib.pylab as plt 9 | 10 | 11 | def jitter(n=256, colmap="hsv", nargout=1): 12 | """ 13 | jitter colormap of size [n x 3]. The jitter colormap will (likely) have distinct colors, with 14 | neighburing colors being quite different 15 | 16 | Parameters: 17 | n (optional): the size of the colormap. default:256 18 | colmap: the colormap to scramble. Either a string passable to plt.get_cmap, 19 | or a n-by-3 or n-by-4 array 20 | 21 | Algorithm: given a (preferably smooth) colormap as a starting point (default "hsv"), jitter 22 | reorders the colors by skipping roughly a quarter of the colors. So given jitter(9, "hsv"), 23 | jitter would take color numbers, in order, 1, 3, 5, 7, 9, 2, 4, 6, 8. 24 | 25 | Contact: adalca@csail.mit.edu 26 | """ 27 | 28 | # get a 1:n vector 29 | idx = range(n) 30 | 31 | # roughly compute the quarter mark. in hsv, a quarter is enough to see a significant col change 32 | m = np.maximum(np.round(0.25 * n), 1).astype(int) 33 | 34 | # compute a new order, by reshaping this index array as a [m x ?] matrix, then vectorizing in 35 | # the opposite direction 36 | 37 | # pad with -1 to make it transformable to a square 38 | nb_elems = np.ceil(n / m) * m 39 | idx = np.pad(idx, [0, (nb_elems - n).astype(int)], 'constant', constant_values=-1) 40 | 41 | # permute elements by resizing to a matrix, transposing, and re-flatteneing 42 | idxnew = np.array(np.reshape(idx, [m, (nb_elems // m).astype(int)]).transpose().flatten()) 43 | 44 | # throw away the extra elements 45 | idxnew = idxnew[np.where(idxnew >= 0)] 46 | assert len(idxnew) == n, "jitter: something went wrong with some inner logic :(" 47 | 48 | # get colormap and scramble it 49 | if isinstance(colmap, six.string_types): 50 | cmap = plt.get_cmap(colmap, nb_elems) 51 | scrambled_cmap = cmap(idxnew) 52 | else: 53 | # assumes colmap is a nx3 or nx4 54 | assert colmap.shape[0] == n 55 | assert colmap.shape[1] == 3 or colmap.shape[1] == 4 56 | scrambled_cmap = colmap[idxnew, :] 57 | 58 | new_cmap = matplotlib.colors.ListedColormap(scrambled_cmap) 59 | if nargout == 1: 60 | return new_cmap 61 | else: 62 | assert nargout == 2 63 | return (new_cmap, scrambled_cmap) 64 | -------------------------------------------------------------------------------- /ext/neuron/neuron/plot.py: -------------------------------------------------------------------------------- 1 | ''' plot tools for the neuron project ''' 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable # plotting 6 | 7 | def slices(slices_in, # the 2D slices 8 | titles=None, # list of titles 9 | cmaps=None, # list of colormaps 10 | norms=None, # list of normalizations 11 | do_colorbars=True, # option to show colorbars on each slice 12 | grid=True, # option to plot the images in a grid or a single row 13 | width=200, # width in in 14 | show=False, # option to actually show the plot (plt.show()) 15 | imshow_args=None): 16 | ''' plot a grid of slices (2d images) ''' 17 | 18 | # input processing 19 | nb_plots = len(slices_in) 20 | def input_check(inputs, nb_plots, name): 21 | ''' change input from None/single-link ''' 22 | assert (inputs is None) or (len(inputs) == nb_plots) or (len(inputs) == 1), \ 23 | 'number of %s is incorrect' % name 24 | if inputs is None: 25 | inputs = [None] 26 | if len(inputs) == 1: 27 | inputs = [inputs[0] for i in range(nb_plots)] 28 | return inputs 29 | 30 | titles = input_check(titles, nb_plots, 'titles') 31 | cmaps = input_check(cmaps, nb_plots, 'cmaps') 32 | norms = input_check(norms, nb_plots, 'norms') 33 | imshow_args = input_check(imshow_args, nb_plots, 'imshow_args') 34 | for idx, ia in enumerate(imshow_args): 35 | imshow_args[idx] = {} if ia is None else ia 36 | 37 | # figure out the number of rows and columns 38 | if grid: 39 | if isinstance(grid, bool): 40 | rows = np.floor(np.sqrt(nb_plots)).astype(int) 41 | cols = np.ceil(nb_plots/rows).astype(int) 42 | else: 43 | assert isinstance(grid, (list, tuple)),\ 44 | "grid should either be bool or [rows,cols]" 45 | rows, cols = grid 46 | else: 47 | rows = 1 48 | cols = nb_plots 49 | 50 | # prepare the subplot 51 | fig, axs = plt.subplots(rows, cols) 52 | if rows == 1 and cols == 1: 53 | axs = [axs] 54 | 55 | for i in range(nb_plots): 56 | col = np.remainder(i, cols) 57 | row = np.floor(i/cols).astype(int) 58 | 59 | # get row and column axes 60 | row_axs = axs if rows == 1 else axs[row] 61 | ax = row_axs[col] 62 | 63 | # turn off axis 64 | ax.axis('off') 65 | 66 | # some cleanup 67 | if titles is not None: 68 | ax.title.set_text(titles[i]) 69 | 70 | # show figure 71 | im_ax = ax.imshow(slices_in[i], cmap=cmaps[i], interpolation="nearest", norm=norms[i], **imshow_args[i]) 72 | 73 | # colorbars 74 | # http://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph 75 | if do_colorbars and cmaps[i] is not None: 76 | divider = make_axes_locatable(ax) 77 | cax = divider.append_axes("right", size="5%", pad=0.05) 78 | fig.colorbar(im_ax, cax=cax) 79 | 80 | # show the plots 81 | fig.set_size_inches(width, rows/cols*width) 82 | plt.tight_layout() 83 | plt.savefig("1.pdf") 84 | if show: 85 | plt.show() 86 | 87 | return (fig, axs) 88 | -------------------------------------------------------------------------------- /ext/pynd-lib/pynd/segutils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | nd segmentation (label map) utilities 3 | 4 | Contact: adalca@csail.mit.edu 5 | ''' 6 | 7 | import numpy as np 8 | from . import ndutils as nd 9 | 10 | def seg2contour(seg, exclude_zero=True, contour_type='inner', thickness=1): 11 | ''' 12 | transform nd segmentation (label maps) to contour maps 13 | 14 | Parameters 15 | ---------- 16 | seg : nd array 17 | volume of labels/segmentations 18 | exclude_zero : optional logical 19 | whether to exclude the zero label. 20 | default True 21 | contour_type : string 22 | where to draw contour voxels relative to label 'inner','outer', or 'both' 23 | 24 | Output 25 | ------ 26 | con : nd array 27 | nd array (volume) of contour maps 28 | 29 | See Also 30 | -------- 31 | seg_overlap 32 | ''' 33 | 34 | # extract unique labels 35 | labels = np.unique(seg) 36 | if exclude_zero: 37 | labels = np.delete(labels, np.where(labels == 0)) 38 | 39 | # get the contour of each label 40 | contour_map = seg * 0 41 | for lab in labels: 42 | 43 | # extract binary label map for this label 44 | label_map = seg == lab 45 | 46 | # extract contour map for this label 47 | thickness = thickness + 0.01 48 | label_contour_map = nd.bw2contour(label_map, type=contour_type, thr=thickness) 49 | 50 | # assign contour to this label 51 | contour_map[label_contour_map] = lab 52 | 53 | return contour_map 54 | 55 | 56 | 57 | def seg_overlap(vol, seg, do_contour=True, do_rgb=True, cmap=None, thickness=1.0): 58 | ''' 59 | overlap a nd volume and nd segmentation (label map) 60 | 61 | do_contour should be None, boolean, or contour_type from seg2contour 62 | 63 | not well tested yet. 64 | ''' 65 | 66 | # compute contours for each label if necessary 67 | if do_contour is not None and do_contour is not False: 68 | if not isinstance(do_contour, str): 69 | do_contour = 'inner' 70 | seg = seg2contour(seg, contour_type=do_contour, thickness=thickness) 71 | 72 | # compute a rgb-contour map 73 | if do_rgb: 74 | if cmap is None: 75 | nb_labels = np.max(seg).astype(int) + 1 76 | colors = np.random.random((nb_labels, 3)) * 0.5 + 0.5 77 | colors[0, :] = [0, 0, 0] 78 | else: 79 | colors = cmap[:, 0:3] 80 | 81 | olap = colors[seg.flat, :] 82 | sf = seg.flat == 0 83 | for d in range(3): 84 | olap[sf, d] = vol.flat[sf] 85 | olap = np.reshape(olap, vol.shape + (3, )) 86 | 87 | else: 88 | olap = seg 89 | olap[seg == 0] = vol[seg == 0] 90 | 91 | return olap 92 | 93 | 94 | def seg_overlay(vol, seg, do_rgb=True, seg_wt=0.5, cmap=None): 95 | ''' 96 | overlap a nd volume and nd segmentation (label map) 97 | 98 | not well tested yet. 99 | ''' 100 | 101 | # compute contours for each label if necessary 102 | 103 | # compute a rgb-contour map 104 | if do_rgb: 105 | if cmap is None: 106 | nb_labels = np.max(seg) + 1 107 | colors = np.random.random((nb_labels, 3)) * 0.5 + 0.5 108 | colors[0, :] = [0, 0, 0] 109 | else: 110 | colors = cmap[:, 0:3] 111 | 112 | seg_flat = colors[seg.flat, :] 113 | seg_rgb = np.reshape(seg_flat, vol.shape + (3, )) 114 | 115 | # get the overlap image 116 | olap = seg_rgb * seg_wt + np.expand_dims(vol, -1) * (1-seg_wt) 117 | 118 | else: 119 | olap = seg * seg_wt + vol * (1-seg_wt) 120 | 121 | return olap 122 | 123 | -------------------------------------------------------------------------------- /src/test_miccai2018.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test models for MICCAI 2018 submission of VoxelMorph. 3 | """ 4 | 5 | # py imports 6 | import os 7 | import sys 8 | import glob 9 | 10 | # third party 11 | import tensorflow as tf 12 | import scipy.io as sio 13 | import numpy as np 14 | import keras 15 | from keras.backend.tensorflow_backend import set_session 16 | from scipy.interpolate import interpn 17 | 18 | # project 19 | sys.path.append('../ext/medipy-lib') 20 | import medipy 21 | import networks 22 | # import util 23 | from medipy.metrics import dice 24 | import datagenerators 25 | 26 | # Test file and anatomical labels we want to evaluate 27 | test_brain_file = open('.../data/test_examples.txt') 28 | test_brain_strings = test_brain_file.readlines() 29 | test_brain_strings = [x.strip() for x in test_brain_strings] 30 | n_batches = len(test_brain_strings) 31 | good_labels = sio.loadmat('../data/labels.mat')['labels'][0] 32 | 33 | # atlas files 34 | atlas = np.load('../data/atlas_norm.npz') 35 | atlas_vol = atlas['vol'][np.newaxis, ..., np.newaxis] 36 | atlas_seg = atlas['seg'] 37 | 38 | def test(gpu_id, model_dir, iter_num, 39 | compute_type = 'GPU', # GPU or CPU 40 | vol_size=(160,192,224), 41 | nf_enc=[16,32,32,32], 42 | nf_dec=[32,32,32,32,16,3], 43 | save_file=None): 44 | """ 45 | test via segmetnation propagation 46 | works by iterating over some iamge files, registering them to atlas, 47 | propagating the warps, then computing Dice with atlas segmentations 48 | """ 49 | 50 | # GPU handling 51 | gpu = '/gpu:' + str(gpu_id) 52 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 53 | config = tf.ConfigProto() 54 | config.gpu_options.allow_growth = True 55 | config.allow_soft_placement = True 56 | set_session(tf.Session(config=config)) 57 | 58 | # load weights of model 59 | with tf.device(gpu): 60 | # if testing miccai run, should be xy indexing. 61 | net = networks.miccai2018_net(vol_size, nf_enc,nf_dec, use_miccai_int=True, indexing='xy') 62 | net.load_weights(os.path.join(model_dir, str(iter_num) + '.h5')) 63 | 64 | # compose diffeomorphic flow output model 65 | diff_net = keras.models.Model(net.inputs, net.get_layer('diffflow').output) 66 | 67 | # NN transfer model 68 | nn_trf_model = networks.nn_trf(vol_size) 69 | 70 | # if CPU, prepare grid 71 | if compute_type == 'CPU': 72 | grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4) 73 | 74 | # prepare a matrix of dice values 75 | dice_vals = np.zeros((len(good_labels), n_batches)) 76 | for k in range(n_batches): 77 | # get data 78 | vol_name, seg_name = test_brain_strings[k].split(",") 79 | X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name) 80 | 81 | # predict transform 82 | with tf.device(gpu): 83 | pred = diff_net.predict([X_vol, atlas_vol]) 84 | 85 | # Warp segments with flow 86 | if compute_type == 'CPU': 87 | flow = pred[0, :, :, :, :] 88 | warp_seg = util.warp_seg(X_seg, flow, grid=grid, xx=xx, yy=yy, zz=zz) 89 | 90 | else: # GPU 91 | warp_seg = nn_trf_model.predict([X_seg, pred])[0,...,0] 92 | 93 | # compute Volume Overlap (Dice) 94 | dice_vals[:, k] = dice(warp_seg, atlas_seg, labels=good_labels) 95 | print('%3d %5.3f %5.3f' % (k, np.mean(dice_vals[:, k]), np.mean(np.mean(dice_vals[:, :k+1])))) 96 | 97 | if save_file is not None: 98 | sio.savemat(save_file, {'dice_vals': dice_vals, 'labels': good_labels}) 99 | 100 | 101 | 102 | 103 | if __name__ == "__main__": 104 | """ 105 | assuming the model is model_dir/iter_num.h5 106 | python test_miccai2018.py gpu_id model_dir iter_num 107 | """ 108 | test(sys.argv[1], sys.argv[2], sys.argv[3]) 109 | -------------------------------------------------------------------------------- /src/MAS5_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | # python imports 7 | import os 8 | import glob 9 | import sys 10 | import random 11 | from argparse import ArgumentParser 12 | 13 | # third-party imports 14 | import tensorflow as tf 15 | import numpy as np 16 | from keras.backend.tensorflow_backend import set_session 17 | from keras.optimizers import Adam 18 | from keras.models import load_model, Model 19 | 20 | 21 | import datagenerators 22 | import networks 23 | import losses 24 | 25 | 26 | vol_size = (160, 192, 224) 27 | # train data preparation 28 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 29 | # find all the path of .npz file in the directory 30 | # read training data 31 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 32 | # shuffle the path of .npz file 33 | # shuffle the training data 34 | random.shuffle(train_vol_names) 35 | 36 | # read the only one atlas data 37 | #atlas = np.load('../data/atlas_norm.npz') 38 | #atlas_vol = atlas['vol'] 39 | 40 | # add two more dimension into the atlas data 41 | #atlas_vol = np.reshape(atlas_vol, (1,) + atlas_vol.shape+(1,)) 42 | 43 | # atlas_list: several atlas were read 44 | atlas_file = open('../data/MAS_atlas.txt') 45 | atlas_strings = atlas_file.readlines() 46 | lenn = 5 47 | atlas_list = list() 48 | for i in range(0,lenn): 49 | st = atlas_strings[i] 50 | atlas_add = np.load(st.strip()) 51 | atlas_add = atlas_add['vol_data'] 52 | atlas_add = np.reshape(atlas_add,(1,)+atlas_add.shape+(1,)) 53 | atlas_list.append(atlas_add) 54 | 55 | # read atlas_norm as atlas used for training 56 | #atlas = np.load('../data/atlas_norm.npz') 57 | #atlas = atlas['vol'] 58 | #atlas = np.reshape(atlas,(1,)+atlas.shape+(1,)) 59 | #atlas_list.append(atlas) 60 | 61 | list_num = len(atlas_list) 62 | 63 | def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): 64 | 65 | model_dir = '/home/ys895/MAS_Models' 66 | if not os.path.isdir(model_dir): 67 | os.mkdir(model_dir) 68 | 69 | gpu = '/gpu:' + str(gpu_id) 70 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | set_session(tf.Session(config=config)) 75 | 76 | 77 | # UNET filters 78 | nf_enc = [16,32,32,32] 79 | if(model == 'vm1'): 80 | nf_dec = [32,32,32,32,8,8,3] 81 | else: 82 | nf_dec = [32,32,32,32,32,16,16,3] 83 | 84 | with tf.device(gpu): 85 | model = networks.unet(vol_size, nf_enc, nf_dec) 86 | if(load_iter != 0): 87 | model.load_weights('/home/ys895/MAS_Models/' + str(load_iter) + '.h5') 88 | 89 | model.compile(optimizer=Adam(lr=lr), loss=[ 90 | losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) 91 | # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') 92 | 93 | # return the data, add one more dimension into the data 94 | train_example_gen = datagenerators.example_gen(train_vol_names) 95 | zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) 96 | 97 | 98 | # In this part, the code inputs the data into the model 99 | # Before this part, the model was set 100 | for step in range(1, n_iterations+1): 101 | # choose randomly one of the atlas from the atlas_list 102 | rand_num = random.randint(0, list_num-1) 103 | atlas_vol = atlas_list[rand_num] 104 | 105 | #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow 106 | X = train_example_gen.__next__()[0] 107 | train_loss = model.train_on_batch( 108 | [atlas_vol, X], [X, zero_flow]) 109 | 110 | if not isinstance(train_loss, list): 111 | train_loss = [train_loss] 112 | 113 | printLoss(step, 1, train_loss) 114 | 115 | if(step % model_save_iter == 0): 116 | model.save(model_dir + '/' + str(load_iter+step) + '.h5') 117 | 118 | 119 | def printLoss(step, training, train_loss): 120 | s = str(step) + "," + str(training) 121 | 122 | if(isinstance(train_loss, list) or isinstance(train_loss, np.ndarray)): 123 | for i in range(len(train_loss)): 124 | s += "," + str(train_loss[i]) 125 | else: 126 | s += "," + str(train_loss) 127 | 128 | print(s) 129 | sys.stdout.flush() 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | parser = ArgumentParser() 135 | parser.add_argument("--model", type=str,dest="model", 136 | choices=['vm1','vm2'],default='vm2', 137 | help="Voxelmorph-1 or 2") 138 | parser.add_argument("--gpu", type=int,default=0, 139 | dest="gpu_id", help="gpu id number") 140 | parser.add_argument("--lr", type=float, 141 | dest="lr", default=1e-4,help="learning rate") 142 | parser.add_argument("--iters", type=int, 143 | dest="n_iterations", default=15000, 144 | help="number of iterations") 145 | parser.add_argument("--lambda", type=float, 146 | dest="reg_param", default=1.0, 147 | help="regularization parameter") 148 | parser.add_argument("--checkpoint_iter", type=int, 149 | dest="model_save_iter", default=500, 150 | help="frequency of model saves") 151 | parser.add_argument("--load_iter", type=int, 152 | dest="load_iter", default=0, 153 | help="the iteratons of models to load") 154 | 155 | args = parser.parse_args() 156 | train(**vars(args)) 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/MAS2_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | # python imports 7 | import os 8 | import glob 9 | import sys 10 | import random 11 | from argparse import ArgumentParser 12 | 13 | # third-party imports 14 | import tensorflow as tf 15 | import numpy as np 16 | from keras.backend.tensorflow_backend import set_session 17 | from keras.optimizers import Adam 18 | from keras.models import load_model, Model 19 | 20 | 21 | import datagenerators 22 | import networks 23 | import losses 24 | 25 | 26 | vol_size = (160, 192, 224) 27 | # train data preparation 28 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 29 | # find all the path of .npz file in the directory 30 | # read training data 31 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 32 | # shuffle the path of .npz file 33 | # shuffle the training data 34 | random.shuffle(train_vol_names) 35 | 36 | # read the only one atlas data 37 | #atlas = np.load('../data/atlas_norm.npz') 38 | #atlas_vol = atlas['vol'] 39 | 40 | # add two more dimension into the atlas data 41 | #atlas_vol = np.reshape(atlas_vol, (1,) + atlas_vol.shape+(1,)) 42 | 43 | # atlas_list: several atlas were read 44 | atlas_file = open('../data/MAS_atlas.txt') 45 | atlas_strings = atlas_file.readlines() 46 | lenn = 2 47 | atlas_list = list() 48 | for i in range(0,lenn): 49 | st = atlas_strings[i] 50 | atlas_add = np.load(st.strip()) 51 | atlas_add = atlas_add['vol_data'] 52 | atlas_add = np.reshape(atlas_add,(1,)+atlas_add.shape+(1,)) 53 | atlas_list.append(atlas_add) 54 | 55 | # read atlas_norm as atlas used for training 56 | #atlas = np.load('../data/atlas_norm.npz') 57 | #atlas = atlas['vol'] 58 | #atlas = np.reshape(atlas,(1,)+atlas.shape+(1,)) 59 | #atlas_list.append(atlas) 60 | 61 | list_num = len(atlas_list) 62 | 63 | def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): 64 | 65 | model_dir = '/home/ys895/MAS2_Models' 66 | if not os.path.isdir(model_dir): 67 | os.mkdir(model_dir) 68 | 69 | gpu = '/gpu:' + str(gpu_id) 70 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | set_session(tf.Session(config=config)) 75 | 76 | 77 | # UNET filters 78 | nf_enc = [16,32,32,32] 79 | if(model == 'vm1'): 80 | nf_dec = [32,32,32,32,8,8,3] 81 | else: 82 | nf_dec = [32,32,32,32,32,16,16,3] 83 | 84 | with tf.device(gpu): 85 | model = networks.unet(vol_size, nf_enc, nf_dec) 86 | if(load_iter != 0): 87 | model.load_weights('/home/ys895/MAS2_Models/' + str(load_iter) + '.h5') 88 | 89 | model.compile(optimizer=Adam(lr=lr), loss=[ 90 | losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) 91 | # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') 92 | 93 | # return the data, add one more dimension into the data 94 | train_example_gen = datagenerators.example_gen(train_vol_names) 95 | zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) 96 | 97 | 98 | # In this part, the code inputs the data into the model 99 | # Before this part, the model was set 100 | for step in range(1, n_iterations+1): 101 | # choose randomly one of the atlas from the atlas_list 102 | rand_num = random.randint(0, list_num-1) 103 | atlas_vol = atlas_list[rand_num] 104 | 105 | #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow 106 | X = train_example_gen.__next__()[0] 107 | train_loss = model.train_on_batch( 108 | [atlas_vol, X], [X, zero_flow]) 109 | 110 | if not isinstance(train_loss, list): 111 | train_loss = [train_loss] 112 | 113 | printLoss(step, 1, train_loss) 114 | 115 | if(step % model_save_iter == 0): 116 | model.save(model_dir + '/' + str(load_iter+step) + '.h5') 117 | 118 | 119 | def printLoss(step, training, train_loss): 120 | s = str(step) + "," + str(training) 121 | 122 | if(isinstance(train_loss, list) or isinstance(train_loss, np.ndarray)): 123 | for i in range(len(train_loss)): 124 | s += "," + str(train_loss[i]) 125 | else: 126 | s += "," + str(train_loss) 127 | 128 | print(s) 129 | sys.stdout.flush() 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | parser = ArgumentParser() 135 | parser.add_argument("--model", type=str,dest="model", 136 | choices=['vm1','vm2'],default='vm2', 137 | help="Voxelmorph-1 or 2") 138 | parser.add_argument("--gpu", type=int,default=0, 139 | dest="gpu_id", help="gpu id number") 140 | parser.add_argument("--lr", type=float, 141 | dest="lr", default=1e-4,help="learning rate") 142 | parser.add_argument("--iters", type=int, 143 | dest="n_iterations", default=15000, 144 | help="number of iterations") 145 | parser.add_argument("--lambda", type=float, 146 | dest="reg_param", default=1.0, 147 | help="regularization parameter") 148 | parser.add_argument("--checkpoint_iter", type=int, 149 | dest="model_save_iter", default=500, 150 | help="frequency of model saves") 151 | parser.add_argument("--load_iter", type=int, 152 | dest="load_iter", default=0, 153 | help="the iteratons of models to load") 154 | 155 | args = parser.parse_args() 156 | train(**vars(args)) 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/MAS3_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | # python imports 7 | import os 8 | import glob 9 | import sys 10 | import random 11 | from argparse import ArgumentParser 12 | 13 | # third-party imports 14 | import tensorflow as tf 15 | import numpy as np 16 | from keras.backend.tensorflow_backend import set_session 17 | from keras.optimizers import Adam 18 | from keras.models import load_model, Model 19 | 20 | 21 | import datagenerators 22 | import networks 23 | import losses 24 | 25 | 26 | vol_size = (160, 192, 224) 27 | # train data preparation 28 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 29 | # find all the path of .npz file in the directory 30 | # read training data 31 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 32 | # shuffle the path of .npz file 33 | # shuffle the training data 34 | random.shuffle(train_vol_names) 35 | 36 | # read the only one atlas data 37 | #atlas = np.load('../data/atlas_norm.npz') 38 | #atlas_vol = atlas['vol'] 39 | 40 | # add two more dimension into the atlas data 41 | #atlas_vol = np.reshape(atlas_vol, (1,) + atlas_vol.shape+(1,)) 42 | 43 | # atlas_list: several atlas were read 44 | atlas_file = open('../data/MAS_atlas.txt') 45 | atlas_strings = atlas_file.readlines() 46 | lenn = 3 47 | atlas_list = list() 48 | for i in range(0,lenn): 49 | st = atlas_strings[i] 50 | atlas_add = np.load(st.strip()) 51 | atlas_add = atlas_add['vol_data'] 52 | atlas_add = np.reshape(atlas_add,(1,)+atlas_add.shape+(1,)) 53 | atlas_list.append(atlas_add) 54 | 55 | # read atlas_norm as atlas used for training 56 | #atlas = np.load('../data/atlas_norm.npz') 57 | #atlas = atlas['vol'] 58 | #atlas = np.reshape(atlas,(1,)+atlas.shape+(1,)) 59 | #atlas_list.append(atlas) 60 | 61 | list_num = len(atlas_list) 62 | 63 | def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): 64 | 65 | model_dir = '/home/ys895/MAS3_Models' 66 | if not os.path.isdir(model_dir): 67 | os.mkdir(model_dir) 68 | 69 | gpu = '/gpu:' + str(gpu_id) 70 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | set_session(tf.Session(config=config)) 75 | 76 | 77 | # UNET filters 78 | nf_enc = [16,32,32,32] 79 | if(model == 'vm1'): 80 | nf_dec = [32,32,32,32,8,8,3] 81 | else: 82 | nf_dec = [32,32,32,32,32,16,16,3] 83 | 84 | with tf.device(gpu): 85 | model = networks.unet(vol_size, nf_enc, nf_dec) 86 | if(load_iter != 0): 87 | model.load_weights('/home/ys895/MAS3_Models/' + str(load_iter) + '.h5') 88 | 89 | model.compile(optimizer=Adam(lr=lr), loss=[ 90 | losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) 91 | # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') 92 | 93 | # return the data, add one more dimension into the data 94 | train_example_gen = datagenerators.example_gen(train_vol_names) 95 | zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) 96 | 97 | 98 | # In this part, the code inputs the data into the model 99 | # Before this part, the model was set 100 | for step in range(1, n_iterations+1): 101 | # choose randomly one of the atlas from the atlas_list 102 | rand_num = random.randint(0, list_num-1) 103 | atlas_vol = atlas_list[rand_num] 104 | 105 | #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow 106 | X = train_example_gen.__next__()[0] 107 | train_loss = model.train_on_batch( 108 | [atlas_vol, X], [X, zero_flow]) 109 | 110 | if not isinstance(train_loss, list): 111 | train_loss = [train_loss] 112 | 113 | printLoss(step, 1, train_loss) 114 | 115 | if(step % model_save_iter == 0): 116 | model.save(model_dir + '/' + str(load_iter+step) + '.h5') 117 | 118 | 119 | def printLoss(step, training, train_loss): 120 | s = str(step) + "," + str(training) 121 | 122 | if(isinstance(train_loss, list) or isinstance(train_loss, np.ndarray)): 123 | for i in range(len(train_loss)): 124 | s += "," + str(train_loss[i]) 125 | else: 126 | s += "," + str(train_loss) 127 | 128 | print(s) 129 | sys.stdout.flush() 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | parser = ArgumentParser() 135 | parser.add_argument("--model", type=str,dest="model", 136 | choices=['vm1','vm2'],default='vm2', 137 | help="Voxelmorph-1 or 2") 138 | parser.add_argument("--gpu", type=int,default=0, 139 | dest="gpu_id", help="gpu id number") 140 | parser.add_argument("--lr", type=float, 141 | dest="lr", default=1e-4,help="learning rate") 142 | parser.add_argument("--iters", type=int, 143 | dest="n_iterations", default=15000, 144 | help="number of iterations") 145 | parser.add_argument("--lambda", type=float, 146 | dest="reg_param", default=1.0, 147 | help="regularization parameter") 148 | parser.add_argument("--checkpoint_iter", type=int, 149 | dest="model_save_iter", default=500, 150 | help="frequency of model saves") 151 | parser.add_argument("--load_iter", type=int, 152 | dest="load_iter", default=0, 153 | help="the iteratons of models to load") 154 | 155 | args = parser.parse_args() 156 | train(**vars(args)) 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/MAS4_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | # python imports 7 | import os 8 | import glob 9 | import sys 10 | import random 11 | from argparse import ArgumentParser 12 | 13 | # third-party imports 14 | import tensorflow as tf 15 | import numpy as np 16 | from keras.backend.tensorflow_backend import set_session 17 | from keras.optimizers import Adam 18 | from keras.models import load_model, Model 19 | 20 | 21 | import datagenerators 22 | import networks 23 | import losses 24 | 25 | 26 | vol_size = (160, 192, 224) 27 | # train data preparation 28 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 29 | # find all the path of .npz file in the directory 30 | # read training data 31 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 32 | # shuffle the path of .npz file 33 | # shuffle the training data 34 | random.shuffle(train_vol_names) 35 | 36 | # read the only one atlas data 37 | #atlas = np.load('../data/atlas_norm.npz') 38 | #atlas_vol = atlas['vol'] 39 | 40 | # add two more dimension into the atlas data 41 | #atlas_vol = np.reshape(atlas_vol, (1,) + atlas_vol.shape+(1,)) 42 | 43 | # atlas_list: several atlas were read 44 | atlas_file = open('../data/MAS_atlas.txt') 45 | atlas_strings = atlas_file.readlines() 46 | lenn = 4 47 | atlas_list = list() 48 | for i in range(0,lenn): 49 | st = atlas_strings[i] 50 | atlas_add = np.load(st.strip()) 51 | atlas_add = atlas_add['vol_data'] 52 | atlas_add = np.reshape(atlas_add,(1,)+atlas_add.shape+(1,)) 53 | atlas_list.append(atlas_add) 54 | 55 | # read atlas_norm as atlas used for training 56 | #atlas = np.load('../data/atlas_norm.npz') 57 | #atlas = atlas['vol'] 58 | #atlas = np.reshape(atlas,(1,)+atlas.shape+(1,)) 59 | #atlas_list.append(atlas) 60 | 61 | list_num = len(atlas_list) 62 | 63 | def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): 64 | 65 | model_dir = '/home/ys895/MAS4_Models' 66 | if not os.path.isdir(model_dir): 67 | os.mkdir(model_dir) 68 | 69 | gpu = '/gpu:' + str(gpu_id) 70 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | set_session(tf.Session(config=config)) 75 | 76 | 77 | # UNET filters 78 | nf_enc = [16,32,32,32] 79 | if(model == 'vm1'): 80 | nf_dec = [32,32,32,32,8,8,3] 81 | else: 82 | nf_dec = [32,32,32,32,32,16,16,3] 83 | 84 | with tf.device(gpu): 85 | model = networks.unet(vol_size, nf_enc, nf_dec) 86 | if(load_iter != 0): 87 | model.load_weights('/home/ys895/MAS4_Models/' + str(load_iter) + '.h5') 88 | 89 | model.compile(optimizer=Adam(lr=lr), loss=[ 90 | losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) 91 | # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') 92 | 93 | # return the data, add one more dimension into the data 94 | train_example_gen = datagenerators.example_gen(train_vol_names) 95 | zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) 96 | 97 | 98 | # In this part, the code inputs the data into the model 99 | # Before this part, the model was set 100 | for step in range(1, n_iterations+1): 101 | # choose randomly one of the atlas from the atlas_list 102 | rand_num = random.randint(0, list_num-1) 103 | atlas_vol = atlas_list[rand_num] 104 | 105 | #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow 106 | X = train_example_gen.__next__()[0] 107 | train_loss = model.train_on_batch( 108 | [atlas_vol, X], [X, zero_flow]) 109 | 110 | if not isinstance(train_loss, list): 111 | train_loss = [train_loss] 112 | 113 | printLoss(step, 1, train_loss) 114 | 115 | if(step % model_save_iter == 0): 116 | model.save(model_dir + '/' + str(load_iter+step) + '.h5') 117 | 118 | 119 | def printLoss(step, training, train_loss): 120 | s = str(step) + "," + str(training) 121 | 122 | if(isinstance(train_loss, list) or isinstance(train_loss, np.ndarray)): 123 | for i in range(len(train_loss)): 124 | s += "," + str(train_loss[i]) 125 | else: 126 | s += "," + str(train_loss) 127 | 128 | print(s) 129 | sys.stdout.flush() 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | parser = ArgumentParser() 135 | parser.add_argument("--model", type=str,dest="model", 136 | choices=['vm1','vm2'],default='vm2', 137 | help="Voxelmorph-1 or 2") 138 | parser.add_argument("--gpu", type=int,default=0, 139 | dest="gpu_id", help="gpu id number") 140 | parser.add_argument("--lr", type=float, 141 | dest="lr", default=1e-4,help="learning rate") 142 | parser.add_argument("--iters", type=int, 143 | dest="n_iterations", default=15000, 144 | help="number of iterations") 145 | parser.add_argument("--lambda", type=float, 146 | dest="reg_param", default=1.0, 147 | help="regularization parameter") 148 | parser.add_argument("--checkpoint_iter", type=int, 149 | dest="model_save_iter", default=500, 150 | help="frequency of model saves") 151 | parser.add_argument("--load_iter", type=int, 152 | dest="load_iter", default=0, 153 | help="the iteratons of models to load") 154 | 155 | args = parser.parse_args() 156 | train(**vars(args)) 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | 2 | # Third party inports 3 | import tensorflow as tf 4 | import keras.backend as K 5 | import numpy as np 6 | 7 | # batch_sizexheightxwidthxdepthxchan 8 | 9 | 10 | 11 | 12 | 13 | 14 | def diceLoss(y_true, y_pred): 15 | top = 2*tf.reduce_sum(y_true * y_pred, [1, 2, 3]) 16 | bottom = tf.maximum(tf.reduce_sum(y_true+y_pred, [1, 2, 3]), 1e-5) 17 | dice = tf.reduce_mean(top/bottom) 18 | return -dice 19 | 20 | 21 | def gradientLoss(penalty='l1'): 22 | def loss(y_true, y_pred): 23 | dy = tf.abs(y_pred[:, 1:, :, :, :] - y_pred[:, :-1, :, :, :]) 24 | dx = tf.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :]) 25 | dz = tf.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :]) 26 | 27 | if (penalty == 'l2'): 28 | dy = dy * dy 29 | dx = dx * dx 30 | dz = dz * dz 31 | d = tf.reduce_mean(dx)+tf.reduce_mean(dy)+tf.reduce_mean(dz) 32 | return d/3.0 33 | 34 | return loss 35 | 36 | 37 | def gradientLoss2D(): 38 | def loss(y_true, y_pred): 39 | dy = tf.abs(y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]) 40 | dx = tf.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]) 41 | 42 | dy = dy * dy 43 | dx = dx * dx 44 | 45 | d = tf.reduce_mean(dx)+tf.reduce_mean(dy) 46 | return d/2.0 47 | 48 | return loss 49 | 50 | 51 | def cc3D(win=[9, 9, 9], voxel_weights=None): 52 | def loss(I, J): 53 | I2 = I*I 54 | J2 = J*J 55 | IJ = I*J 56 | 57 | filt = tf.ones([win[0], win[1], win[2], 1, 1]) 58 | 59 | I_sum = tf.nn.conv3d(I, filt, [1, 1, 1, 1, 1], "SAME") 60 | J_sum = tf.nn.conv3d(J, filt, [1, 1, 1, 1, 1], "SAME") 61 | I2_sum = tf.nn.conv3d(I2, filt, [1, 1, 1, 1, 1], "SAME") 62 | J2_sum = tf.nn.conv3d(J2, filt, [1, 1, 1, 1, 1], "SAME") 63 | IJ_sum = tf.nn.conv3d(IJ, filt, [1, 1, 1, 1, 1], "SAME") 64 | 65 | win_size = win[0]*win[1]*win[2] 66 | u_I = I_sum/win_size 67 | u_J = J_sum/win_size 68 | 69 | cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size 70 | I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size 71 | J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size 72 | 73 | cc = cross*cross / (I_var*J_var+1e-5) 74 | 75 | # if(voxel_weights is not None): 76 | # cc = cc * voxel_weights 77 | 78 | return -1.0*tf.reduce_mean(cc) 79 | 80 | return loss 81 | 82 | 83 | def cc2D(win=[9, 9]): 84 | def loss(I, J): 85 | I2 = tf.multiply(I, I) 86 | J2 = tf.multiply(J, J) 87 | IJ = tf.multiply(I, J) 88 | 89 | sum_filter = tf.ones([win[0], win[1], 1, 1]) 90 | 91 | I_sum = tf.nn.conv2d(I, sum_filter, [1, 1, 1, 1], "SAME") 92 | J_sum = tf.nn.conv2d(J, sum_filter, [1, 1, 1, 1], "SAME") 93 | I2_sum = tf.nn.conv2d(I2, sum_filter, [1, 1, 1, 1], "SAME") 94 | J2_sum = tf.nn.conv2d(J2, sum_filter, [1, 1, 1, 1], "SAME") 95 | IJ_sum = tf.nn.conv2d(IJ, sum_filter, [1, 1, 1, 1], "SAME") 96 | 97 | win_size = win[0]*win[1] 98 | 99 | u_I = I_sum/win_size 100 | u_J = J_sum/win_size 101 | 102 | cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size 103 | I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size 104 | J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size 105 | 106 | cc = cross*cross / (I_var*J_var + np.finfo(float).eps) 107 | return -1.0*tf.reduce_mean(cc) 108 | return loss 109 | 110 | 111 | 112 | 113 | ## Losses for the MICCAI2018 Paper 114 | def kl_loss(alpha): 115 | def loss(_, y_pred): 116 | """ 117 | KL loss 118 | y_pred is assumed to be 6 channels: first 3 for mean, next 3 for logsigma 119 | """ 120 | mean = y_pred[..., 0:3] 121 | log_sigma = y_pred[..., 3:] 122 | 123 | # compute the degree matrix. 124 | # TODO: should only compute this once! 125 | # z = K.ones((1, ) + vol_size + (3, )) 126 | sz = log_sigma.get_shape().as_list()[1:] 127 | z = K.ones([1] + sz) 128 | 129 | filt = np.zeros((3, 3, 3, 3, 3)) 130 | for i in range(3): 131 | filt[1, 1, [0, 2], i, i] = 1 132 | filt[[0, 2], 1, 1, i, i] = 1 133 | filt[1, [0, 2], 1, i, i] = 1 134 | filt_tf = tf.convert_to_tensor(filt, dtype=tf.float32) 135 | D = tf.nn.conv3d(z, filt_tf, [1, 1, 1, 1, 1], "SAME") 136 | D = K.expand_dims(D, 0) 137 | 138 | sigma_terms = (alpha * D * tf.exp(log_sigma) - log_sigma) 139 | 140 | # note needs 0.5 twice, one here, one below 141 | prec_terms = 0.5 * alpha * kl_prec_term_manual(_, mean) 142 | kl = 0.5 * tf.reduce_mean(sigma_terms, [1, 2, 3]) + 0.5 * prec_terms 143 | return kl 144 | 145 | return loss 146 | 147 | def kl_prec_term_manual(y_true, y_pred): 148 | """ 149 | a more manual implementation of the precision matrix term 150 | P = D - A 151 | mu * P * mu 152 | where D is the degree matrix and A is the adjacency matrix 153 | mu * P * mu = sum_i mu_i sum_j (mu_i - mu_j) 154 | where j are neighbors of i 155 | """ 156 | dy = y_pred[:,1:,:,:,:] * (y_pred[:,1:,:,:,:] - y_pred[:,:-1,:,:,:]) 157 | dx = y_pred[:,:,1:,:,:] * (y_pred[:,:,1:,:,:] - y_pred[:,:,:-1,:,:]) 158 | dz = y_pred[:,:,:,1:,:] * (y_pred[:,:,:,1:,:] - y_pred[:,:,:,:-1,:]) 159 | dy2 = y_pred[:,:-1,:,:,:] * (y_pred[:,:-1,:,:,:] - y_pred[:,1:,:,:,:]) 160 | dx2 = y_pred[:,:,:-1,:,:] * (y_pred[:,:,:-1,:,:] - y_pred[:,:,1:,:,:]) 161 | dz2 = y_pred[:,:,:,:-1,:] * (y_pred[:,:,:,:-1,:] - y_pred[:,:,:,1:,:]) 162 | 163 | d = tf.reduce_mean(dx) + tf.reduce_mean(dy) + tf.reduce_mean(dz) + \ 164 | tf.reduce_mean(dy2) + tf.reduce_mean(dx2) + tf.reduce_mean(dz2) 165 | return d 166 | 167 | 168 | def kl_l2loss(image_sigma): 169 | def loss(y_true, y_pred): 170 | return 1. / (image_sigma**2) * K.mean(K.square(y_true - y_pred)) 171 | return loss -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | train atlas-based alignment with CVPR2018 version of VoxelMorph 3 | """ 4 | 5 | # python imports 6 | import os 7 | import glob 8 | import sys 9 | import random 10 | from argparse import ArgumentParser 11 | 12 | # third-party imports 13 | import tensorflow as tf 14 | import numpy as np 15 | from keras.backend.tensorflow_backend import set_session 16 | from keras.optimizers import Adam 17 | from keras.models import load_model, Model 18 | 19 | # project imports 20 | import datagenerators 21 | import networks 22 | import losses 23 | 24 | 25 | ## some data prep 26 | # Volume size used in our experiments. Please change to suit your data. 27 | vol_size = (160, 192, 224) 28 | 29 | # prepare the data 30 | # for the CVPR paper, we have data arranged in train/validate/test folders 31 | # inside each folder is a /vols/ and a /asegs/ folder with the volumes 32 | # and segmentations 33 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 34 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 35 | random.shuffle(train_vol_names) # shuffle volume list 36 | 37 | # load atlas from provided files. This atlas is 160x192x224. 38 | atlas = np.load('../data/atlas_norm.npz') 39 | atlas_vol = atlas['vol'][np.newaxis,...,np.newaxis] 40 | 41 | 42 | def train(model, model_dir, gpu_id, lr, n_iterations, reg_param, model_save_iter, batch_size=1): 43 | """ 44 | model training function 45 | :param model: either vm1 or vm2 (based on CVPR 2018 paper) 46 | :param model_dir: the model directory to save to 47 | :param gpu_id: integer specifying the gpu to use 48 | :param lr: learning rate 49 | :param n_iterations: number of training iterations 50 | :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) 51 | :param model_save_iter: frequency with which to save models 52 | :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size 53 | """ 54 | 55 | # prepare model folder 56 | if not os.path.isdir(model_dir): 57 | os.mkdir(model_dir) 58 | 59 | # GPU handling 60 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 61 | config = tf.ConfigProto() 62 | config.gpu_options.allow_growth = True 63 | config.allow_soft_placement = True 64 | set_session(tf.Session(config=config)) 65 | 66 | # UNET filters for voxelmorph-1 and voxelmorph-2, 67 | # these are architectures presented in CVPR 2018 68 | nf_enc = [16, 32, 32, 32] 69 | if model == 'vm1': 70 | nf_dec = [32, 32, 32, 32, 8, 8] 71 | else: 72 | nf_dec = [32, 32, 32, 32, 32, 16, 16] 73 | 74 | # prepare the model 75 | # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow] 76 | # in the experiments, we use image_2 as atlas 77 | model = networks.unet(vol_size, nf_enc, nf_dec) 78 | model.compile(optimizer=Adam(lr=lr), 79 | loss=[losses.cc3D(), losses.gradientLoss('l2')], 80 | loss_weights=[1.0, reg_param]) 81 | 82 | # if you'd like to initialize the data, you can do it here: 83 | # model.load_weights(os.path.join(model_dir, '120000.h5')) 84 | 85 | # prepare data for training 86 | train_example_gen = datagenerators.example_gen(train_vol_names) 87 | zero_flow = np.zeros([batch_size, *vol_size, 3]) 88 | 89 | # train. Note: we use train_on_batch and design out own print function as this has enabled 90 | # faster development and debugging, but one could also use fit_generator and Keras callbacks. 91 | for step in range(0, n_iterations): 92 | 93 | # get data 94 | X = next(train_example_gen)[0] 95 | 96 | # train 97 | train_loss = model.train_on_batch([X, atlas_vol], [atlas_vol, zero_flow]) 98 | if not isinstance(train_loss, list): 99 | train_loss = [train_loss] 100 | 101 | # print the loss. 102 | print_loss(step, 1, train_loss) 103 | 104 | # save model 105 | if step % model_save_iter == 0: 106 | model.save(os.path.join(model_dir, str(step) + '.h5')) 107 | 108 | 109 | def print_loss(step, training, train_loss): 110 | """ 111 | Prints training progress to std. out 112 | :param step: iteration number 113 | :param training: a 0/1 indicating training/testing 114 | :param train_loss: model loss at current iteration 115 | """ 116 | s = str(step) + "," + str(training) 117 | 118 | if isinstance(train_loss, list) or isinstance(train_loss, np.ndarray): 119 | for i in range(len(train_loss)): 120 | s += "," + str(train_loss[i]) 121 | else: 122 | s += "," + str(train_loss) 123 | 124 | print(s) 125 | sys.stdout.flush() 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = ArgumentParser() 130 | parser.add_argument("--model", type=str, dest="model", 131 | choices=['vm1', 'vm2'], default='vm2', 132 | help="Voxelmorph-1 or 2") 133 | parser.add_argument("--gpu", type=int, default=0, 134 | dest="gpu_id", help="gpu id number") 135 | parser.add_argument("--lr", type=float, 136 | dest="lr", default=1e-4, help="learning rate") 137 | parser.add_argument("--iters", type=int, 138 | dest="n_iterations", default=150000, 139 | help="number of iterations") 140 | parser.add_argument("--lambda", type=float, 141 | dest="reg_param", default=1.0, 142 | help="regularization parameter") 143 | parser.add_argument("--checkpoint_iter", type=int, 144 | dest="model_save_iter", default=100, 145 | help="frequency of model saves") 146 | parser.add_argument("--model_dir", type=str, 147 | dest="model_dir", default='../models/', 148 | help="models folder") 149 | 150 | args = parser.parse_args() 151 | train(**vars(args)) 152 | -------------------------------------------------------------------------------- /ext/pytools-lib/pytools/iniparse.py: -------------------------------------------------------------------------------- 1 | """ 2 | very simple ini parser and tools 3 | 4 | tested on python 3.6 5 | 6 | contact: adalca at csail.mit.edu 7 | 8 | TODO: see 9 | from collections import namedtuple 10 | instead of Struct 11 | """ 12 | 13 | # built-in modules 14 | # we'll need python's ini parser: 'configparser' 15 | import configparser 16 | 17 | def ini_to_struct(file): 18 | """ 19 | very simple ini parser that expands on configparser 20 | tries to cast values from string whereever possible 21 | parsed data ini can be accessed with 22 | 23 | data = ini_to_struct(file) 24 | value = data.section.key 25 | 26 | does not support hierarchical sections 27 | 28 | Parameters: 29 | file: string full filename of the ini file. 30 | 31 | Returns: 32 | stuct: a Struct that allows ini data to be access in the manner of data.section.key 33 | """ 34 | 35 | # read the file via config. 36 | conf = configparser.ConfigParser() 37 | confout = conf.read(file) 38 | assert len(confout) > 0, 'Cannot read file %s ' % file 39 | 40 | # prepare the Struct 41 | strct = Struct() 42 | 43 | # go through the sections in the ini file 44 | for sec in conf.sections(): 45 | 46 | # each section is its own struct 47 | secstrct = Struct() 48 | 49 | # go through the keys 50 | for key in conf[sec]: 51 | val = conf[sec][key] 52 | 53 | # try to cast the data 54 | ret, done = str_convert_single(val) 55 | 56 | # if couldn't cast, try a comma/whitespace separated list 57 | if not done: 58 | lst = str_to_list(val) 59 | 60 | # if the size of the list is 1, we didn't achieve anything 61 | if len(lst) == 1: 62 | ret = lst[0] # still not done 63 | 64 | # if we actually get a list, only keep it if we can cast its elements to something 65 | # otherwise keep the entry as an entire string 66 | else: 67 | # make sure all elements in the list convert to something 68 | done = all([str_convert_single(v)[1] for v in lst]) 69 | if done: 70 | ret = [str_convert_single(v)[0] for v in lst] 71 | 72 | # defeated, accept the entry as just a simple string... 73 | if not done: 74 | ret = val # accept string 75 | 76 | # assign secstrct.key = ret 77 | setattr(secstrct, key, ret) 78 | 79 | # assign strct.sec = secstrct 80 | setattr(strct, sec, secstrct) 81 | 82 | return strct 83 | 84 | 85 | class Struct(): 86 | """ 87 | a simple struct class to allow for the following syntax: 88 | data = Struct() 89 | data.foo = 'bar' 90 | """ 91 | 92 | def __str__(self): 93 | return self.__dict__.__str__() 94 | 95 | 96 | def str_to_none(val): 97 | """ 98 | cast a string to a None 99 | 100 | Parameters: 101 | val: the string to cast 102 | 103 | Returns: 104 | (casted_val, success) 105 | casted val: the casted value if successful, or None 106 | success: None if casting was successful 107 | """ 108 | if val == 'None': 109 | return (None, True) 110 | else: 111 | return (None, False) 112 | 113 | 114 | def str_to_type(val, ctype): 115 | """ 116 | cast a string to a type (e.g. int('8')), with try/except 117 | do *not* use for bool casting, instead see str_to_bull 118 | 119 | Parameters: 120 | val: the string to cast 121 | 122 | Returns: 123 | (casted_val, success) 124 | casted val: the casted value if successful, or None 125 | success: bool if casting was successful 126 | """ 127 | assert ctype is not bool, 'use str_to_bull() for casting to bool' 128 | 129 | ret = None 130 | success = True 131 | try: 132 | ret = ctype(val) 133 | except ValueError: 134 | success = False 135 | return (ret, success) 136 | 137 | 138 | def str_to_bool(val): 139 | """ 140 | cast a string to a bool 141 | 142 | Parameters: 143 | val: the string to cast 144 | 145 | Returns: 146 | (casted_val, success) 147 | casted val: the casted value if successful, or None 148 | success: bool if casting was successful 149 | """ 150 | if val == 'True': 151 | return (True, True) 152 | elif val == 'False': 153 | return (False, True) 154 | else: 155 | return (None, False) 156 | 157 | 158 | def str_to_list(val): 159 | """ 160 | Split a string to a list of elements, where elements are separated by whitespace or commas 161 | Leading/ending parantheses are stripped. 162 | 163 | Returns: 164 | val: the string to split 165 | 166 | Returns: 167 | casted_dst: the casted list 168 | """ 169 | val = val.replace('[', '') 170 | val = val.replace('(', '') 171 | val = val.replace(']', '') 172 | val = val.replace(')', '') 173 | 174 | if ',' in val: 175 | lst = val.split(',') 176 | else: 177 | lst = val.split() 178 | 179 | return lst 180 | 181 | 182 | def str_convert_single(val): 183 | """ 184 | try to cast a string to an int, float or bool (in that order) 185 | 186 | Parameters: 187 | val: the string to cast 188 | 189 | Returns: 190 | (casted_val, success) 191 | casted val: the casted value if successful, or None 192 | success: bool if casting was successful 193 | """ 194 | val = val.strip() 195 | # try int 196 | ret, done = str_to_type(val, int) 197 | 198 | # try float 199 | if not done: 200 | ret, done = str_to_type(val, float) 201 | 202 | # try bool 203 | if not done: 204 | ret, done = str_to_bool(val) 205 | 206 | # try None 207 | if not done: 208 | ret, done = str_to_none(val) 209 | 210 | return (ret, done) 211 | 212 | -------------------------------------------------------------------------------- /src/train_miccai2018.py: -------------------------------------------------------------------------------- 1 | """ 2 | train atlas-based alignment with MICCAI2018 version of VoxelMorph, 3 | specifically adding uncertainty estimation and diffeomorphic transforms. 4 | """ 5 | 6 | 7 | # python imports 8 | import os 9 | import glob 10 | import sys 11 | import random 12 | from argparse import ArgumentParser 13 | 14 | # third-party imports 15 | import tensorflow as tf 16 | import numpy as np 17 | from keras.backend.tensorflow_backend import set_session 18 | from keras.optimizers import Adam 19 | from keras.models import load_model, Model 20 | 21 | # project imports 22 | import datagenerators 23 | import networks 24 | import losses 25 | 26 | """ 27 | the model uses the volume data and atlas data for the training process 28 | """ 29 | 30 | ## some data prep 31 | # Volume size used in our experiments. Please change to suit your data. 32 | vol_size = (160, 192, 224) 33 | 34 | # prepare the data 35 | # for the CVPR paper, we have data arranged in train/validate/test folders 36 | # inside each folder is a /vols/ and a /asegs/ folder with the volumes 37 | # and segmentations 38 | # read the volume data from the directory 39 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 40 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 41 | random.shuffle(train_vol_names) # shuffle volume list 42 | 43 | # load atlas from provided files. This atlas is 160x192x224. 44 | atlas = np.load('../data/atlas_norm.npz') 45 | atlas_vol = atlas['vol'][np.newaxis,...,np.newaxis] 46 | 47 | 48 | def train(model_dir, gpu_id, lr, n_iterations, alpha, image_sigma, model_save_iter, batch_size=1): 49 | """ 50 | model training function 51 | :param model_dir: model folder to save to 52 | :param gpu_id: integer specifying the gpu to use 53 | :param lr: learning rate 54 | :param n_iterations: number of training iterations 55 | :param alpha: the alpha, the scalar in front of the smoothing laplacian, in MICCAI paper 56 | :param image_sigma: the image sigma in MICCAI paper 57 | :param model_save_iter: frequency with which to save models 58 | :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size 59 | """ 60 | 61 | """ 62 | preparing the model 63 | """ 64 | 65 | # prepare model folder 66 | if not os.path.isdir(model_dir): 67 | os.mkdir(model_dir) 68 | print(model_dir) 69 | 70 | # gpu handling 71 | gpu = '/gpu:' + str(gpu_id) 72 | os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_id) 73 | config = tf.ConfigProto() 74 | config.gpu_options.allow_growth = True 75 | config.allow_soft_placement = True 76 | set_session(tf.Session(config=config)) 77 | 78 | # Diffeomorphic network architecture used in MICCAI 2018 paper 79 | nf_enc = [16,32,32,32] 80 | nf_dec = [32,32,32,32,16,3] 81 | 82 | # prepare the model 83 | # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats] 84 | # in the experiments, we use image_2 as atlas 85 | with tf.device(gpu): 86 | # miccai 2018 used xy indexing. 87 | model = networks.miccai2018_net(vol_size,nf_enc,nf_dec, use_miccai_int=True, indexing='xy') 88 | 89 | # compile 90 | model_losses = [losses.kl_l2loss(image_sigma), losses.kl_loss(alpha)] 91 | model.compile(optimizer=Adam(lr=lr), loss=model_losses) 92 | 93 | # save first iteration 94 | model.save(os.path.join(model_dir, str(0) + '.h5')) 95 | 96 | train_example_gen = datagenerators.example_gen(train_vol_names) 97 | zeros = np.zeros((1, *vol_size, 3)) 98 | 99 | """ 100 | training process 101 | """ 102 | 103 | # train. Note: we use train_on_batch and design out own print function as this has enabled 104 | # faster development and debugging, but one could also use fit_generator and Keras callbacks. 105 | for step in range(1, n_iterations): 106 | 107 | # get_data 108 | X = next(train_example_gen)[0] 109 | 110 | # train 111 | # train with the raw images and the warped image by the deformation field 112 | with tf.device(gpu): 113 | train_loss = model.train_on_batch([X,atlas_vol], [atlas_vol, zeros]) 114 | 115 | if not isinstance(train_loss,list): 116 | train_loss = [train_loss] 117 | 118 | # print 119 | print_loss(step, 0, train_loss) 120 | 121 | # save model 122 | with tf.device(gpu): 123 | if (step % model_save_iter == 0) or step < 10: 124 | model.save(os.path.join(model_dir, str(step) + '.h5')) 125 | 126 | 127 | def print_loss(step, training, train_loss): 128 | """ 129 | Prints training progress to std. out 130 | :param step: iteration number 131 | :param training: a 0/1 indicating training/testing 132 | :param train_loss: model loss at current iteration 133 | """ 134 | s = str(step) + "," + str(training) 135 | 136 | if isinstance(train_loss, list) or isinstance(train_loss, np.ndarray): 137 | for i in range(len(train_loss)): 138 | s += "," + str(train_loss[i]) 139 | else: 140 | s += "," + str(train_loss) 141 | 142 | print(s) 143 | sys.stdout.flush() 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = ArgumentParser() 148 | parser.add_argument("--model_dir", type=str, 149 | dest="model_dir", default='../models/', 150 | help="models folder") 151 | parser.add_argument("--gpu", type=int, default=0, 152 | dest="gpu_id", help="gpu id number") 153 | parser.add_argument("--lr", type=float, 154 | dest="lr", default=1e-4, help="learning rate") 155 | parser.add_argument("--iters", type=int, 156 | dest="n_iterations", default=150000, 157 | help="number of iterations") 158 | parser.add_argument("--alpha", type=float, 159 | dest="alpha", default=70000/128, 160 | help="alpha regularization parameter") 161 | parser.add_argument("--image_sigma", type=float, 162 | dest="image_sigma", default=0.05, 163 | help="image noise parameter") 164 | parser.add_argument("--checkpoint_iter", type=int, 165 | dest="model_save_iter", default=100, 166 | help="frequency of model saves") 167 | 168 | args = parser.parse_args() 169 | train(**vars(args)) 170 | -------------------------------------------------------------------------------- /data/allcsv.txt: -------------------------------------------------------------------------------- 1 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/981102_vc604.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/981102_vc604.npz 2 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/981112_vc623.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/981112_vc623.npz 3 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/981204_vc660.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/981204_vc660.npz 4 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/981216_vc681.npz 5 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990104_vc700.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990104_vc700.npz 6 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990111_vc716.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990111_vc716.npz 7 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990114_vc722.npz 8 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990114_vc723.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990114_vc723.npz 9 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990119_vc740.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990119_vc740.npz 10 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990121_vc747.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990121_vc747.npz 11 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990128_vc764.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990128_vc764.npz 12 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990205_vc783.npz 13 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990210_vc792.npz 14 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990211_vc799.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990211_vc799.npz 15 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990215_vc803.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990215_vc803.npz 16 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990217_vc809.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990217_vc809.npz 17 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990317_vc876.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990317_vc876.npz 18 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990326_vc891.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990326_vc891.npz 19 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990405_vc922.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990405_vc922.npz 20 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990525_vc1024.npz 21 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990715_vc1131.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990715_vc1131.npz 22 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990729_vc1168.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990729_vc1168.npz 23 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990730_vc1172.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990730_vc1172.npz 24 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990902_vc1249.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990902_vc1249.npz 25 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990903_vc1253.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990903_vc1253.npz 26 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/990921_vc1289.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/990921_vc1289.npz 27 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991006_vc1337.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991006_vc1337.npz 28 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991025_vc1379.npz 29 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991102_vc1401.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991102_vc1401.npz 30 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991109_vc1420.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991109_vc1420.npz 31 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991109_vc1423.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991109_vc1423.npz 32 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991110_vc1425.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991110_vc1425.npz 33 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991113_vc1439.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991113_vc1439.npz 34 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991113_vc1440.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991113_vc1440.npz 35 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991120_vc1456.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991120_vc1456.npz 36 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991122_vc1463.npz 37 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1465.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991122_vc1465.npz 38 | /data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1479.npz,/data/ddmg/voxelmorph/data/buckner/proc/resize256-crop_x32/FromEugenio_prep/segs_edited/991122_vc1479.npz 39 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Networks for voxelmorph model 3 | 4 | In general, these are fairly specific architectures that were designed for the presented papers. 5 | However, the VoxelMorph concepts are not tied to a very particular architecture, and we 6 | encourage you to explore architectures that fit your needs. 7 | see e.g. more powerful unet function in https://github.com/adalca/neuron/blob/master/neuron/models.py 8 | """ 9 | # main imports 10 | import sys 11 | 12 | # third party 13 | import numpy as np 14 | import keras.backend as K 15 | from keras.models import Model 16 | from keras.layers import Conv3D, Activation, Input, UpSampling3D, concatenate 17 | from keras.layers import LeakyReLU, Reshape, Lambda 18 | from keras.initializers import RandomNormal 19 | import keras.initializers 20 | import tensorflow as tf 21 | 22 | # import neuron layers, which will be useful for Transforming. 23 | sys.path.append('../ext/neuron') 24 | sys.path.append('../ext/pynd-lib') 25 | sys.path.append('../ext/pytools-lib') 26 | import neuron.layers as nrn_layers 27 | import neuron.utils as nrn_utils 28 | 29 | # other vm functions 30 | import losses 31 | 32 | 33 | def unet_core(vol_size, enc_nf, dec_nf, full_size=True): 34 | """ 35 | unet architecture for voxelmorph models presented in the CVPR 2018 paper. 36 | You may need to modify this code (e.g., number of layers) to suit your project needs. 37 | 38 | :param vol_size: volume size. e.g. (256, 256, 256) 39 | :param enc_nf: list of encoder filters. right now it needs to be 1x4. 40 | e.g. [16,32,32,32] 41 | :param dec_nf: list of decoder filters. right now it must be 1x6 (like voxelmorph-1) or 1x7 (voxelmorph-2) 42 | :return: the keras model 43 | """ 44 | # inputs 45 | src = Input(shape=vol_size + (1,)) 46 | tgt = Input(shape=vol_size + (1,)) 47 | x_in = concatenate([src, tgt]) 48 | 49 | # down-sample path (encoder) 50 | x_enc = [x_in] 51 | for i in range(len(enc_nf)): 52 | x_enc.append(conv_block(x_enc[-1], enc_nf[i], 2)) 53 | 54 | # up-sample path (decoder) 55 | x = conv_block(x_enc[-1], dec_nf[0]) 56 | x = UpSampling3D()(x) 57 | x = concatenate([x, x_enc[-2]]) 58 | x = conv_block(x, dec_nf[1]) 59 | x = UpSampling3D()(x) 60 | x = concatenate([x, x_enc[-3]]) 61 | x = conv_block(x, dec_nf[2]) 62 | x = UpSampling3D()(x) 63 | x = concatenate([x, x_enc[-4]]) 64 | x = conv_block(x, dec_nf[3]) 65 | x = conv_block(x, dec_nf[4]) 66 | 67 | # only upsampleto full dim if full_size 68 | # here we explore architectures where we essentially work with flow fields 69 | # that are 1/2 size 70 | if full_size: 71 | x = UpSampling3D()(x) 72 | x = concatenate([x, x_enc[0]]) 73 | x = conv_block(x, dec_nf[5]) 74 | 75 | # optional convolution at output resolution (used in voxelmorph-2) 76 | if len(dec_nf) == 7: 77 | x = conv_block(x, dec_nf[6]) 78 | 79 | return Model(inputs=[src, tgt], outputs=[x]) 80 | 81 | 82 | 83 | def unet(vol_size, enc_nf, dec_nf, full_size=True): 84 | """ 85 | unet architecture for voxelmorph models presented in the CVPR 2018 paper. 86 | You may need to modify this code (e.g., number of layers) to suit your project needs. 87 | 88 | :param vol_size: volume size. e.g. (256, 256, 256) 89 | :param enc_nf: list of encoder filters. right now it needs to be 1x4. 90 | e.g. [16,32,32,32] 91 | :param dec_nf: list of decoder filters. right now it must be 1x6 (like voxelmorph-1) or 1x7 (voxelmorph-2) 92 | :return: the keras model 93 | """ 94 | 95 | # get the core model 96 | unet_model = unet_core(vol_size, enc_nf, dec_nf, full_size=full_size) 97 | [src, tgt] = unet_model.inputs 98 | x = unet_model.output 99 | 100 | # transform the results into a flow field. 101 | # make warped src into the flow 102 | flow = Conv3D(3, kernel_size=3, padding='same', 103 | kernel_initializer=RandomNormal(mean=0.0, stddev=1e-5), name='flow')(x) 104 | 105 | 106 | # warp the source with the flow 107 | # y is the image warped with flow 108 | y = nrn_layers.SpatialTransformer(interp_method='linear', indexing='xy')([src, flow]) 109 | # prepare model 110 | # the whole unet returns the warped image and the flow used to warp it 111 | model = Model(inputs=[src, tgt], outputs=[y, flow]) 112 | return model 113 | 114 | 115 | 116 | """ 117 | return warped image and flow param 118 | """ 119 | def miccai2018_net(vol_size, enc_nf, dec_nf, use_miccai_int=True, int_steps=7, indexing='xy'): 120 | """ 121 | architecture for probabilistic diffeomoprhic VoxelMorph presented in the MICCAI 2018 paper. 122 | You may need to modify this code (e.g., number of layers) to suit your project needs. 123 | 124 | The stationary velocity field operates in a space (0.5)^3 of vol_size for computational reasons. 125 | 126 | :param vol_size: volume size. e.g. (256, 256, 256) 127 | :param enc_nf: list of encoder filters. right now it needs to be 1x4. 128 | e.g. [16,32,32,32] 129 | :param dec_nf: list of decoder filters. right now it must be 1x6, see unet function. 130 | :param use_miccai_int: whether to use the manual miccai implementation of scaling and squaring integration 131 | note that the 'velocity' field outputted in that case was 132 | since then we've updated the code to be part of a flexible layer. see neuron.layers.VecInt 133 | :param int_steps: the number of integration steps 134 | :param indexing: xy or ij indexing. we recommend ij indexing if training from scratch. 135 | miccai 2018 runs were done with xy indexing. 136 | :return: the keras model 137 | """ 138 | 139 | # get unet 140 | unet_model = unet_core(vol_size, enc_nf, dec_nf, full_size=False) 141 | [src,tgt] = unet_model.inputs 142 | x_out = unet_model.outputs[-1] 143 | 144 | # velocity mean and logsigma layers 145 | flow_mean = Conv3D(3, kernel_size=3, padding='same', 146 | kernel_initializer=RandomNormal(mean=0.0, stddev=1e-5), name='flow')(x_out) 147 | 148 | flow_log_sigma = Conv3D(3, kernel_size=3, padding='same', 149 | kernel_initializer=RandomNormal(mean=0.0, stddev=1e-10), 150 | bias_initializer=keras.initializers.Constant(value=-10), name='log_sigma')(x_out) 151 | flow_params = concatenate([flow_mean, flow_log_sigma]) 152 | 153 | # velocity sample 154 | flow = Lambda(sample, name="z_sample")([flow_mean, flow_log_sigma]) 155 | 156 | # integrate if diffeomorphic (i.e. treating 'flow' above as stationary velocity field) 157 | if use_miccai_int: 158 | # for the miccai2018 submission, the scaling and squaring layer 159 | # was manually composed of a Transform and and Add Layer. 160 | flow = Lambda(lambda x: x, name='flow-fix')(flow) # remanant of old code 161 | v = flow 162 | for _ in range(int_steps): 163 | v1 = nrn_layers.SpatialTransformer(interp_method='linear', indexing=indexing)([v, v]) 164 | v = keras.layers.add([v, v1]) 165 | flow = v 166 | 167 | else: 168 | # new implementation in neuron is cleaner. 169 | # the 2**int_steps is a correcting factor left over from the miccai implementation. 170 | # * (2**int_steps) 171 | flow = Lambda(lambda x: x, name='flow-fix')(flow) 172 | flow = nrn_layers.VecInt(method='ss', name='flow-int', int_steps=7)(flow) 173 | 174 | # get up to final resolution 175 | flow = Lambda(interp_upsampling, output_shape=vol_size+(3,), name='pre_diffflow')(flow) 176 | flow = Lambda(lambda arg: arg*2, name='diffflow')(flow) 177 | 178 | """ 179 | spatial transform 180 | """ 181 | 182 | # transform 183 | y = nrn_layers.SpatialTransformer(interp_method='linear', indexing=indexing)([src, flow]) 184 | 185 | # prepare outputs and losses 186 | # return y, which is the warped image of src with flow 187 | outputs = [y, flow_params] 188 | 189 | # build the model 190 | return Model(inputs=[src, tgt], outputs=outputs) 191 | 192 | # return a model used to compute spatial tramsform 193 | def nn_trf(vol_size): 194 | """ 195 | Simple transform model for nearest-neighbor based transformation 196 | Note: this is essentially a wrapper for the neuron.utils.transform(..., interp_method='nearest') 197 | """ 198 | ndims = len(vol_size) 199 | 200 | # nn warp model 201 | subj_input = Input((*vol_size, 1), name='subj_input') 202 | trf_input = Input((*vol_size, ndims) , name='trf_input') 203 | 204 | # note the nearest neighbour interpolation method 205 | # note xy indexing because Guha's original code switched x and y dimensions 206 | nn_output = nrn_layers.SpatialTransformer(interp_method='nearest', indexing='xy') 207 | nn_spatial_output = nn_output([subj_input, trf_input]) 208 | # return a model used to compute spatial tramsform 209 | return keras.models.Model([subj_input, trf_input], nn_spatial_output) 210 | 211 | 212 | # Helper functions 213 | def conv_block(x_in, nf, strides=1): 214 | """ 215 | specific convolution module including convolution followed by leakyrelu 216 | """ 217 | 218 | x_out = Conv3D(nf, kernel_size=3, padding='same', 219 | kernel_initializer='he_normal', strides=strides)(x_in) 220 | x_out = LeakyReLU(0.2)(x_out) 221 | return x_out 222 | 223 | 224 | def sample(args): 225 | """ 226 | sample from a normal distribution 227 | """ 228 | mu = args[0] 229 | log_sigma = args[1] 230 | noise = tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32) 231 | z = mu + tf.exp(log_sigma/2.0) * noise 232 | return z 233 | 234 | def interp_upsampling(V): 235 | """ 236 | upsample a field by a factor of 2 237 | TODO: should switch this to use neuron.utils.interpn() 238 | """ 239 | 240 | [xx, yy, zz] = nrn_utils.volshape_to_ndgrid([f*2 for f in V.get_shape().as_list()[1:4]]) 241 | xx = tf.cast(xx, 'float32') 242 | yy = tf.cast(yy, 'float32') 243 | zz = tf.cast(zz, 'float32') 244 | xx = tf.expand_dims(xx/2-xx, 0) 245 | yy = tf.expand_dims(yy/2-yy, 0) 246 | zz = tf.expand_dims(zz/2-zz, 0) 247 | offset = tf.stack([xx, yy, zz], 4) 248 | 249 | # V = nrn_utils.transform(V, offset) 250 | V = nrn_layers.SpatialTransformer(interp_method='linear')([V, offset]) 251 | 252 | return V 253 | 254 | -------------------------------------------------------------------------------- /src/SAS_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | single atlas segmentation based on Voxelmorph and Neuron 3 | train the model to warp the atlas onto the volume, then the volume is labeled 4 | """ 5 | ''' 6 | # python imports 7 | import os 8 | import glob 9 | import sys 10 | import random 11 | from argparse import ArgumentParser 12 | 13 | # third-party imports 14 | import tensorflow as tf 15 | import numpy as np 16 | from keras.backend.tensorflow_backend import set_session 17 | from keras.optimizers import Adam 18 | from keras.models import load_model, Model 19 | 20 | 21 | import datagenerators 22 | import networks 23 | import losses 24 | 25 | 26 | vol_size = (160, 192, 224) 27 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 28 | #find all the path of .npz file in the directory 29 | #read training data 30 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 31 | #shuffle the path of .npz file 32 | #shuffle the training data 33 | random.shuffle(train_vol_names) 34 | 35 | #read atlas data 36 | atlas = np.load('../data/atlas_norm.npz') 37 | atlas_vol = atlas['vol'] 38 | #add two more dimension into the atlas data 39 | atlas_vol = np.reshape(atlas_vol, (1,) + atlas_vol.shape+(1,)) 40 | 41 | def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): 42 | 43 | model_dir = '/home/ys895/SAS_Models' 44 | if not os.path.isdir(model_dir): 45 | os.mkdir(model_dir) 46 | 47 | gpu = '/gpu:' + str(gpu_id) 48 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 49 | config = tf.ConfigProto() 50 | config.gpu_options.allow_growth = True 51 | config.allow_soft_placement = True 52 | set_session(tf.Session(config=config)) 53 | 54 | 55 | # UNET filters 56 | nf_enc = [16,32,32,32] 57 | if(model == 'vm1'): 58 | nf_dec = [32,32,32,32,8,8,3] 59 | else: 60 | nf_dec = [32,32,32,32,32,16,16,3] 61 | 62 | with tf.device(gpu): 63 | model = networks.unet(vol_size, nf_enc, nf_dec) 64 | if(load_iter != 0): 65 | net.load_weights('/home/ys895/SAS_Models/' + str(load_iter) + '.h5') 66 | 67 | model.compile(optimizer=Adam(lr=lr), loss=[ 68 | losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) 69 | # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') 70 | 71 | # return the data, add one more dimension into the data 72 | train_example_gen = datagenerators.example_gen(train_vol_names) 73 | zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) 74 | 75 | 76 | # In this part, the code inputs the data into the model 77 | # Before this part, the model was set 78 | for step in range(1, n_iterations+1): 79 | 80 | #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow 81 | X = train_example_gen.__next__()[0] 82 | train_loss = model.train_on_batch( 83 | [atlas_vol, X], [X, zero_flow]) 84 | 85 | if not isinstance(train_loss, list): 86 | train_loss = [train_loss] 87 | 88 | printLoss(step, 1, train_loss) 89 | 90 | if(step % model_save_iter == 0): 91 | model.save(model_dir + '/' + str(load_iter+step) + '.h5') 92 | 93 | 94 | def printLoss(step, training, train_loss): 95 | s = str(step) + "," + str(training) 96 | 97 | if(isinstance(train_loss, list) or isinstance(train_loss, np.ndarray)): 98 | for i in range(len(train_loss)): 99 | s += "," + str(train_loss[i]) 100 | else: 101 | s += "," + str(train_loss) 102 | 103 | print(s) 104 | sys.stdout.flush() 105 | 106 | 107 | if __name__ == "__main__": 108 | 109 | parser = ArgumentParser() 110 | parser.add_argument("--model", type=str,dest="model", 111 | choices=['vm1','vm2'],default='vm2', 112 | help="Voxelmorph-1 or 2") 113 | parser.add_argument("--gpu", type=int,default=0, 114 | dest="gpu_id", help="gpu id number") 115 | parser.add_argument("--lr", type=float, 116 | dest="lr", default=1e-4,help="learning rate") 117 | parser.add_argument("--iters", type=int, 118 | dest="n_iterations", default=15000, 119 | help="number of iterations") 120 | parser.add_argument("--lambda", type=float, 121 | dest="reg_param", default=1.0, 122 | help="regularization parameter") 123 | parser.add_argument("--checkpoint_iter", type=int, 124 | dest="model_save_iter", default=500, 125 | help="frequency of model saves") 126 | parser.add_argument("--load_iter", type=int, 127 | dest="load_iter", default=0, 128 | help="the iteratons of models to load") 129 | 130 | args = parser.parse_args() 131 | train(**vars(args)) 132 | 133 | ''' 134 | 135 | 136 | 137 | """ 138 | multi atlas segmentation based on Voxelmorph and Neuron 139 | 140 | """ 141 | 142 | # python imports 143 | import os 144 | import glob 145 | import sys 146 | import random 147 | from argparse import ArgumentParser 148 | 149 | # third-party imports 150 | import tensorflow as tf 151 | import numpy as np 152 | from keras.backend.tensorflow_backend import set_session 153 | from keras.optimizers import Adam 154 | from keras.models import load_model, Model 155 | 156 | 157 | import datagenerators 158 | import networks 159 | import losses 160 | 161 | 162 | vol_size = (160, 192, 224) 163 | # train data preparation 164 | base_data_dir = '/home/ys895/resize256/resize256-crop_x32/' 165 | # find all the path of .npz file in the directory 166 | # read training data 167 | train_vol_names = glob.glob(base_data_dir + 'train/vols/*.npz') 168 | # shuffle the path of .npz file 169 | # shuffle the training data 170 | random.shuffle(train_vol_names) 171 | 172 | # read the only one atlas data 173 | #atlas = np.load('../data/atlas_norm.npz') 174 | #atlas_vol = atlas['vol'] 175 | 176 | # add two more dimension into the atlas data 177 | #atlas_vol = np.reshape(atlas_vol, (1,) + atlas_vol.shape+(1,)) 178 | 179 | # atlas_list: several atlas were read 180 | atlas_file = open('../data/MAS_atlas.txt') 181 | atlas_strings = atlas_file.readlines() 182 | lenn = 1 183 | atlas_list = list() 184 | for i in range(0,lenn): 185 | st = atlas_strings[i] 186 | atlas_add = np.load(st.strip()) 187 | atlas_add = atlas_add['vol_data'] 188 | atlas_add = np.reshape(atlas_add,(1,)+atlas_add.shape+(1,)) 189 | atlas_list.append(atlas_add) 190 | 191 | # read atlas_norm as atlas used for training 192 | #atlas = np.load('../data/atlas_norm.npz') 193 | #atlas = atlas['vol'] 194 | #atlas = np.reshape(atlas,(1,)+atlas.shape+(1,)) 195 | #atlas_list.append(atlas) 196 | 197 | list_num = len(atlas_list) 198 | 199 | def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): 200 | 201 | model_dir = '/home/ys895/SAS_Models' 202 | if not os.path.isdir(model_dir): 203 | os.mkdir(model_dir) 204 | 205 | gpu = '/gpu:' + str(gpu_id) 206 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 207 | config = tf.ConfigProto() 208 | config.gpu_options.allow_growth = True 209 | config.allow_soft_placement = True 210 | set_session(tf.Session(config=config)) 211 | 212 | 213 | # UNET filters 214 | nf_enc = [16,32,32,32] 215 | if(model == 'vm1'): 216 | nf_dec = [32,32,32,32,8,8,3] 217 | else: 218 | nf_dec = [32,32,32,32,32,16,16,3] 219 | 220 | with tf.device(gpu): 221 | model = networks.unet(vol_size, nf_enc, nf_dec) 222 | if(load_iter != 0): 223 | model.load_weights('/home/ys895/SAS_Models/' + str(load_iter) + '.h5') 224 | 225 | model.compile(optimizer=Adam(lr=lr), loss=[ 226 | losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) 227 | # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') 228 | 229 | # return the data, add one more dimension into the data 230 | train_example_gen = datagenerators.example_gen(train_vol_names) 231 | zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) 232 | 233 | 234 | # In this part, the code inputs the data into the model 235 | # Before this part, the model was set 236 | for step in range(1, n_iterations+1): 237 | # choose randomly one of the atlas from the atlas_list 238 | rand_num = random.randint(0, list_num-1) 239 | atlas_vol = atlas_list[rand_num] 240 | 241 | #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow 242 | X = train_example_gen.__next__()[0] 243 | train_loss = model.train_on_batch( 244 | [atlas_vol, X], [X, zero_flow]) 245 | 246 | if not isinstance(train_loss, list): 247 | train_loss = [train_loss] 248 | 249 | printLoss(step, 1, train_loss) 250 | 251 | if(step % model_save_iter == 0): 252 | model.save(model_dir + '/' + str(load_iter+step) + '.h5') 253 | 254 | 255 | def printLoss(step, training, train_loss): 256 | s = str(step) + "," + str(training) 257 | 258 | if(isinstance(train_loss, list) or isinstance(train_loss, np.ndarray)): 259 | for i in range(len(train_loss)): 260 | s += "," + str(train_loss[i]) 261 | else: 262 | s += "," + str(train_loss) 263 | 264 | print(s) 265 | sys.stdout.flush() 266 | 267 | 268 | if __name__ == "__main__": 269 | 270 | parser = ArgumentParser() 271 | parser.add_argument("--model", type=str,dest="model", 272 | choices=['vm1','vm2'],default='vm2', 273 | help="Voxelmorph-1 or 2") 274 | parser.add_argument("--gpu", type=int,default=0, 275 | dest="gpu_id", help="gpu id number") 276 | parser.add_argument("--lr", type=float, 277 | dest="lr", default=1e-4,help="learning rate") 278 | parser.add_argument("--iters", type=int, 279 | dest="n_iterations", default=15000, 280 | help="number of iterations") 281 | parser.add_argument("--lambda", type=float, 282 | dest="reg_param", default=1.0, 283 | help="regularization parameter") 284 | parser.add_argument("--checkpoint_iter", type=int, 285 | dest="model_save_iter", default=500, 286 | help="frequency of model saves") 287 | parser.add_argument("--load_iter", type=int, 288 | dest="load_iter", default=0, 289 | help="the iteratons of models to load") 290 | 291 | args = parser.parse_args() 292 | train(**vars(args)) 293 | 294 | 295 | -------------------------------------------------------------------------------- /src/SAS_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | the test for single atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | ''' 7 | import os 8 | import sys 9 | import glob 10 | 11 | # third party 12 | import tensorflow as tf 13 | import scipy.io as sio 14 | import numpy as np 15 | from keras.backend.tensorflow_backend import set_session 16 | from scipy.interpolate import interpn 17 | 18 | # project 19 | sys.path.append('../ext/medipy-lib') 20 | sys.path.append('../ext/neuron') 21 | sys.path.append('../ext/pynd-lib') 22 | sys.path.append('../ext/pytools-lib') 23 | 24 | import medipy 25 | import networks 26 | from medipy.metrics import dice 27 | import datagenerators 28 | import neuron as nu 29 | 30 | 31 | def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 32 | gpu = '/gpu:' + str(gpu_id) 33 | 34 | # Anatomical labels we want to evaluate 35 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 36 | 37 | atlas = np.load('../data/atlas_norm.npz') 38 | atlas_vol = atlas['vol'] 39 | print('the size of atlas:') 40 | print(atlas_vol.shape) 41 | atlas_seg = atlas['seg'] 42 | atlas_vol = np.reshape(atlas_vol, (1,)+atlas_vol.shape+(1,)) 43 | 44 | # read atlas data 45 | 46 | #gpu = '/gpu:' + str(gpu_id) 47 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 48 | #config = tf.ConfigProto() 49 | #config.gpu_options.allow_growth = True 50 | #config.allow_soft_placement = True 51 | #set_session(tf.Session(config=config)) 52 | 53 | # load weights of model 54 | with tf.device(gpu): 55 | net = networks.unet(vol_size, nf_enc, nf_dec) 56 | net.load_weights('/home/ys895/SAS_Models/'+str(iter_num)+'.h5') 57 | #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') 58 | 59 | xx = np.arange(vol_size[1]) 60 | yy = np.arange(vol_size[0]) 61 | zz = np.arange(vol_size[2]) 62 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3) 63 | X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 64 | 65 | # change the direction of the atlas data and volume data 66 | # pred[0].shape (1, 160, 192, 224, 1) 67 | # pred[1].shape (1, 160, 192, 224, 3) 68 | with tf.device(gpu): 69 | pred = net.predict([atlas_vol, X_vol]) 70 | # Warp segments with flow 71 | flow = pred[1][0, :, :, :, :] # (1, 160, 192, 224, 3) 72 | sample = flow+grid 73 | sample = np.stack((sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3) 74 | warp_seg = interpn((yy, xx, zz), atlas_seg[ :, :, : ], sample, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 75 | vals, _ = dice(warp_seg, X_seg[0,:,:,:,0], labels=labels, nargout=2) 76 | print(np.mean(vals), np.std(vals)) 77 | 78 | 79 | # plot the outcome of warp seg 80 | #warp_seg = warp_seg.reshape((warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 81 | #warp_seg2 = np.empty(shape = (warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 82 | #for i in range(0,warp_seg.shape[1]): 83 | # warp_seg2[i,:,:] = np.transpose(warp_seg[:,i,:]) 84 | #nu.plot.slices(warp_seg) 85 | 86 | 87 | if __name__ == "__main__": 88 | test(sys.argv[1], sys.argv[2]) 89 | 90 | ''' 91 | 92 | import os 93 | import sys 94 | import glob 95 | 96 | # third party 97 | import tensorflow as tf 98 | import scipy.io as sio 99 | import numpy as np 100 | from scipy import stats 101 | from keras.backend.tensorflow_backend import set_session 102 | from scipy.interpolate import interpn 103 | 104 | # project 105 | sys.path.append('../ext/medipy-lib') 106 | sys.path.append('../ext/neuron') 107 | sys.path.append('../ext/pynd-lib') 108 | sys.path.append('../ext/pytools-lib') 109 | 110 | import medipy 111 | import networks 112 | from medipy.metrics import dice 113 | import datagenerators 114 | import neuron as nu 115 | 116 | 117 | def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 118 | gpu = '/gpu:' + str(gpu_id) 119 | 120 | # Anatomical labels we want to evaluate 121 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 122 | 123 | # read atlas 124 | atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz', 125 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz')# [1,160,192,224,1] 126 | atlas_seg1 = atlas_seg1[0,:,:,:,0]# reduce the dimension to [160,192,224] 127 | 128 | atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz', 129 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz') 130 | atlas_seg2 = atlas_seg2[0, :, :, :, 0] 131 | 132 | #gpu = '/gpu:' + str(gpu_id) 133 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 134 | config = tf.ConfigProto() 135 | config.gpu_options.allow_growth = True 136 | config.allow_soft_placement = True 137 | set_session(tf.Session(config=config)) 138 | 139 | # load weights of model 140 | with tf.device(gpu): 141 | net = networks.unet(vol_size, nf_enc, nf_dec) 142 | net.load_weights('/home/ys895/MAS2_Models/'+str(iter_num)+'.h5') 143 | #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') 144 | 145 | xx = np.arange(vol_size[1]) 146 | yy = np.arange(vol_size[0]) 147 | zz = np.arange(vol_size[2]) 148 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3) 149 | #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 150 | X_vol1, X_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz', 151 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz') 152 | 153 | X_vol2, X_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz', 154 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz') 155 | 156 | X_vol3, X_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz', 157 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz') 158 | 159 | X_vol4, X_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz', 160 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz') 161 | 162 | X_vol5, X_seg5 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz', 163 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz') 164 | 165 | # change the direction of the atlas data and volume data 166 | # pred[0].shape (1, 160, 192, 224, 1) 167 | # pred[1].shape (1, 160, 192, 224, 3) 168 | # X1 169 | with tf.device(gpu): 170 | pred1 = net.predict([atlas_vol1, X_vol1]) 171 | 172 | # Warp segments with flow 173 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 174 | 175 | sample1 = flow1+grid 176 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 177 | 178 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 179 | 180 | 181 | 182 | # label fusion: get the final warp_seg 183 | warp_seg = np.empty((160, 192, 224)) 184 | for x in range(0,160): 185 | for y in range(0,192): 186 | for z in range(0,224): 187 | warp_seg = np.array(warp_seg1[x, y, z]) 188 | 189 | vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2) 190 | mean1 = np.mean(vals) 191 | 192 | # X2 193 | with tf.device(gpu): 194 | pred1 = net.predict([atlas_vol1, X_vol2]) 195 | 196 | # Warp segments with flow 197 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 198 | 199 | sample1 = flow1+grid 200 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 201 | 202 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 203 | 204 | 205 | 206 | # label fusion: get the final warp_seg 207 | warp_seg = np.empty((160, 192, 224)) 208 | for x in range(0,160): 209 | for y in range(0,192): 210 | for z in range(0,224): 211 | warp_seg = np.array(warp_seg1[x, y, z]) 212 | 213 | vals, _ = dice(warp_seg, X_seg2[0,:,:,:,0], labels=labels, nargout=2) 214 | mean2 = np.mean(vals) 215 | #print(np.mean(vals), np.std(vals)) 216 | 217 | # X3 218 | with tf.device(gpu): 219 | pred1 = net.predict([atlas_vol1, X_vol3]) 220 | 221 | # Warp segments with flow 222 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 223 | 224 | sample1 = flow1+grid 225 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 226 | 227 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 228 | 229 | 230 | 231 | # label fusion: get the final warp_seg 232 | warp_seg = np.empty((160, 192, 224)) 233 | for x in range(0,160): 234 | for y in range(0,192): 235 | for z in range(0,224): 236 | warp_seg = np.array(warp_seg1[x, y, z]) 237 | 238 | vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2) 239 | mean3 = np.mean(vals) 240 | 241 | # X4 242 | with tf.device(gpu): 243 | pred1 = net.predict([atlas_vol1, X_vol4]) 244 | 245 | # Warp segments with flow 246 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 247 | 248 | sample1 = flow1+grid 249 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 250 | 251 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 252 | 253 | 254 | # label fusion: get the final warp_seg 255 | warp_seg = np.empty((160, 192, 224)) 256 | for x in range(0,160): 257 | for y in range(0,192): 258 | for z in range(0,224): 259 | warp_seg = np.array(warp_seg1[x, y, z]) 260 | 261 | vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2) 262 | mean4 = np.mean(vals) 263 | 264 | 265 | # X5 266 | with tf.device(gpu): 267 | pred1 = net.predict([atlas_vol1, X_vol5]) 268 | 269 | # Warp segments with flow 270 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 271 | 272 | sample1 = flow1+grid 273 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 274 | 275 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 276 | 277 | 278 | # label fusion: get the final warp_seg 279 | warp_seg = np.empty((160, 192, 224)) 280 | for x in range(0,160): 281 | for y in range(0,192): 282 | for z in range(0,224): 283 | warp_seg = np.array(warp_seg1[x,y,z]) 284 | #print(warp_arr) 285 | #warp_seg[x,y,z] = stats.mode(warp_arr)[0] 286 | 287 | vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2) 288 | mean5 = np.mean(vals) 289 | 290 | # compute mean of dice score 291 | sum = mean1 + mean2 + mean3 + mean4 + mean5 292 | mean_dice = sum/5 293 | print(mean_dice) 294 | 295 | # plot the outcome of warp seg 296 | #warp_seg = warp_seg.reshape((warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 297 | #warp_seg2 = np.empty(shape = (warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 298 | #for i in range(0,warp_seg.shape[1]): 299 | # warp_seg2[i,:,:] = np.transpose(warp_seg[:,i,:]) 300 | #nu.plot.slices(warp_seg) 301 | 302 | 303 | if __name__ == "__main__": 304 | #result_list = np.empty((1000,1)) 305 | #for i in range(0,35): 306 | # iterr = (i+1)*200 307 | # result_list[i,0] = test(iterr,sys.argv[1])[0] 308 | #print(result_list) 309 | test(sys.argv[1], sys.argv[2]) 310 | -------------------------------------------------------------------------------- /ext/pynd-lib/pynd/ndutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for nd (n-dimensional) arrays 3 | Tested on Python 3.5 4 | 5 | Contact: adalca@csail.mit.edu 6 | """ 7 | 8 | import builtins 9 | import numpy as np 10 | import scipy as sp 11 | import scipy.ndimage 12 | from scipy.spatial import ConvexHull 13 | 14 | 15 | def boundingbox(bwvol): 16 | """ 17 | bounding box coordinates of a nd volume 18 | 19 | Parameters 20 | ---------- 21 | vol : nd array 22 | the binary (black/white) array for which to compute the boundingbox 23 | 24 | Returns 25 | ------- 26 | boundingbox : 1-by-(nd*2) array 27 | [xstart ystart ... xend yend ...] 28 | """ 29 | 30 | # find indices where bwvol is True 31 | idx = np.where(bwvol) 32 | 33 | # get the starts 34 | starts = [np.min(x) for x in idx] 35 | 36 | # get the ends 37 | ends = [np.max(x) for x in idx] 38 | 39 | # concatinate [starts, ends] 40 | return np.concatenate((starts, ends), 0) 41 | 42 | 43 | 44 | def bwdist(bwvol): 45 | """ 46 | positive distance transform from positive entries in logical image 47 | 48 | Parameters 49 | ---------- 50 | bwvol : nd array 51 | The logical volume 52 | 53 | Returns 54 | ------- 55 | possdtrf : nd array 56 | the positive distance transform 57 | 58 | See Also 59 | -------- 60 | bw2sdtrf 61 | """ 62 | 63 | # reverse volume to run scipy function 64 | revbwvol = np.logical_not(bwvol) 65 | 66 | # get distance 67 | return scipy.ndimage.morphology.distance_transform_edt(revbwvol) 68 | 69 | 70 | 71 | def bw2sdtrf(bwvol): 72 | """ 73 | computes the signed distance transform from the surface between the 74 | binary True/False elements of logical bwvol 75 | 76 | Note: the distance transform on either side of the surface will be +1/-1 77 | - i.e. there are no voxels for which the dst should be 0. 78 | 79 | Runtime: currently the function uses bwdist twice. If there is a quick way to 80 | compute the surface, bwdist could be used only once. 81 | 82 | Parameters 83 | ---------- 84 | bwvol : nd array 85 | The logical volume 86 | 87 | Returns 88 | ------- 89 | sdtrf : nd array 90 | the signed distance transform 91 | 92 | See Also 93 | -------- 94 | bwdist 95 | """ 96 | 97 | # get the positive transform (outside the positive island) 98 | posdst = bwdist(bwvol) 99 | 100 | # get the negative transform (distance inside the island) 101 | notbwvol = np.logical_not(bwvol) 102 | negdst = bwdist(notbwvol) 103 | 104 | # combine the positive and negative map 105 | return posdst * notbwvol - negdst * bwvol 106 | 107 | 108 | def bw_convex_hull(bwvol): 109 | # transform bw to mesh. 110 | grid = volsize2ndgrid(bwvol.shape) 111 | # get the 1 points 112 | q = np.concatenate([grid[d].flat for d in bwvol.ndims], 1) 113 | return q 114 | 115 | def bw2contour(bwvol, type='both', thr=1.01): 116 | """ 117 | computes the contour of island(s) on a nd logical volume 118 | 119 | Parameters 120 | ---------- 121 | bwvol : nd array 122 | The logical volume 123 | type : optional string 124 | since the contour is drawn on voxels, it can be drawn on the inside 125 | of the island ('inner'), outside of the island ('outer'), or both 126 | ('both' - default) 127 | 128 | Returns 129 | ------- 130 | contour : nd array 131 | the contour map of the same size of the input 132 | 133 | See Also 134 | -------- 135 | bwdist, bw2dstrf 136 | """ 137 | 138 | # obtain a signed distance transform for the bw volume 139 | sdtrf = bw2sdtrf(bwvol) 140 | 141 | if type == 'inner': 142 | return np.logical_and(sdtrf <= 0, sdtrf > -thr) 143 | elif type == 'outer': 144 | return np.logical_and(sdtrf >= 0, sdtrf < thr) 145 | else: 146 | assert type == 'both', 'type should only be inner, outer or both' 147 | return np.abs(sdtrf) < thr 148 | 149 | 150 | def ndgrid(*args, **kwargs): 151 | """ 152 | Disclaimer: This code is taken directly from the scitools package [1] 153 | Since at the time of writing scitools predominantly requires python 2.7 while we work with 3.5+ 154 | To avoid issues, we copy the quick code here. 155 | 156 | Same as calling ``meshgrid`` with *indexing* = ``'ij'`` (see 157 | ``meshgrid`` for documentation). 158 | """ 159 | kwargs['indexing'] = 'ij' 160 | return np.meshgrid(*args, **kwargs) 161 | 162 | 163 | def volsize2ndgrid(volsize): 164 | """ 165 | return the dense nd-grid for the volume with size volsize 166 | essentially return the ndgrid fpr 167 | """ 168 | ranges = [np.arange(e) for e in volsize] 169 | return ndgrid(*ranges) 170 | 171 | 172 | def bw_sphere(volshape, rad, loc=None): 173 | """ 174 | compute a logical (black/white) image of a sphere 175 | """ 176 | 177 | # if the location is not given, use the center of the volume. 178 | if loc is None: 179 | loc = 1.0 * (np.array(volshape)-1) / 2 180 | assert len(loc) == len(volshape), \ 181 | 'Location (%d) and volume dimensions (%d) do not match' % (len(loc), len(volshape)) 182 | 183 | 184 | # compute distances between each location in the volume and ``loc`` 185 | volgrid = volsize2ndgrid(volshape) 186 | dst = [np.square(loc[d] - volgrid[d]) for d in range(len(volshape))] 187 | dst = np.sqrt(np.sum(dst, 0)) 188 | 189 | # draw the sphere 190 | return dst <= rad 191 | 192 | 193 | def volcrop(vol, new_vol_size=None, start=None, end=None, crop=None): 194 | """ 195 | crop a nd volume. 196 | 197 | Parameters 198 | ---------- 199 | vol : nd array 200 | the nd-dimentional volume to crop. If only specified parameters, is returned intact 201 | new_vol_size : nd vector, optional 202 | the new size of the cropped volume 203 | crop : nd tuple, optional 204 | either tuple of integers or tuple of tuples. 205 | If tuple of integers, will crop that amount from both sides. 206 | if tuple of tuples, expect each inner tuple to specify (crop from start, crop from end) 207 | start : int, optional 208 | start of cropped volume 209 | end : int, optional 210 | end of cropped volume 211 | 212 | Returns 213 | ------ 214 | cropped_vol : nd array 215 | """ 216 | 217 | vol_size = np.asarray(vol.shape) 218 | 219 | # check which parameters are passed 220 | passed_new_vol_size = new_vol_size is not None 221 | passed_start = start is not None 222 | passed_end = end is not None 223 | passed_crop = crop is not None 224 | 225 | # from whatever is passed, we want to obtain start and end. 226 | if passed_start and passed_end: 227 | assert not (passed_new_vol_size or passed_crop), \ 228 | "If passing start and end, don't pass anything else" 229 | 230 | elif passed_new_vol_size: 231 | # compute new volume size and crop_size 232 | assert not passed_crop, "Cannot use both new volume size and crop info" 233 | 234 | # compute start and end 235 | if passed_start: 236 | assert not passed_end, \ 237 | "When giving passed_new_vol_size, cannot pass both start and end" 238 | end = start + new_vol_size 239 | 240 | elif passed_end: 241 | assert not passed_start, \ 242 | "When giving passed_new_vol_size, cannot pass both start and end" 243 | start = end - new_vol_size 244 | 245 | else: # none of crop_size, crop, start or end are passed 246 | mid = np.asarray(vol_size) // 2 247 | start = mid - (new_vol_size // 2) 248 | end = start + new_vol_size 249 | 250 | elif passed_crop: 251 | assert not (passed_start or passed_end or new_vol_size), \ 252 | "Cannot pass both passed_crop and start or end or new_vol_size" 253 | 254 | if isinstance(crop[0], (list, tuple)): 255 | end = vol_size - [val[1] for val in crop] 256 | start = [val[0] for val in crop] 257 | else: 258 | end = vol_size - crop 259 | start = crop 260 | 261 | elif passed_start: # nothing else is passed 262 | end = vol_size 263 | 264 | else: 265 | assert passed_end 266 | start = vol_size * 0 267 | 268 | # get indices. Since we want this to be an nd-volume crop function, we 269 | # idx = [] 270 | # for i in range(len(end)): 271 | # idx.append(slice(start[i], end[i])) 272 | idx = range(start, end) 273 | 274 | return vol[np.ix_(*idx)] 275 | 276 | 277 | def slice(*args): 278 | """ 279 | slice([start], end [,step]) 280 | nd version of slice, where each arg can be a vector of the same length 281 | 282 | Parameters: 283 | [start] (vector): the start 284 | 285 | """ 286 | 287 | # if passed in scalars call the built-in range 288 | if not isinstance(args[0], (list, tuple, np.ndarray)): 289 | return builtins.slice(*args) 290 | 291 | start, end, step = _prep_range(*args) 292 | 293 | # prepare 294 | idx = [slice(start[i], end[i], step[i]) for i in range(len(end))] 295 | return idx 296 | 297 | def range(*args): 298 | """ 299 | range([start], end [,step]) 300 | nd version of range, where each arg can be a vector of the same length 301 | 302 | Parameters: 303 | [start] (vector): the start 304 | 305 | """ 306 | 307 | # if passed in scalars call the built-in range 308 | if not isinstance(args[0], (list, tuple, np.ndarray)): 309 | return np.arange(*args) 310 | 311 | start, end, step = _prep_range(*args) 312 | 313 | # prepare 314 | idx = [range(start[i], end[i], step[i]) for i in range(len(end))] 315 | return idx 316 | 317 | 318 | def arange(*args): 319 | """ 320 | aange([start], end [,step]) 321 | nd version of arange, where each arg can be a vector of the same length 322 | 323 | Parameters: 324 | [start] (vector): the start 325 | 326 | """ 327 | 328 | # if passed in scalars call the built-in range 329 | if not isinstance(args[0], (list, tuple, np.ndarray)): 330 | return builtins.range(*args) 331 | 332 | start, end, step = _prep_range(*args) 333 | 334 | # prepare 335 | idx = [np.arange(start[i], end[i], step[i]) for i in range(len(end))] 336 | return idx 337 | 338 | 339 | 340 | def axissplit(arr, axis): 341 | """ 342 | Split a nd volume along an exis into n volumes, where n is the size of the axis dim. 343 | 344 | Parameters 345 | ---------- 346 | arr : nd array 347 | array to split 348 | axis : integer 349 | indicating axis to split 350 | 351 | Output 352 | ------ 353 | outarr : 1-by-n array 354 | where n is the size of the axis dim in original volume. 355 | each entry is a sub-volume of the original volume 356 | 357 | See also numpy.split() 358 | """ 359 | nba = arr.shape[axis] 360 | return np.split(arr, nba, axis=axis) 361 | 362 | 363 | 364 | 365 | def sub2ind(arr, size, **kwargs): 366 | """ 367 | similar to MATLAB's sub2ind 368 | 369 | Note default order is C-style, not F-style (Fortran/MATLAB) 370 | """ 371 | return np.ravel_multi_index(arr, size, **kwargs) 372 | 373 | 374 | def ind2sub(indices, size, **kwargs): 375 | """ 376 | similar to MATLAB's ind2sub 377 | 378 | Note default order is C-style, not F-style (Fortran/MATLAB) 379 | """ 380 | return np.unravel_index(indices, size, **kwargs) 381 | 382 | 383 | def centroid(im): 384 | """ 385 | compute centroid of a probability ndimage in 0/1 386 | """ 387 | volgrid = volsize2ndgrid(im.shape) 388 | prob = [np.array(im) * np.array(volgrid[d]) for d in range(len(im.shape))] 389 | return [np.sum(p.flat) / np.sum(im.shape) for p in prob] 390 | 391 | 392 | 393 | def ind2sub_entries(indices, size, **kwargs): 394 | """ 395 | returns a nb_entries -by- nb_dims (essentially the transpose of ind2sub) 396 | 397 | somewhat similar to MATLAB's ind2subvec 398 | https://github.com/adalca/mgt/blob/master/src/ind2subvec.m 399 | 400 | Note default order is C-style, not F-style (Fortran/MATLAB) 401 | """ 402 | sub = ind2sub(np.array(indices).flatten(), size, **kwargs) 403 | subvec = np.vstack(sub).transpose() 404 | # Warning this might be F-style-like stacking... it's a bit confusing 405 | return subvec 406 | 407 | ############################################################################### 408 | # internal 409 | ############################################################################### 410 | 411 | def _prep_range(*args): 412 | """ 413 | _prep_range([start], end [,step]) 414 | prepare the start, end and step for range and arange 415 | 416 | Parameters: 417 | [start] (vector): the start 418 | 419 | """ 420 | 421 | # prepare the start, step and end 422 | step = np.ones(len(args[0]), 'int') 423 | if len(args) == 1: 424 | end = args[0] 425 | start = np.zeros(len(end), 'int') 426 | elif len(args) == 2: 427 | assert len(args[0]) == len(args[1]), "argument vectors do not match" 428 | start, end = args 429 | elif len(args) == 3: 430 | assert len(args[0]) == len(args[1]), "argument vectors do not match" 431 | assert len(args[0]) == len(args[2]), "argument vectors do not match" 432 | start, end, step = args 433 | else: 434 | raise ValueError('unknown arguments') 435 | 436 | return (start, end, step) -------------------------------------------------------------------------------- /src/MAS2_test_linear.py: -------------------------------------------------------------------------------- 1 | """ 2 | the test for multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | 7 | import os 8 | import sys 9 | import glob 10 | 11 | # third party 12 | import tensorflow as tf 13 | import keras 14 | import scipy.io as sio 15 | import numpy as np 16 | from scipy import stats 17 | from keras.backend.tensorflow_backend import set_session 18 | from scipy.interpolate import interpn 19 | 20 | # project 21 | sys.path.append('../ext/medipy-lib') 22 | sys.path.append('../ext/neuron') 23 | sys.path.append('../ext/pynd-lib') 24 | sys.path.append('../ext/pytools-lib') 25 | 26 | import medipy 27 | import networks 28 | from medipy.metrics import dice 29 | import datagenerators 30 | import neuron as nu 31 | 32 | 33 | def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 34 | gpu = '/gpu:' + str(gpu_id) 35 | 36 | # Anatomical labels we want to evaluate 37 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 38 | 39 | # read atlas 40 | atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz', 41 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz')# [1,160,192,224,1] 42 | atlas_seg1 = atlas_seg1[0,:,:,:,0]# reduce the dimension to [160,192,224] 43 | atlas_seg1 = keras.utils.to_categorical(atlas_seg1) 44 | 45 | atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz', 46 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz') 47 | atlas_seg2 = atlas_seg2[0, :, :, :, 0] 48 | atlas_seg2 = keras.utils.to_categorical(atlas_seg2) 49 | #gpu = '/gpu:' + str(gpu_id) 50 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 51 | config = tf.ConfigProto() 52 | config.gpu_options.allow_growth = True 53 | config.allow_soft_placement = True 54 | set_session(tf.Session(config=config)) 55 | 56 | # load weights of model 57 | with tf.device(gpu): 58 | net = networks.unet(vol_size, nf_enc, nf_dec) 59 | net.load_weights('/home/ys895/MAS2_Models/'+str(iter_num)+'.h5') 60 | #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') 61 | 62 | xx = np.arange(vol_size[1]) 63 | yy = np.arange(vol_size[0]) 64 | zz = np.arange(vol_size[2]) 65 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3) 66 | #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 67 | X_vol1, X_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz', 68 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz') 69 | 70 | X_vol2, X_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz', 71 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz') 72 | 73 | X_vol3, X_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz', 74 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz') 75 | 76 | X_vol4, X_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz', 77 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz') 78 | 79 | X_vol5, X_seg5 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz', 80 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz') 81 | 82 | # change the direction of the atlas data and volume data 83 | # pred[0].shape (1, 160, 192, 224, 1) 84 | # pred[1].shape (1, 160, 192, 224, 3) 85 | # X1 86 | with tf.device(gpu): 87 | pred1 = net.predict([atlas_vol1, X_vol1]) 88 | pred2 = net.predict([atlas_vol2, X_vol1]) 89 | #pred3 = net.predict([atlas_vol3, X_vol1]) 90 | #pred4 = net.predict([atlas_vol4, X_vol1]) 91 | #pred5 = net.predict([atlas_vol5, X_vol1]) 92 | # Warp segments with flow 93 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 94 | flow2 = pred2[1][0, :, :, :, :] 95 | #flow3 = pred3[1][0, :, :, :, :] 96 | #flow4 = pred4[1][0, :, :, :, :] 97 | #flow5 = pred5[1][0, :, :, :, :] 98 | 99 | sample1 = flow1+grid 100 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)# (160,192,224,3) 101 | sample2 = flow2+grid 102 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 103 | #sample3 = flow3+grid 104 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 105 | #sample4 = flow4+grid 106 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 107 | #sample5 = flow5+grid 108 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 109 | 110 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 111 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 112 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 113 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 114 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 115 | print('warp segmentation shape:' + str(warp_seg1.shape)) 116 | 117 | # label fusion: get the final warp_seg 118 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 119 | print('warp segmentation shape:' + str(warp_seg.shape)) 120 | warp_seg = (warp_seg1 + warp_seg2)/2 121 | warp_seg = np.argmax(warp_seg, axis = 3) 122 | print('warp segmentation shape:' + str(warp_seg.shape)) 123 | 124 | vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2) 125 | mean1 = np.mean(vals) 126 | 127 | # X2 128 | with tf.device(gpu): 129 | pred1 = net.predict([atlas_vol1, X_vol2]) 130 | pred2 = net.predict([atlas_vol2, X_vol2]) 131 | #pred3 = net.predict([atlas_vol3, X_vol2]) 132 | #pred4 = net.predict([atlas_vol4, X_vol2]) 133 | #pred5 = net.predict([atlas_vol5, X_vol2]) 134 | # Warp segments with flow 135 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 136 | flow2 = pred2[1][0, :, :, :, :] 137 | #flow3 = pred3[1][0, :, :, :, :] 138 | #flow4 = pred4[1][0, :, :, :, :] 139 | #flow5 = pred5[1][0, :, :, :, :] 140 | 141 | sample1 = flow1+grid 142 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 143 | sample2 = flow2+grid 144 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 145 | #sample3 = flow3+grid 146 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 147 | #sample4 = flow4+grid 148 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 149 | #sample5 = flow5+grid 150 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 151 | 152 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 153 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 154 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 155 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 156 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 157 | 158 | 159 | # label fusion: get the final warp_seg 160 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 161 | warp_seg = (warp_seg1 + warp_seg2) / 2 162 | warp_seg = np.argmax(warp_seg, axis = 3) 163 | 164 | vals, _ = dice(warp_seg, X_seg2[0,:,:,:,0], labels=labels, nargout=2) 165 | mean2 = np.mean(vals) 166 | #print(np.mean(vals), np.std(vals)) 167 | 168 | # X3 169 | with tf.device(gpu): 170 | pred1 = net.predict([atlas_vol1, X_vol3]) 171 | pred2 = net.predict([atlas_vol2, X_vol3]) 172 | #pred3 = net.predict([atlas_vol3, X_vol1]) 173 | #pred4 = net.predict([atlas_vol4, X_vol1]) 174 | #pred5 = net.predict([atlas_vol5, X_vol1]) 175 | # Warp segments with flow 176 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 177 | flow2 = pred2[1][0, :, :, :, :] 178 | #flow3 = pred3[1][0, :, :, :, :] 179 | #flow4 = pred4[1][0, :, :, :, :] 180 | #flow5 = pred5[1][0, :, :, :, :] 181 | 182 | sample1 = flow1+grid 183 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 184 | sample2 = flow2+grid 185 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 186 | #sample3 = flow3+grid 187 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 188 | #sample4 = flow4+grid 189 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 190 | #sample5 = flow5+grid 191 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 192 | 193 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 194 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 195 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 196 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 197 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 198 | 199 | 200 | # label fusion: get the final warp_seg 201 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 202 | warp_seg = (warp_seg1 + warp_seg2) / 2 203 | warp_seg = np.argmax(warp_seg, axis = 3) 204 | 205 | vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2) 206 | mean3 = np.mean(vals) 207 | 208 | # X4 209 | with tf.device(gpu): 210 | pred1 = net.predict([atlas_vol1, X_vol4]) 211 | pred2 = net.predict([atlas_vol2, X_vol4]) 212 | #pred3 = net.predict([atlas_vol3, X_vol1]) 213 | #pred4 = net.predict([atlas_vol4, X_vol1]) 214 | #pred5 = net.predict([atlas_vol5, X_vol1]) 215 | # Warp segments with flow 216 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 217 | flow2 = pred2[1][0, :, :, :, :] 218 | #flow3 = pred3[1][0, :, :, :, :] 219 | #flow4 = pred4[1][0, :, :, :, :] 220 | #flow5 = pred5[1][0, :, :, :, :] 221 | 222 | sample1 = flow1+grid 223 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 224 | sample2 = flow2+grid 225 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 226 | #sample3 = flow3+grid 227 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 228 | #sample4 = flow4+grid 229 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 230 | #sample5 = flow5+grid 231 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 232 | 233 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 234 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 235 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 236 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 237 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 238 | 239 | 240 | # label fusion: get the final warp_seg 241 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 242 | warp_seg = (warp_seg1 + warp_seg2) / 2 243 | warp_seg = np.argmax(warp_seg, axis = 3) 244 | 245 | 246 | vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2) 247 | mean4 = np.mean(vals) 248 | 249 | 250 | # X5 251 | with tf.device(gpu): 252 | pred1 = net.predict([atlas_vol1, X_vol5]) 253 | pred2 = net.predict([atlas_vol2, X_vol5]) 254 | #pred3 = net.predict([atlas_vol3, X_vol1]) 255 | #pred4 = net.predict([atlas_vol4, X_vol1]) 256 | #pred5 = net.predict([atlas_vol5, X_vol1]) 257 | # Warp segments with flow 258 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 259 | flow2 = pred2[1][0, :, :, :, :] 260 | #flow3 = pred3[1][0, :, :, :, :] 261 | #flow4 = pred4[1][0, :, :, :, :] 262 | #flow5 = pred5[1][0, :, :, :, :] 263 | 264 | sample1 = flow1+grid 265 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 266 | sample2 = flow2+grid 267 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 268 | #sample3 = flow3+grid 269 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 270 | #sample4 = flow4+grid 271 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 272 | #sample5 = flow5+grid 273 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 274 | 275 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 276 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 277 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 278 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 279 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 280 | 281 | 282 | # label fusion: get the final warp_seg 283 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 284 | warp_seg = (warp_seg1 + warp_seg2) / 2 285 | warp_seg = np.argmax(warp_seg, axis = 3) 286 | 287 | vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2) 288 | mean5 = np.mean(vals) 289 | 290 | # compute mean of dice score 291 | sum = mean1 + mean2 + mean3 + mean4 + mean5 292 | mean_dice = sum/5 293 | print(mean_dice) 294 | 295 | # plot the outcome of warp seg 296 | #warp_seg = warp_seg.reshape((warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 297 | #warp_seg2 = np.empty(shape = (warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 298 | #for i in range(0,warp_seg.shape[1]): 299 | # warp_seg2[i,:,:] = np.transpose(warp_seg[:,i,:]) 300 | #nu.plot.slices(warp_seg) 301 | 302 | 303 | if __name__ == "__main__": 304 | #result_list = np.empty((1000,1)) 305 | #for i in range(0,35): 306 | # iterr = (i+1)*200 307 | # result_list[i,0] = test(iterr,sys.argv[1])[0] 308 | #print(result_list) 309 | test(sys.argv[1], sys.argv[2]) -------------------------------------------------------------------------------- /src/MAS3_test_linear.py: -------------------------------------------------------------------------------- 1 | """ 2 | the test for multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | 7 | import os 8 | import sys 9 | import glob 10 | 11 | # third party 12 | import tensorflow as tf 13 | import keras 14 | import scipy.io as sio 15 | import numpy as np 16 | from scipy import stats 17 | from keras.backend.tensorflow_backend import set_session 18 | from scipy.interpolate import interpn 19 | 20 | # project 21 | sys.path.append('../ext/medipy-lib') 22 | sys.path.append('../ext/neuron') 23 | sys.path.append('../ext/pynd-lib') 24 | sys.path.append('../ext/pytools-lib') 25 | 26 | import medipy 27 | import networks 28 | from medipy.metrics import dice 29 | import datagenerators 30 | import neuron as nu 31 | 32 | 33 | def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 34 | gpu = '/gpu:' + str(gpu_id) 35 | 36 | # Anatomical labels we want to evaluate 37 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 38 | 39 | # read atlas 40 | atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz', 41 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz')# [1,160,192,224,1] 42 | atlas_seg1 = atlas_seg1[0,:,:,:,0]# reduce the dimension to [160,192,224] 43 | atlas_seg1 = keras.utils.to_categorical(atlas_seg1) 44 | 45 | atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz', 46 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz') 47 | atlas_seg2 = atlas_seg2[0, :, :, :, 0] 48 | atlas_seg2 = keras.utils.to_categorical(atlas_seg2) 49 | atlas_vol3, atlas_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990405_vc922.npz', 50 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990405_vc922.npz') 51 | atlas_seg3 = atlas_seg3[0, :, :, :, 0] 52 | atlas_seg3 = keras.utils.to_categorical(atlas_seg3) 53 | 54 | #gpu = '/gpu:' + str(gpu_id) 55 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 56 | config = tf.ConfigProto() 57 | config.gpu_options.allow_growth = True 58 | config.allow_soft_placement = True 59 | set_session(tf.Session(config=config)) 60 | 61 | # load weights of model 62 | with tf.device(gpu): 63 | net = networks.unet(vol_size, nf_enc, nf_dec) 64 | net.load_weights('/home/ys895/MAS4_Models/'+str(iter_num)+'.h5') 65 | #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') 66 | 67 | xx = np.arange(vol_size[1]) 68 | yy = np.arange(vol_size[0]) 69 | zz = np.arange(vol_size[2]) 70 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3) 71 | #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 72 | X_vol1, X_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz', 73 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz') 74 | 75 | X_vol2, X_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz', 76 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz') 77 | 78 | X_vol3, X_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz', 79 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz') 80 | 81 | X_vol4, X_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz', 82 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz') 83 | 84 | X_vol5, X_seg5 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz', 85 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz') 86 | 87 | # change the direction of the atlas data and volume data 88 | # pred[0].shape (1, 160, 192, 224, 1) 89 | # pred[1].shape (1, 160, 192, 224, 3) 90 | # X1 91 | with tf.device(gpu): 92 | pred1 = net.predict([atlas_vol1, X_vol1]) 93 | pred2 = net.predict([atlas_vol2, X_vol1]) 94 | pred3 = net.predict([atlas_vol3, X_vol1]) 95 | #pred4 = net.predict([atlas_vol4, X_vol1]) 96 | #pred5 = net.predict([atlas_vol5, X_vol1]) 97 | # Warp segments with flow 98 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 99 | flow2 = pred2[1][0, :, :, :, :] 100 | flow3 = pred3[1][0, :, :, :, :] 101 | #flow4 = pred4[1][0, :, :, :, :] 102 | #flow5 = pred5[1][0, :, :, :, :] 103 | 104 | sample1 = flow1+grid 105 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 106 | sample2 = flow2+grid 107 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 108 | sample3 = flow3+grid 109 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 110 | #sample4 = flow4+grid 111 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 112 | #sample5 = flow5+grid 113 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 114 | 115 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 116 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 117 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 118 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 119 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 120 | 121 | 122 | # label fusion: get the final warp_seg 123 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 124 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3) / 3 125 | warp_seg = np.argmax(warp_seg, axis=3) 126 | 127 | vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2) 128 | mean1 = np.mean(vals) 129 | 130 | # X2 131 | with tf.device(gpu): 132 | pred1 = net.predict([atlas_vol1, X_vol2]) 133 | pred2 = net.predict([atlas_vol2, X_vol2]) 134 | pred3 = net.predict([atlas_vol3, X_vol2]) 135 | #pred4 = net.predict([atlas_vol4, X_vol2]) 136 | #pred5 = net.predict([atlas_vol5, X_vol2]) 137 | # Warp segments with flow 138 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 139 | flow2 = pred2[1][0, :, :, :, :] 140 | flow3 = pred3[1][0, :, :, :, :] 141 | #flow4 = pred4[1][0, :, :, :, :] 142 | #flow5 = pred5[1][0, :, :, :, :] 143 | 144 | sample1 = flow1+grid 145 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 146 | sample2 = flow2+grid 147 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 148 | sample3 = flow3+grid 149 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 150 | #sample4 = flow4+grid 151 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 152 | #sample5 = flow5+grid 153 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 154 | 155 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 156 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 157 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 158 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 159 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 160 | 161 | 162 | # label fusion: get the final warp_seg 163 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 164 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3) / 3 165 | warp_seg = np.argmax(warp_seg, axis=3) 166 | 167 | vals, _ = dice(warp_seg, X_seg2[0,:,:,:,0], labels=labels, nargout=2) 168 | mean2 = np.mean(vals) 169 | #print(np.mean(vals), np.std(vals)) 170 | 171 | # X3 172 | with tf.device(gpu): 173 | pred1 = net.predict([atlas_vol1, X_vol3]) 174 | pred2 = net.predict([atlas_vol2, X_vol3]) 175 | pred3 = net.predict([atlas_vol3, X_vol3]) 176 | #pred4 = net.predict([atlas_vol4, X_vol3]) 177 | #pred5 = net.predict([atlas_vol5, X_vol3]) 178 | # Warp segments with flow 179 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 180 | flow2 = pred2[1][0, :, :, :, :] 181 | flow3 = pred3[1][0, :, :, :, :] 182 | #flow4 = pred4[1][0, :, :, :, :] 183 | #flow5 = pred5[1][0, :, :, :, :] 184 | 185 | sample1 = flow1+grid 186 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 187 | sample2 = flow2+grid 188 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 189 | sample3 = flow3+grid 190 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 191 | #sample4 = flow4+grid 192 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 193 | #sample5 = flow5+grid 194 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 195 | 196 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 197 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 198 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 199 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 200 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 201 | 202 | 203 | # label fusion: get the final warp_seg 204 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 205 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3) / 3 206 | warp_seg = np.argmax(warp_seg, axis=3) 207 | 208 | vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2) 209 | mean3 = np.mean(vals) 210 | 211 | # X4 212 | with tf.device(gpu): 213 | pred1 = net.predict([atlas_vol1, X_vol4]) 214 | pred2 = net.predict([atlas_vol2, X_vol4]) 215 | pred3 = net.predict([atlas_vol3, X_vol4]) 216 | #pred4 = net.predict([atlas_vol4, X_vol1]) 217 | #pred5 = net.predict([atlas_vol5, X_vol1]) 218 | # Warp segments with flow 219 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 220 | flow2 = pred2[1][0, :, :, :, :] 221 | flow3 = pred3[1][0, :, :, :, :] 222 | #flow4 = pred4[1][0, :, :, :, :] 223 | #flow5 = pred5[1][0, :, :, :, :] 224 | 225 | sample1 = flow1+grid 226 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 227 | sample2 = flow2+grid 228 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 229 | sample3 = flow3+grid 230 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 231 | #sample4 = flow4+grid 232 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 233 | #sample5 = flow5+grid 234 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 235 | 236 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 237 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 238 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 239 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 240 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 241 | 242 | 243 | # label fusion: get the final warp_seg 244 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 245 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3) / 3 246 | warp_seg = np.argmax(warp_seg, axis=3) 247 | 248 | vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2) 249 | mean4 = np.mean(vals) 250 | 251 | # X5 252 | with tf.device(gpu): 253 | pred1 = net.predict([atlas_vol1, X_vol5]) 254 | pred2 = net.predict([atlas_vol2, X_vol5]) 255 | pred3 = net.predict([atlas_vol3, X_vol5]) 256 | #pred4 = net.predict([atlas_vol4, X_vol1]) 257 | #pred5 = net.predict([atlas_vol5, X_vol1]) 258 | # Warp segments with flow 259 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 260 | flow2 = pred2[1][0, :, :, :, :] 261 | flow3 = pred3[1][0, :, :, :, :] 262 | #flow4 = pred4[1][0, :, :, :, :] 263 | #flow5 = pred5[1][0, :, :, :, :] 264 | 265 | sample1 = flow1+grid 266 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 267 | sample2 = flow2+grid 268 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 269 | sample3 = flow3+grid 270 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 271 | #sample4 = flow4+grid 272 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 273 | #sample5 = flow5+grid 274 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 275 | 276 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 277 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 278 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 279 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 280 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 281 | 282 | 283 | # label fusion: get the final warp_seg 284 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 285 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3) / 3 286 | warp_seg = np.argmax(warp_seg, axis=3) 287 | 288 | vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2) 289 | mean5 = np.mean(vals) 290 | 291 | # compute mean of dice score 292 | sum = mean1 + mean2 + mean3 + mean4 + mean5 293 | mean_dice = sum/5 294 | print(mean_dice) 295 | 296 | # plot the outcome of warp seg 297 | #warp_seg = warp_seg.reshape((warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 298 | #warp_seg2 = np.empty(shape = (warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 299 | #for i in range(0,warp_seg.shape[1]): 300 | # warp_seg2[i,:,:] = np.transpose(warp_seg[:,i,:]) 301 | #nu.plot.slices(warp_seg) 302 | 303 | 304 | if __name__ == "__main__": 305 | #result_list = np.empty((1000,1)) 306 | #for i in range(0,35): 307 | # iterr = (i+1)*200 308 | # result_list[i,0] = test(iterr,sys.argv[1])[0] 309 | #print(result_list) 310 | test(sys.argv[1], sys.argv[2]) -------------------------------------------------------------------------------- /src/MAS2_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | the test for multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | 7 | import os 8 | import sys 9 | import glob 10 | 11 | # third party 12 | import tensorflow as tf 13 | import scipy.io as sio 14 | import numpy as np 15 | from scipy import stats 16 | from keras.backend.tensorflow_backend import set_session 17 | from scipy.interpolate import interpn 18 | 19 | # project 20 | sys.path.append('../ext/medipy-lib') 21 | sys.path.append('../ext/neuron') 22 | sys.path.append('../ext/pynd-lib') 23 | sys.path.append('../ext/pytools-lib') 24 | 25 | import medipy 26 | import networks 27 | from medipy.metrics import dice 28 | import datagenerators 29 | import neuron as nu 30 | 31 | 32 | def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 33 | gpu = '/gpu:' + str(gpu_id) 34 | 35 | # Anatomical labels we want to evaluate 36 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 37 | 38 | # read atlas 39 | atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz', 40 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz')# [1,160,192,224,1] 41 | atlas_seg1 = atlas_seg1[0,:,:,:,0]# reduce the dimension to [160,192,224] 42 | 43 | atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz', 44 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz') 45 | atlas_seg2 = atlas_seg2[0, :, :, :, 0] 46 | 47 | #gpu = '/gpu:' + str(gpu_id) 48 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 49 | config = tf.ConfigProto() 50 | config.gpu_options.allow_growth = True 51 | config.allow_soft_placement = True 52 | set_session(tf.Session(config=config)) 53 | 54 | # load weights of model 55 | with tf.device(gpu): 56 | net = networks.unet(vol_size, nf_enc, nf_dec) 57 | net.load_weights('/home/ys895/MAS2_Models/'+str(iter_num)+'.h5') 58 | #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') 59 | 60 | xx = np.arange(vol_size[1]) 61 | yy = np.arange(vol_size[0]) 62 | zz = np.arange(vol_size[2]) 63 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3) 64 | #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 65 | X_vol1, X_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz', 66 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz') 67 | 68 | X_vol2, X_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz', 69 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz') 70 | 71 | X_vol3, X_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz', 72 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz') 73 | 74 | X_vol4, X_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz', 75 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz') 76 | 77 | X_vol5, X_seg5 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz', 78 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz') 79 | 80 | # change the direction of the atlas data and volume data 81 | # pred[0].shape (1, 160, 192, 224, 1) 82 | # pred[1].shape (1, 160, 192, 224, 3) 83 | # X1 84 | with tf.device(gpu): 85 | pred1 = net.predict([atlas_vol1, X_vol1]) 86 | pred2 = net.predict([atlas_vol2, X_vol1]) 87 | #pred3 = net.predict([atlas_vol3, X_vol1]) 88 | #pred4 = net.predict([atlas_vol4, X_vol1]) 89 | #pred5 = net.predict([atlas_vol5, X_vol1]) 90 | # Warp segments with flow 91 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 92 | flow2 = pred2[1][0, :, :, :, :] 93 | #flow3 = pred3[1][0, :, :, :, :] 94 | #flow4 = pred4[1][0, :, :, :, :] 95 | #flow5 = pred5[1][0, :, :, :, :] 96 | 97 | sample1 = flow1+grid 98 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 99 | sample2 = flow2+grid 100 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 101 | #sample3 = flow3+grid 102 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 103 | #sample4 = flow4+grid 104 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 105 | #sample5 = flow5+grid 106 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 107 | 108 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 109 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 110 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 111 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 112 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 113 | 114 | 115 | # label fusion: get the final warp_seg 116 | warp_seg = np.empty((160, 192, 224)) 117 | for x in range(0,160): 118 | for y in range(0,192): 119 | for z in range(0,224): 120 | warp_arr = np.array([warp_seg1[x,y,z],warp_seg2[x,y,z]]) 121 | #print(warp_arr) 122 | warp_seg[x,y,z] = stats.mode(warp_arr)[0] 123 | 124 | vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2) 125 | mean1 = np.mean(vals) 126 | 127 | # X2 128 | with tf.device(gpu): 129 | pred1 = net.predict([atlas_vol1, X_vol2]) 130 | pred2 = net.predict([atlas_vol2, X_vol2]) 131 | #pred3 = net.predict([atlas_vol3, X_vol2]) 132 | #pred4 = net.predict([atlas_vol4, X_vol2]) 133 | #pred5 = net.predict([atlas_vol5, X_vol2]) 134 | # Warp segments with flow 135 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 136 | flow2 = pred2[1][0, :, :, :, :] 137 | #flow3 = pred3[1][0, :, :, :, :] 138 | #flow4 = pred4[1][0, :, :, :, :] 139 | #flow5 = pred5[1][0, :, :, :, :] 140 | 141 | sample1 = flow1+grid 142 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 143 | sample2 = flow2+grid 144 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 145 | #sample3 = flow3+grid 146 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 147 | #sample4 = flow4+grid 148 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 149 | #sample5 = flow5+grid 150 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 151 | 152 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 153 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 154 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 155 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 156 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 157 | 158 | 159 | # label fusion: get the final warp_seg 160 | warp_seg = np.empty((160, 192, 224)) 161 | for x in range(0,160): 162 | for y in range(0,192): 163 | for z in range(0,224): 164 | warp_arr = np.array([warp_seg1[x,y,z],warp_seg2[x,y,z]]) 165 | #print(warp_arr) 166 | warp_seg[x,y,z] = stats.mode(warp_arr)[0] 167 | 168 | vals, _ = dice(warp_seg, X_seg2[0,:,:,:,0], labels=labels, nargout=2) 169 | mean2 = np.mean(vals) 170 | #print(np.mean(vals), np.std(vals)) 171 | 172 | # X3 173 | with tf.device(gpu): 174 | pred1 = net.predict([atlas_vol1, X_vol3]) 175 | pred2 = net.predict([atlas_vol2, X_vol3]) 176 | #pred3 = net.predict([atlas_vol3, X_vol1]) 177 | #pred4 = net.predict([atlas_vol4, X_vol1]) 178 | #pred5 = net.predict([atlas_vol5, X_vol1]) 179 | # Warp segments with flow 180 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 181 | flow2 = pred2[1][0, :, :, :, :] 182 | #flow3 = pred3[1][0, :, :, :, :] 183 | #flow4 = pred4[1][0, :, :, :, :] 184 | #flow5 = pred5[1][0, :, :, :, :] 185 | 186 | sample1 = flow1+grid 187 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 188 | sample2 = flow2+grid 189 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 190 | #sample3 = flow3+grid 191 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 192 | #sample4 = flow4+grid 193 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 194 | #sample5 = flow5+grid 195 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 196 | 197 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 198 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 199 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 200 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 201 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 202 | 203 | 204 | # label fusion: get the final warp_seg 205 | warp_seg = np.empty((160, 192, 224)) 206 | for x in range(0,160): 207 | for y in range(0,192): 208 | for z in range(0,224): 209 | warp_arr = np.array([warp_seg1[x,y,z],warp_seg2[x,y,z]]) 210 | #print(warp_arr) 211 | warp_seg[x,y,z] = stats.mode(warp_arr)[0] 212 | 213 | vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2) 214 | mean3 = np.mean(vals) 215 | 216 | # X4 217 | with tf.device(gpu): 218 | pred1 = net.predict([atlas_vol1, X_vol4]) 219 | pred2 = net.predict([atlas_vol2, X_vol4]) 220 | #pred3 = net.predict([atlas_vol3, X_vol1]) 221 | #pred4 = net.predict([atlas_vol4, X_vol1]) 222 | #pred5 = net.predict([atlas_vol5, X_vol1]) 223 | # Warp segments with flow 224 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 225 | flow2 = pred2[1][0, :, :, :, :] 226 | #flow3 = pred3[1][0, :, :, :, :] 227 | #flow4 = pred4[1][0, :, :, :, :] 228 | #flow5 = pred5[1][0, :, :, :, :] 229 | 230 | sample1 = flow1+grid 231 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 232 | sample2 = flow2+grid 233 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 234 | #sample3 = flow3+grid 235 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 236 | #sample4 = flow4+grid 237 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 238 | #sample5 = flow5+grid 239 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 240 | 241 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 242 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 243 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 244 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 245 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 246 | 247 | 248 | # label fusion: get the final warp_seg 249 | warp_seg = np.empty((160, 192, 224)) 250 | for x in range(0,160): 251 | for y in range(0,192): 252 | for z in range(0,224): 253 | warp_arr = np.array([warp_seg1[x,y,z],warp_seg2[x,y,z]]) 254 | #print(warp_arr) 255 | warp_seg[x,y,z] = stats.mode(warp_arr)[0] 256 | 257 | vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2) 258 | mean4 = np.mean(vals) 259 | 260 | 261 | # X5 262 | with tf.device(gpu): 263 | pred1 = net.predict([atlas_vol1, X_vol5]) 264 | pred2 = net.predict([atlas_vol2, X_vol5]) 265 | #pred3 = net.predict([atlas_vol3, X_vol1]) 266 | #pred4 = net.predict([atlas_vol4, X_vol1]) 267 | #pred5 = net.predict([atlas_vol5, X_vol1]) 268 | # Warp segments with flow 269 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 270 | flow2 = pred2[1][0, :, :, :, :] 271 | #flow3 = pred3[1][0, :, :, :, :] 272 | #flow4 = pred4[1][0, :, :, :, :] 273 | #flow5 = pred5[1][0, :, :, :, :] 274 | 275 | sample1 = flow1+grid 276 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 277 | sample2 = flow2+grid 278 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 279 | #sample3 = flow3+grid 280 | #sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 281 | #sample4 = flow4+grid 282 | #sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 283 | #sample5 = flow5+grid 284 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 285 | 286 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224) 287 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :], sample2, method='nearest', bounds_error=False, fill_value=0) 288 | #warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :], sample3, method='nearest', bounds_error=False, fill_value=0) 289 | #warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :], sample4, method='nearest', bounds_error=False, fill_value=0) 290 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 291 | 292 | 293 | # label fusion: get the final warp_seg 294 | warp_seg = np.empty((160, 192, 224)) 295 | for x in range(0,160): 296 | for y in range(0,192): 297 | for z in range(0,224): 298 | warp_arr = np.array([warp_seg1[x,y,z],warp_seg2[x,y,z]]) 299 | #print(warp_arr) 300 | warp_seg[x,y,z] = stats.mode(warp_arr)[0] 301 | 302 | vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2) 303 | mean5 = np.mean(vals) 304 | 305 | # compute mean of dice score 306 | sum = mean1 + mean2 + mean3 + mean4 + mean5 307 | mean_dice = sum/5 308 | print(mean_dice) 309 | 310 | # plot the outcome of warp seg 311 | #warp_seg = warp_seg.reshape((warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 312 | #warp_seg2 = np.empty(shape = (warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 313 | #for i in range(0,warp_seg.shape[1]): 314 | # warp_seg2[i,:,:] = np.transpose(warp_seg[:,i,:]) 315 | #nu.plot.slices(warp_seg) 316 | 317 | 318 | if __name__ == "__main__": 319 | #result_list = np.empty((1000,1)) 320 | #for i in range(0,35): 321 | # iterr = (i+1)*200 322 | # result_list[i,0] = test(iterr,sys.argv[1])[0] 323 | #print(result_list) 324 | test(sys.argv[1], sys.argv[2]) -------------------------------------------------------------------------------- /src/MAS4_test_linear.py: -------------------------------------------------------------------------------- 1 | """ 2 | the test for multi atlas segmentation based on Voxelmorph and Neuron 3 | 4 | """ 5 | 6 | 7 | import os 8 | import sys 9 | import glob 10 | 11 | # third party 12 | import tensorflow as tf 13 | import keras 14 | import scipy.io as sio 15 | import numpy as np 16 | from scipy import stats 17 | from keras.backend.tensorflow_backend import set_session 18 | from scipy.interpolate import interpn 19 | 20 | # project 21 | sys.path.append('../ext/medipy-lib') 22 | sys.path.append('../ext/neuron') 23 | sys.path.append('../ext/pynd-lib') 24 | sys.path.append('../ext/pytools-lib') 25 | 26 | import medipy 27 | import networks 28 | from medipy.metrics import dice 29 | import datagenerators 30 | import neuron as nu 31 | 32 | 33 | def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]): 34 | gpu = '/gpu:' + str(gpu_id) 35 | 36 | # Anatomical labels we want to evaluate 37 | labels = sio.loadmat('../data/labels.mat')['labels'][0] 38 | 39 | # read atlas 40 | atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz', 41 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz')# [1,160,192,224,1] 42 | atlas_seg1 = atlas_seg1[0,:,:,:,0]# reduce the dimension to [160,192,224] 43 | atlas_seg1 = keras.utils.to_categorical(atlas_seg1) 44 | 45 | atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz', 46 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz') 47 | atlas_seg2 = atlas_seg2[0, :, :, :, 0] 48 | atlas_seg2 = keras.utils.to_categorical(atlas_seg2) 49 | 50 | atlas_vol3, atlas_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990405_vc922.npz', 51 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990405_vc922.npz') 52 | atlas_seg3 = atlas_seg3[0, :, :, :, 0] 53 | atlas_seg3 = keras.utils.to_categorical(atlas_seg3) 54 | 55 | atlas_vol4, atlas_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991006_vc1337.npz', 56 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991006_vc1337.npz') 57 | atlas_seg4 = atlas_seg4[0, :, :, :, 0] 58 | atlas_seg4 = keras.utils.to_categorical(atlas_seg4) 59 | 60 | 61 | #gpu = '/gpu:' + str(gpu_id) 62 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 63 | config = tf.ConfigProto() 64 | config.gpu_options.allow_growth = True 65 | config.allow_soft_placement = True 66 | set_session(tf.Session(config=config)) 67 | 68 | # load weights of model 69 | with tf.device(gpu): 70 | net = networks.unet(vol_size, nf_enc, nf_dec) 71 | net.load_weights('/home/ys895/MAS4_Models/'+str(iter_num)+'.h5') 72 | #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') 73 | 74 | xx = np.arange(vol_size[1]) 75 | yy = np.arange(vol_size[0]) 76 | zz = np.arange(vol_size[2]) 77 | grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3) 78 | #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz') 79 | X_vol1, X_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz', 80 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz') 81 | 82 | X_vol2, X_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz', 83 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz') 84 | 85 | X_vol3, X_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz', 86 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz') 87 | 88 | X_vol4, X_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz', 89 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz') 90 | 91 | X_vol5, X_seg5 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz', 92 | '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz') 93 | 94 | # change the direction of the atlas data and volume data 95 | # pred[0].shape (1, 160, 192, 224, 1) 96 | # pred[1].shape (1, 160, 192, 224, 3) 97 | # X1 98 | with tf.device(gpu): 99 | pred1 = net.predict([atlas_vol1, X_vol1]) 100 | pred2 = net.predict([atlas_vol2, X_vol1]) 101 | pred3 = net.predict([atlas_vol3, X_vol1]) 102 | pred4 = net.predict([atlas_vol4, X_vol1]) 103 | #pred5 = net.predict([atlas_vol5, X_vol1]) 104 | # Warp segments with flow 105 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 106 | flow2 = pred2[1][0, :, :, :, :] 107 | flow3 = pred3[1][0, :, :, :, :] 108 | flow4 = pred4[1][0, :, :, :, :] 109 | #flow5 = pred5[1][0, :, :, :, :] 110 | 111 | sample1 = flow1+grid 112 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 113 | sample2 = flow2+grid 114 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 115 | sample3 = flow3+grid 116 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 117 | sample4 = flow4+grid 118 | sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 119 | #sample5 = flow5+grid 120 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 121 | 122 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 123 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 124 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 125 | warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :, :], sample4, method='linear', bounds_error=False, fill_value=0) 126 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 127 | 128 | 129 | # label fusion: get the final warp_seg 130 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 131 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4) / 4 132 | warp_seg = np.argmax(warp_seg, axis=3) 133 | 134 | vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2) 135 | mean1 = np.mean(vals) 136 | 137 | # X2 138 | with tf.device(gpu): 139 | pred1 = net.predict([atlas_vol1, X_vol2]) 140 | pred2 = net.predict([atlas_vol2, X_vol2]) 141 | pred3 = net.predict([atlas_vol3, X_vol2]) 142 | pred4 = net.predict([atlas_vol4, X_vol2]) 143 | #pred5 = net.predict([atlas_vol5, X_vol2]) 144 | # Warp segments with flow 145 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 146 | flow2 = pred2[1][0, :, :, :, :] 147 | flow3 = pred3[1][0, :, :, :, :] 148 | flow4 = pred4[1][0, :, :, :, :] 149 | #flow5 = pred5[1][0, :, :, :, :] 150 | 151 | sample1 = flow1+grid 152 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 153 | sample2 = flow2+grid 154 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 155 | sample3 = flow3+grid 156 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 157 | sample4 = flow4+grid 158 | sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 159 | #sample5 = flow5+grid 160 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 161 | 162 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 163 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 164 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 165 | warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :, :], sample4, method='linear', bounds_error=False, fill_value=0) 166 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 167 | 168 | 169 | # label fusion: get the final warp_seg 170 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 171 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4) / 4 172 | warp_seg = np.argmax(warp_seg, axis=3) 173 | 174 | vals, _ = dice(warp_seg, X_seg2[0,:,:,:,0], labels=labels, nargout=2) 175 | mean2 = np.mean(vals) 176 | #print(np.mean(vals), np.std(vals)) 177 | 178 | # X3 179 | with tf.device(gpu): 180 | pred1 = net.predict([atlas_vol1, X_vol3]) 181 | pred2 = net.predict([atlas_vol2, X_vol3]) 182 | pred3 = net.predict([atlas_vol3, X_vol3]) 183 | pred4 = net.predict([atlas_vol4, X_vol3]) 184 | #pred5 = net.predict([atlas_vol5, X_vol3]) 185 | # Warp segments with flow 186 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 187 | flow2 = pred2[1][0, :, :, :, :] 188 | flow3 = pred3[1][0, :, :, :, :] 189 | flow4 = pred4[1][0, :, :, :, :] 190 | #flow5 = pred5[1][0, :, :, :, :] 191 | 192 | sample1 = flow1+grid 193 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 194 | sample2 = flow2+grid 195 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 196 | sample3 = flow3+grid 197 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 198 | sample4 = flow4+grid 199 | sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 200 | #sample5 = flow5+grid 201 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 202 | 203 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 204 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 205 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 206 | warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :, :], sample4, method='linear', bounds_error=False, fill_value=0) 207 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 208 | 209 | 210 | # label fusion: get the final warp_seg 211 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 212 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4) / 4 213 | warp_seg = np.argmax(warp_seg, axis=3) 214 | 215 | vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2) 216 | mean3 = np.mean(vals) 217 | 218 | # X4 219 | with tf.device(gpu): 220 | pred1 = net.predict([atlas_vol1, X_vol4]) 221 | pred2 = net.predict([atlas_vol2, X_vol4]) 222 | pred3 = net.predict([atlas_vol3, X_vol4]) 223 | pred4 = net.predict([atlas_vol4, X_vol4]) 224 | #pred5 = net.predict([atlas_vol5, X_vol4]) 225 | # Warp segments with flow 226 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 227 | flow2 = pred2[1][0, :, :, :, :] 228 | flow3 = pred3[1][0, :, :, :, :] 229 | flow4 = pred4[1][0, :, :, :, :] 230 | #flow5 = pred5[1][0, :, :, :, :] 231 | 232 | sample1 = flow1+grid 233 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 234 | sample2 = flow2+grid 235 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 236 | sample3 = flow3+grid 237 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 238 | sample4 = flow4+grid 239 | sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 240 | #sample5 = flow5+grid 241 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 242 | 243 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 244 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 245 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 246 | warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :, :], sample4, method='linear', bounds_error=False, fill_value=0) 247 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 248 | 249 | 250 | # label fusion: get the final warp_seg 251 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 252 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4) / 4 253 | warp_seg = np.argmax(warp_seg, axis=3) 254 | 255 | vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2) 256 | mean4 = np.mean(vals) 257 | 258 | # X5 259 | with tf.device(gpu): 260 | pred1 = net.predict([atlas_vol1, X_vol5]) 261 | pred2 = net.predict([atlas_vol2, X_vol5]) 262 | pred3 = net.predict([atlas_vol3, X_vol5]) 263 | pred4 = net.predict([atlas_vol4, X_vol5]) 264 | #pred5 = net.predict([atlas_vol5, X_vol1]) 265 | # Warp segments with flow 266 | flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3) 267 | flow2 = pred2[1][0, :, :, :, :] 268 | flow3 = pred3[1][0, :, :, :, :] 269 | flow4 = pred4[1][0, :, :, :, :] 270 | #flow5 = pred5[1][0, :, :, :, :] 271 | 272 | sample1 = flow1+grid 273 | sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3) 274 | sample2 = flow2+grid 275 | sample2 = np.stack((sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3) 276 | sample3 = flow3+grid 277 | sample3 = np.stack((sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3) 278 | sample4 = flow4+grid 279 | sample4 = np.stack((sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3) 280 | #sample5 = flow5+grid 281 | #sample5 = np.stack((sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3) 282 | 283 | warp_seg1 = interpn((yy, xx, zz), atlas_seg1[:, :, :, :], sample1, method='linear', bounds_error=False, fill_value=0) # (160, 192, 224) 284 | warp_seg2 = interpn((yy, xx, zz), atlas_seg2[:, :, :, :], sample2, method='linear', bounds_error=False, fill_value=0) 285 | warp_seg3 = interpn((yy, xx, zz), atlas_seg3[:, :, :, :], sample3, method='linear', bounds_error=False, fill_value=0) 286 | warp_seg4 = interpn((yy, xx, zz), atlas_seg4[:, :, :, :], sample4, method='linear', bounds_error=False, fill_value=0) 287 | #warp_seg5 = interpn((yy, xx, zz), atlas_seg5[:, :, :], sample5, method='nearest', bounds_error=False, fill_value=0) 288 | 289 | 290 | # label fusion: get the final warp_seg 291 | warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3])) 292 | warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4) / 4 293 | warp_seg = np.argmax(warp_seg, axis=3) 294 | 295 | vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2) 296 | mean5 = np.mean(vals) 297 | 298 | # compute mean of dice score 299 | sum = mean1 + mean2 + mean3 + mean4 + mean5 300 | mean_dice = sum/5 301 | print(mean_dice) 302 | 303 | # plot the outcome of warp seg 304 | #warp_seg = warp_seg.reshape((warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 305 | #warp_seg2 = np.empty(shape = (warp_seg.shape[1], warp_seg.shape[2], warp_seg.shape[0])) 306 | #for i in range(0,warp_seg.shape[1]): 307 | # warp_seg2[i,:,:] = np.transpose(warp_seg[:,i,:]) 308 | #nu.plot.slices(warp_seg) 309 | 310 | 311 | if __name__ == "__main__": 312 | #result_list = np.empty((1000,1)) 313 | #for i in range(0,35): 314 | # iterr = (i+1)*200 315 | # result_list[i,0] = test(iterr,sys.argv[1])[0] 316 | #print(result_list) 317 | test(sys.argv[1], sys.argv[2]) -------------------------------------------------------------------------------- /ext/neuron/neuron/dataproc.py: -------------------------------------------------------------------------------- 1 | ''' data processing for neuron project ''' 2 | 3 | # built-in 4 | import sys 5 | import os 6 | import shutil 7 | import six 8 | 9 | # third party 10 | import nibabel as nib 11 | import numpy as np 12 | import scipy.ndimage.interpolation 13 | from tqdm import tqdm_notebook as tqdm # for verbosity for forloops 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | 19 | # note sure if tqdm_notebook reverts back to 20 | try: 21 | get_ipython 22 | from tqdm import tqdm_notebook as tqdm 23 | except: 24 | from tqdm import tqdm as tqdm 25 | 26 | from subprocess import call 27 | 28 | 29 | # import local ndutils 30 | import pynd.ndutils as nd 31 | import re 32 | 33 | from imp import reload 34 | reload(nd) 35 | 36 | # from imp import reload # for re-loading modules, since some of the modules are still in development 37 | # reload(nd) 38 | 39 | 40 | def proc_mgh_vols(inpath, 41 | outpath, 42 | ext='.mgz', 43 | label_idx=None, 44 | **kwargs): 45 | ''' process mgh data from mgz format and save to numpy format 46 | 47 | 1. load file 48 | 2. normalize intensity 49 | 3. resize 50 | 4. save as python block 51 | 52 | TODO: check header info and such.? 53 | ''' 54 | 55 | # get files in input directory 56 | files = [f for f in os.listdir(inpath) if f.endswith(ext)] 57 | 58 | # go through each file 59 | list_skipped_files = () 60 | for fileidx in tqdm(range(len(files)), ncols=80): 61 | 62 | # load nifti volume 63 | volnii = nib.load(os.path.join(inpath, files[fileidx])) 64 | 65 | # get the data out 66 | vol_data = volnii.get_data().astype(float) 67 | 68 | if ('dim' in volnii.header) and volnii.header['dim'][4] > 1: 69 | vol_data = vol_data[:, :, :, -1] 70 | 71 | # process volume 72 | try: 73 | vol_data = vol_proc(vol_data, **kwargs) 74 | except Exception as e: 75 | list_skipped_files += (files[fileidx], ) 76 | print("Skipping %s\nError: %s" % (files[fileidx], str(e)), file=sys.stderr) 77 | continue 78 | 79 | if label_idx is not None: 80 | vol_data = (vol_data == label_idx).astype(int) 81 | 82 | # save numpy file 83 | outname = os.path.splitext(os.path.join(outpath, files[fileidx]))[0] + '.npz' 84 | np.savez_compressed(outname, vol_data=vol_data) 85 | 86 | for file in list_skipped_files: 87 | print("Skipped: %s" % file, file=sys.stderr) 88 | 89 | 90 | def scans_to_slices(inpath, outpath, slice_nrs, 91 | ext='.mgz', 92 | label_idx=None, 93 | dim_idx=2, 94 | out_ext='.png', 95 | slice_pad=0, 96 | vol_inner_pad_for_slice_nrs=0, 97 | **kwargs): # vol_proc args 98 | 99 | # get files in input directory 100 | files = [f for f in os.listdir(inpath) if f.endswith(ext)] 101 | 102 | # go through each file 103 | list_skipped_files = () 104 | for fileidx in tqdm(range(len(files)), ncols=80): 105 | 106 | # load nifti volume 107 | volnii = nib.load(os.path.join(inpath, files[fileidx])) 108 | 109 | # get the data out 110 | vol_data = volnii.get_data().astype(float) 111 | 112 | if ('dim' in volnii.header) and volnii.header['dim'][4] > 1: 113 | vol_data = vol_data[:, :, :, -1] 114 | 115 | if slice_pad > 0: 116 | assert (out_ext != '.png'), "slice pad can only be used with volumes" 117 | 118 | # process volume 119 | try: 120 | vol_data = vol_proc(vol_data, **kwargs) 121 | except Exception as e: 122 | list_skipped_files += (files[fileidx], ) 123 | print("Skipping %s\nError: %s" % (files[fileidx], str(e)), file=sys.stderr) 124 | continue 125 | 126 | mult_fact = 255 127 | if label_idx is not None: 128 | vol_data = (vol_data == label_idx).astype(int) 129 | mult_fact = 1 130 | 131 | # extract slice 132 | if slice_nrs is None: 133 | slice_nrs_sel = range(vol_inner_pad_for_slice_nrs+slice_pad, vol_data.shape[dim_idx]-slice_pad-vol_inner_pad_for_slice_nrs) 134 | else: 135 | slice_nrs_sel = slice_nrs 136 | 137 | for slice_nr in slice_nrs_sel: 138 | slice_nr_out = range(slice_nr - slice_pad, slice_nr + slice_pad + 1) 139 | if dim_idx == 2: # TODO: fix in one line 140 | vol_img = np.squeeze(vol_data[:, :, slice_nr_out]) 141 | elif dim_idx == 1: 142 | vol_img = np.squeeze(vol_data[:, slice_nr_out, :]) 143 | else: 144 | vol_img = np.squeeze(vol_data[slice_nr_out, :, :]) 145 | 146 | # save file 147 | if out_ext == '.png': 148 | # save png file 149 | img = (vol_img*mult_fact).astype('uint8') 150 | outname = os.path.splitext(os.path.join(outpath, files[fileidx]))[0] + '_slice%d.png' % slice_nr 151 | Image.fromarray(img).convert('RGB').save(outname) 152 | else: 153 | if slice_pad == 0: # dimenion has collapsed 154 | assert vol_img.ndim == 2 155 | vol_img = np.expand_dims(vol_img, dim_idx) 156 | # assuming nibabel saving image 157 | nii = nib.Nifti1Image(vol_img, np.diag([1,1,1,1])) 158 | outname = os.path.splitext(os.path.join(outpath, files[fileidx]))[0] + '_slice%d.nii.gz' % slice_nr 159 | nib.save(nii, outname) 160 | 161 | 162 | def vol_proc(vol_data, 163 | crop=None, 164 | resize_shape=None, # None (to not resize), or vector. If vector, third entry can be None 165 | interp_order=None, 166 | rescale=None, 167 | rescale_prctle=None, 168 | resize_slices=None, 169 | resize_slices_dim=None, 170 | offset=None, 171 | clip=None, 172 | extract_nd=None, # extracts a particular section 173 | force_binary=None, # forces anything > 0 to be 1 174 | permute=None): 175 | ''' process a volume with a series of intensity rescale, resize and crop rescale''' 176 | 177 | if offset is not None: 178 | vol_data = vol_data + offset 179 | 180 | # intensity normalize data .* rescale 181 | if rescale is not None: 182 | vol_data = np.multiply(vol_data, rescale) 183 | 184 | if rescale_prctle is not None: 185 | # print("max:", np.max(vol_data.flat)) 186 | # print("test") 187 | rescale = np.percentile(vol_data.flat, rescale_prctle) 188 | # print("rescaling by 1/%f" % (rescale)) 189 | vol_data = np.multiply(vol_data.astype(float), 1/rescale) 190 | 191 | if resize_slices is not None: 192 | resize_slices = [*resize_slices] 193 | assert resize_shape is None, "if resize_slices is given, resize_shape has to be None" 194 | resize_shape = resize_slices 195 | if resize_slices_dim is None: 196 | resize_slices_dim = np.where([f is None for f in resize_slices])[0] 197 | assert len(resize_slices_dim) == 1, "Could not find dimension or slice resize" 198 | resize_slices_dim = resize_slices_dim[0] 199 | resize_shape[resize_slices_dim] = vol_data.shape[resize_slices_dim] 200 | 201 | # resize (downsample) matrices 202 | if resize_shape is not None and resize_shape != vol_data.shape: 203 | resize_shape = [*resize_shape] 204 | # allow for the last entry to be None 205 | if resize_shape[-1] is None: 206 | resize_ratio = np.divide(resize_shape[0], vol_data.shape[0]) 207 | resize_shape[-1] = np.round(resize_ratio * vol_data.shape[-1]).astype('int') 208 | resize_ratio = np.divide(resize_shape, vol_data.shape) 209 | vol_data = scipy.ndimage.interpolation.zoom(vol_data, resize_ratio, order=interp_order) 210 | 211 | # crop data if necessary 212 | if crop is not None: 213 | vol_data = nd.volcrop(vol_data, crop=crop) 214 | 215 | # needs to be last to guarantee clip limits. 216 | # For e.g., resize might screw this up due to bicubic interpolation if it was done after. 217 | if clip is not None: 218 | vol_data = np.clip(vol_data, clip[0], clip[1]) 219 | 220 | if extract_nd is not None: 221 | vol_data = vol_data[np.ix_(*extract_nd)] 222 | 223 | if force_binary: 224 | vol_data = (vol_data > 0).astype(float) 225 | 226 | # return with checks. this check should be right at the end before rturn 227 | if clip is not None: 228 | assert np.max(vol_data) <= clip[1], "clip failed" 229 | assert np.min(vol_data) >= clip[0], "clip failed" 230 | return vol_data 231 | 232 | 233 | def prior_to_weights(prior_filename, nargout=1, min_freq=0, force_binary=False, verbose=False): 234 | 235 | ''' transform a 4D prior (3D + nb_labels) into a class weight vector ''' 236 | 237 | # load prior 238 | if isinstance(prior_filename, six.string_types): 239 | prior = np.load(prior_filename)['prior'] 240 | else: 241 | prior = prior_filename 242 | 243 | # assumes prior is 4D. 244 | assert np.ndim(prior) == 4 or np.ndim(prior) == 3, "prior is the wrong number of dimensions" 245 | prior_flat = np.reshape(prior, (np.prod(prior.shape[0:(np.ndim(prior)-1)]), prior.shape[-1])) 246 | 247 | if force_binary: 248 | nb_labels = prior_flat.shape[-1] 249 | prior_flat[:, 1] = np.sum(prior_flat[:, 1:nb_labels], 1) 250 | prior_flat = np.delete(prior_flat, range(2, nb_labels), 1) 251 | 252 | # sum total class votes 253 | class_count = np.sum(prior_flat, 0) 254 | class_prior = class_count / np.sum(class_count) 255 | 256 | # adding minimum frequency 257 | class_prior[class_prior < min_freq] = min_freq 258 | class_prior = class_prior / np.sum(class_prior) 259 | 260 | if np.any(class_prior == 0): 261 | print("Warning, found a label with 0 support. Setting its weight to 0!", file=sys.stderr) 262 | class_prior[class_prior == 0] = np.inf 263 | 264 | # compute weights from class frequencies 265 | weights = 1/class_prior 266 | weights = weights / np.sum(weights) 267 | # weights[0] = 0 # explicitly don't care about bg 268 | 269 | # a bit of verbosity 270 | if verbose: 271 | f, (ax1, ax2, ax3) = plt.subplots(1, 3) 272 | ax1.bar(range(prior.size), np.log(prior)) 273 | ax1.set_title('log class freq') 274 | ax2.bar(range(weights.size), weights) 275 | ax2.set_title('weights') 276 | ax3.bar(range(weights.size), np.log((weights))-np.min(np.log((weights)))) 277 | ax3.set_title('log(weights)-minlog') 278 | f.set_size_inches(12, 3) 279 | plt.show() 280 | np.set_printoptions(precision=3) 281 | 282 | # return 283 | if nargout == 1: 284 | return weights 285 | else: 286 | return (weights, prior) 287 | 288 | 289 | 290 | 291 | def filestruct_change(in_path, out_path, re_map, 292 | mode='subj_to_type', 293 | use_symlinks=False, name=""): 294 | """ 295 | change from independent subjects in a folder to breakdown structure 296 | 297 | example: filestruct_change('/../in_path', '/../out_path', {'asegs.nii.gz':'asegs', 'norm.nii.gz':'vols'}) 298 | 299 | 300 | input structure: 301 | /.../in_path/subj_1 --> with files that match regular repressions defined in re_map.keys() 302 | /.../in_path/subj_2 --> with files that match regular repressions defined in re_map.keys() 303 | ... 304 | output structure: 305 | /.../out_path/asegs/subj_1.nii.gz, subj_2.nii.gz 306 | /.../out_path/vols/subj_1.nii.gz, subj_2.nii.gz 307 | 308 | Parameters: 309 | in_path (string): input path 310 | out_path (string): output path 311 | re_map (dictionary): keys are reg-exs that match files in the input folders. 312 | values are the folders to put those files in the new structure. 313 | values can also be tuples, in which case values[0] is the dst folder, 314 | and values[1] is the extension of the output file 315 | mode (optional) 316 | use_symlinks (bool): whether to just use symlinks rather than copy files 317 | default:True 318 | """ 319 | 320 | 321 | if not os.path.isdir(out_path): 322 | os.mkdir(out_path) 323 | 324 | # go through folders 325 | for subj in tqdm(os.listdir(in_path), desc=name): 326 | 327 | # go through files in a folder 328 | files = os.listdir(os.path.join(in_path, subj)) 329 | for file in files: 330 | 331 | # see which key matches. Make sure only one does. 332 | matches = [re.match(k, file) for k in re_map.keys()] 333 | nb_matches = sum([f is not None for f in matches]) 334 | assert nb_matches == 1, "Found %d matches for file %s/%s" %(nb_matches, file, subj) 335 | 336 | # get the matches key 337 | match_idx = [i for i,f in enumerate(matches) if f is not None][0] 338 | matched_dst = re_map[list(re_map.keys())[match_idx]] 339 | _, ext = os.path.splitext(file) 340 | if isinstance(matched_dst, tuple): 341 | ext = matched_dst[1] 342 | matched_dst = matched_dst[0] 343 | 344 | # prepare source and destination file 345 | src_file = os.path.join(in_path, subj, file) 346 | dst_path = os.path.join(out_path, matched_dst) 347 | if not os.path.isdir(dst_path): 348 | os.mkdir(dst_path) 349 | dst_file = os.path.join(dst_path, subj + ext) 350 | 351 | if use_symlinks: 352 | # on windows there are permission problems. 353 | # Can try : call(['mklink', 'LINK', 'TARGET'], shell=True) 354 | # or note https://stackoverflow.com/questions/6260149/os-symlink-support-in-windows 355 | os.symlink(src_file, dst_file) 356 | 357 | else: 358 | shutil.copyfile(src_file, dst_file) 359 | 360 | 361 | def ml_split(in_path, out_path, 362 | cat_titles=['train', 'validate', 'test'], 363 | cat_prop=[0.5, 0.3, 0.2], 364 | use_symlinks=False, 365 | seed=None, 366 | tqdm=tqdm): 367 | """ 368 | split dataset 369 | """ 370 | 371 | if seed is not None: 372 | np.random.seed(seed) 373 | 374 | if not os.path.isdir(out_path): 375 | os.makedirs(out_path) 376 | 377 | # get subjects and randomize their order 378 | subjs = sorted(os.listdir(in_path)) 379 | nb_subj = len(subjs) 380 | subj_order = np.random.permutation(nb_subj) 381 | 382 | # prepare split 383 | cat_tot = np.cumsum(cat_prop) 384 | if not cat_tot[-1] == 1: 385 | print("split_prop sums to %f, re-normalizing" % cat_tot) 386 | cat_tot = np.array(cat_tot) / cat_tot[-1] 387 | nb_cat_subj = np.round(cat_tot * nb_subj).astype(int) 388 | cat_subj_start = [0, *nb_cat_subj[:-1]] 389 | 390 | # go through each category 391 | for cat_idx, cat in enumerate(cat_titles): 392 | if not os.path.isdir(os.path.join(out_path, cat)): 393 | os.mkdir(os.path.join(out_path, cat)) 394 | 395 | cat_subj_idx = subj_order[cat_subj_start[cat_idx]:nb_cat_subj[cat_idx]] 396 | for subj_idx in tqdm(cat_subj_idx, desc=cat): 397 | src_folder = os.path.join(in_path, subjs[subj_idx]) 398 | dst_folder = os.path.join(out_path, cat, subjs[subj_idx]) 399 | 400 | if use_symlinks: 401 | # on windows there are permission problems. 402 | # Can try : call(['mklink', 'LINK', 'TARGET'], shell=True) 403 | # or note https://stackoverflow.com/questions/6260149/os-symlink-support-in-windows 404 | os.symlink(src_folder, dst_folder) 405 | 406 | else: 407 | if os.path.isdir(src_folder): 408 | shutil.copytree(src_folder, dst_folder) 409 | else: 410 | shutil.copyfile(src_folder, dst_folder) 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | --------------------------------------------------------------------------------