├── .travis.yml ├── LICENSE ├── README.md ├── fetal_brain_segmentation_mri ├── README.md ├── create_csv.ipynb ├── deploy.py ├── dsc.png ├── example.png ├── fetal_fcn.py ├── iFind_fetal.csv ├── reader.py └── train.py ├── logo.jpg ├── setup.cfg ├── synapse_btcv_abdominal_ct_segmentation ├── README.md ├── config.json ├── config_asym_unet_balce.json ├── config_asym_unet_ce.json ├── config_fcn_balce.json ├── config_fcn_ce.json ├── config_unet_balce.json ├── config_unet_ce.json ├── deploy.py ├── pred.png ├── preprocessing.py ├── reader.py ├── test.csv ├── train.csv ├── train.py └── val.csv ├── ukbb_cardiac_segmentation_cine_mri ├── README.md ├── demo.py ├── deploy_network.py ├── image_utils.py ├── network.py └── train_network.py └── ukbb_neuronet_brain_segmentation ├── README.md ├── config_all.json ├── config_fsl_fast.json ├── config_fsl_first.json ├── config_malp_em.json ├── config_malp_em_tissue.json ├── config_spm_tissue.json ├── config_tissue.json ├── deploy.py ├── eval.ipynb ├── figures ├── boxplot.png ├── ex.png ├── example_seg.png └── fail.png ├── neuronet.py ├── parse_csvs.ipynb ├── reader.py ├── sandbox.ipynb ├── test.csv ├── train.csv ├── train.py └── val.csv /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: trusty 3 | python: 4 | - 2.7.9 5 | - 2.7 6 | - 3.5 7 | - 3.6 8 | cache: 9 | directories: 10 | - "$HOME/.cache/pip" 11 | install: 12 | - pip install tensorflow 13 | - pip install --upgrade pytest 14 | - pip install -e git+https://github.com/DLTK/DLTK.git@dev#egg=dltk[tests] 15 | script: 16 | - pytest --flake8 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DLTK Model Zoo 2 | [![Gitter](https://badges.gitter.im/DLTK/DLTK.svg)](https://gitter.im/DLTK/DLTK?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) 3 | [![Build Status](https://travis-ci.org/DLTK/models.svg?branch=master)](https://travis-ci.org/DLTK/models) 4 | 5 | ![DLTK Model Zoo logo](logo.jpg) 6 | 7 | ### Referencing and citing methods in the Model Zoo 8 | To find out how to reference each implementation, please refer to the specifications in the authors' README.md. If you use DLTK in your work please refer to this citation: 9 | 10 | ``` 11 | @article{pawlowski2017state, 12 | title={DLTK: State of the Art Reference Implementations for Deep Learning on Medical Images}, 13 | author={Nick Pawlowski and S. Ira Ktena, and Matthew C.H. Lee and Bernhard Kainz and Daniel Rueckert and Ben Glocker and Martin Rajchl}, 14 | journal={arXiv preprint arXiv:1711.06853}, 15 | year={2017} 16 | } 17 | ``` 18 | 19 | ### Installation 20 | To install DLTK, check out the installation instructions on the main [repo](https://github.com/DLTK/DLTK/blob/master/README.md). Although not encouraged, additional dependecies might need to be installed for each separate model implementation. Please refer to the individual README.md files for further instructions. 21 | Other than that, clone the Model Zoo repository via 22 | 23 | ``` 24 | git clone https://github.com/DLTK/models.git 25 | ``` 26 | and download any pre-trained models, if available for download. 27 | 28 | ### How to contribute 29 | We appreciate any contributions to the DLTK and its Model Zoo. If you have improvements, features or patches, please send us your pull requests! You can find specific instructions on how to issue a PR on github [here](https://help.github.com/articles/about-pull-requests/). Feel free to open an [issue](https://github.com/DLTK/DLTK/issues) if you find a bug or directly come chat with us on our gitter channel [![Gitter](https://badges.gitter.im/DLTK/DLTK.svg)](https://gitter.im/DLTK/DLTK?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge). 30 | 31 | #### Basic contribution guidelines 32 | - Python coding style: Like TensorFlow, we loosely adhere to [google coding style](https://google.github.io/styleguide/pyguide.html) and [google docstrings](https://google.github.io/styleguide/pyguide.html#Comments). 33 | - Entirely new features should be committed to ```dltk/contrib``` before we can sensibly integrate it into the core. 34 | - Standalone problem-specific applications or (re-)implementations of published methods should be committed to the [DLTK Model Zoo](https://github.com/DLTK/models) repo and provide a README.md file with author/coder contact information. 35 | 36 | ### The team 37 | The DLTK Model Zoo is currently maintained by [@pawni](https://github.com/pawni) and [@mrajchl](https://github.com/mrajchl), with greatly appreciated contributions from [@baiwenjia](https://github.com/baiwenjia) [@farrell236](https://github.com/farrell236) (alphabetical order). 38 | 39 | ### License 40 | See [LICENSE](https://github.com/DLTK/models/blob/master/LICENSE) 41 | 42 | ### Acknowledgments 43 | We would like to thank [NVIDIA GPU Computing](http://www.nvidia.com/) for providing us with hardware for our research. 44 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/README.md: -------------------------------------------------------------------------------- 1 | ## Fetal brain segmentation from motion corrupted 2D MR image stacks 2 | 3 | ![Exemplary segmentations](example.png) 4 | 5 | ### Contact and referencing this work 6 | If there are any issues please contact the corresponding author of this implementation. If you employ this model in your work, please refer to this citation of the [paper](https://arxiv.org/abs/1606.01100), which contains more detailed information on the implementation than the original. 7 | ``` 8 | @article{rajchl2016learning, 9 | title={Learning under distributed weak supervision}, 10 | author={Rajchl, Martin and Lee, Matthew CH and Schrans, Franklin and Davidson, Alice and Passerat-Palmbach, Jonathan and Tarroni, Giacomo and Alansary, Amir and Oktay, Ozan and Kainz, Bernhard and Rueckert, Daniel}, 11 | journal={arXiv preprint arXiv:1606.01100}, 12 | year={2016} 13 | } 14 | ``` 15 | 16 | ### Important Notes 17 | This is a close implementation of the original caffe code, however differs from that described in the paper in the following points: 18 | - The model was trained on [None, 3, 128, 128, 1] slices, rather than the original code encoding the adjacent slices as channels (i.e. [None, 1, 128, 128, 3]) 19 | - Batch normalisation was employed before each ReLu non-linearity 20 | 21 | ### Data 22 | The MR images have been collected during the [iFind (intellient Fetal Imaging and Diagnosis) project](http://www.ifindproject.com/) at King's College London. For this experimental setup, the model was trained on 30 volumes and validated on 7. If you would require access to image data, please contact the [iFind MRI team](http://www.ifindproject.com/team-2/). 23 | 24 | Images and segmentations are read from a csv file in the format below. The original file (iFind_fetal.csv) is provided in this repo. 25 | 26 | iFind_fetal.csv: 27 | ``` 28 | iFIND_id,image,segmentation 29 | iFIND00011,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00011.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00011.nii.gz 30 | iFIND00018,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00018.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00018.nii.gz 31 | ... 32 | ``` 33 | 34 | These are parsed and extract tf.Tensor examples for training and evaluation in `reader.py` using a [SimpleITK](http://www.simpleitk.org/) for i/o of the .nii files. 35 | 36 | 37 | ### Usage 38 | - You can download a pre-trained model for fine-tuning or deployment [here](https://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/fetal_brain_segmentation_mri.tar.gz). 39 | The archive contains both the tf.estimator export folder and the standard 40 | .index, .meta and .data-* files for continuing training. Extract the model 41 | folder from the .tar.gz file and point your ```--model_path``` MY_MODEL_PATH 42 | argument to its location (see below). 43 | 44 | - To train a new model, run the train.py script. Display run options with 45 | ``` python train.py --help ```: 46 | 47 | ``` 48 | usage: train.py [-h] [--run_validation RUN_VALIDATION] [--restart] [--verbose] 49 | [--cuda_devices CUDA_DEVICES] [--model_path MODEL_PATH] 50 | [--data_csv DATA_CSV] 51 | ``` 52 | 53 | To start training, run the training script with the desired options: 54 | 55 | ``` 56 | python train.py MY_OPTIONS 57 | ``` 58 | 59 | The model and training events will be saved to a ```model_path``` 60 | MY_MODEL_PATH 61 | 62 | - For monitoring and metric tracking, spawn a tensorboard webserver and point 63 | the log directory to MY_MODEL_PATH: 64 | 65 | ``` 66 | tensorboard --logdir MY_MODEL_PATH 67 | ``` 68 | 69 | ![Dice_similarity_coefficient](dsc.png) 70 | 71 | 72 | - To deploy a model and run inference, run the deploy.py script and point to 73 | the trained model: 74 | 75 | ``` 76 | python -u deploy.py --model_path MY_MODEL_PATH 77 | ``` 78 | 79 | Note that during deploy we average the predictions of 4 random crops of a test input, so results may vary a bit from run to run. The expected output of deploy should look similar to the one below and yield a test Dice of approximately 0.924 +/- 0.022: 80 | 81 | ``` 82 | Loading from .../1511556748 83 | Dice=0.94843941927; input_dim=(1, 120, 336, 336, 1); time=17.9820091724; output_fn=.../iFIND00031.nii.gz; 84 | Dice=0.951662540436; input_dim=(1, 170, 336, 336, 1); time=22.7679200172; output_fn=.../iFIND00035.nii.gz; 85 | ... 86 | ``` 87 | 88 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/create_csv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 11, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 12, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "data_path = '/vol/biomedic2/mrajchl/data/iFind2_db'" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 13, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "ids = ['iFIND00011', 'iFIND00014', 'iFIND00018', 'iFIND00022', 'iFIND00025', 'iFIND00029', 'iFIND00033', 'iFIND00036', 'iFIND00039', 'iFIND00042', 'iFIND00046', 'iFIND00049', \n", 30 | "'iFIND00012', 'iFIND00016', 'iFIND00019', 'iFIND00023', 'iFIND00026', 'iFIND00030', 'iFIND00034', 'iFIND00037', 'iFIND00040', 'iFIND00043', 'iFIND00047', 'iFIND00050', \n", 31 | "'iFIND00013', 'iFIND00017', 'iFIND00020', 'iFIND00024', 'iFIND00028', 'iFIND00031', 'iFIND00035', 'iFIND00038', 'iFIND00041', 'iFIND00045', 'iFIND00048', 'iFIND00051', ]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 14, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "hdr = ['iFIND_id', 'image', 'segmentation']\n", 41 | "\n", 42 | "data = []\n", 43 | "for i in ids:\n", 44 | " \n", 45 | " img_fn = os.path.join(data_path, 'imgs', i + '.nii.gz')\n", 46 | " seg_fn = os.path.join(data_path, 'seg', i + '.nii.gz')\n", 47 | " \n", 48 | " data.append([i, img_fn, seg_fn] )" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 15, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "['iFIND00011', '/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00011.nii.gz', '/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00011.nii.gz']\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "print (data[0])" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 16, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "pd.DataFrame(data).to_csv('iFind_fetal.csv', index=False, header=hdr)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "Python 2", 88 | "language": "python", 89 | "name": "python2" 90 | }, 91 | "language_info": { 92 | "codemirror_mode": { 93 | "name": "ipython", 94 | "version": 2 95 | }, 96 | "file_extension": ".py", 97 | "mimetype": "text/x-python", 98 | "name": "python", 99 | "nbconvert_exporter": "python", 100 | "pygments_lexer": "ipython2", 101 | "version": "2.7.12" 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 2 106 | } 107 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/deploy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import tensorflow as tf 12 | import SimpleITK as sitk 13 | 14 | from tensorflow.contrib import predictor 15 | 16 | from dltk.core import metrics as metrics 17 | from dltk.utils import sliding_window_segmentation_inference 18 | 19 | from reader import read_fn 20 | 21 | 22 | READER_PARAMS = {'extract_examples': False} 23 | N_VALIDATION_SUBJECTS = 7 24 | 25 | 26 | def predict(args): 27 | 28 | # Read in the csv with the file names you would want to predict on 29 | file_names = pd.read_csv(args.csv, 30 | dtype=object, 31 | keep_default_na=False, 32 | na_values=[]).as_matrix() 33 | 34 | # We trained on the first 4 subjects, so we predict on the rest 35 | file_names = file_names[-N_VALIDATION_SUBJECTS:] 36 | 37 | # From the model model_path, parse the latest saved estimator model 38 | # and restore a predictor from it 39 | export_dir = [os.path.join(args.model_path, o) for o in os.listdir(args.model_path) 40 | if os.path.isdir(os.path.join(args.model_path, o)) and o.isdigit()][-1] 41 | print('Loading from {}'.format(export_dir)) 42 | my_predictor = predictor.from_saved_model(export_dir) 43 | 44 | # Fetch the output probability op of the trained network 45 | y_prob = my_predictor._fetch_tensors['y_prob'] 46 | num_classes = y_prob.get_shape().as_list()[-1] 47 | 48 | if (args.mode == 'TRAIN'): 49 | mode = tf.estimator.ModeKeys.TRAIN 50 | elif (args.mode == 'EVAL'): 51 | mode = tf.estimator.ModeKeys.EVAL 52 | elif (args.mode == 'PREDICT'): 53 | mode = tf.estimator.ModeKeys.PREDICT 54 | 55 | # Iterate through the files, predict on the full volumes and 56 | # compute a Dice similariy coefficient 57 | for output in read_fn(file_references=file_names, 58 | mode=mode, 59 | params=READER_PARAMS): 60 | 61 | t0 = time.time() 62 | 63 | # Parse the read function output and add a dummy batch dimension 64 | # as required 65 | img = np.expand_dims(output['features']['x'], axis=0) 66 | lbl = np.expand_dims(output['labels']['y'], axis=0) 67 | 68 | # Do a sliding window inference with our DLTK wrapper 69 | pred = sliding_window_segmentation_inference( 70 | session=my_predictor.session, 71 | ops_list=[y_prob], 72 | sample_dict={my_predictor._feed_tensors['x']: img}, 73 | batch_size=32)[0] 74 | 75 | # Calculate the prediction from the probabilities 76 | pred = np.argmax(pred, -1) 77 | 78 | # Calculate the Dice coefficient 79 | dsc = -1 80 | if (args.mode != 'PREDICT'): 81 | dsc = metrics.dice(pred, lbl, num_classes)[1:].mean() 82 | 83 | # Save the file as .nii.gz using the header information from the 84 | # original sitk image 85 | file_id = str(output['img_id']) 86 | output_fn = os.path.join(args.model_path, '{}.nii.gz'.format(file_id)) 87 | new_sitk = sitk.GetImageFromArray(pred[0].astype(np.int32)) 88 | 89 | new_sitk.CopyInformation(output['sitk']) 90 | sitk.WriteImage(new_sitk, output_fn) 91 | 92 | # Print outputs 93 | print('Dice={}; input_dim={}; time={}; output_fn={};'.format( 94 | dsc, img.shape, time.time() - t0, output_fn)) 95 | 96 | 97 | if __name__ == '__main__': 98 | # Set up argument parser 99 | parser = argparse.ArgumentParser(description='iFind2 fetal segmentation deploy script') 100 | parser.add_argument('--verbose', default=False, action='store_true') 101 | parser.add_argument('--cuda_devices', '-c', default='0') 102 | 103 | parser.add_argument('--model_path', '-p', default='/tmp/fetal_segmentation/') 104 | parser.add_argument('--csv', default='iFind_fetal.csv') 105 | parser.add_argument('--mode', default='EVAL', help='TRAIN, EVAL or PREDICT') 106 | 107 | args = parser.parse_args() 108 | 109 | # Set verbosity 110 | if args.verbose: 111 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 112 | tf.logging.set_verbosity(tf.logging.INFO) 113 | else: 114 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 115 | tf.logging.set_verbosity(tf.logging.ERROR) 116 | 117 | # GPU allocation options 118 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 119 | 120 | # Call training 121 | predict(args) 122 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/dsc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/fetal_brain_segmentation_mri/dsc.png -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/fetal_brain_segmentation_mri/example.png -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/fetal_fcn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from dltk.networks.segmentation.fcn import upscore_layer_3d 9 | 10 | 11 | def fetal_fcn_3d(inputs, 12 | num_classes, 13 | filters=(32, 32, 32, 32, 128), 14 | kernel_sizes=(5, 5, 5, 3, 1), 15 | strides=((1, 1, 1), (1, 2, 2), (1, 2, 2), (1, 2, 2), (1, 1, 1)), 16 | mode=tf.estimator.ModeKeys.EVAL, 17 | use_bias=False, 18 | kernel_initializer=tf.uniform_unit_scaling_initializer(), 19 | bias_initializer=tf.zeros_initializer(), 20 | kernel_regularizer=None, 21 | bias_regularizer=None): 22 | """Image segmentation network based on a modified [1] FCN architecture [2]. 23 | 24 | [1] M. Rajchl et al. Learning under Distributed Weak Supervision. arXiv:1606.01100 2016. 25 | [2] K. He et al. Deep residual learning for image recognition. CVPR 2016. 26 | 27 | Args: 28 | inputs (tf.Tensor): Input feature tensor to the network (rank 5 29 | required). 30 | num_classes (int): Number of output classes. 31 | num_res_units (int, optional): Number of residual units at each 32 | resolution scale. 33 | filters (tuple, optional): Number of filters for all residual units at 34 | each resolution scale. 35 | strides (tuple, optional): Stride of the first unit on a resolution 36 | scale. 37 | mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: 38 | TRAIN, EVAL or PREDICT 39 | use_bias (bool, optional): Boolean, whether the layer uses a bias. 40 | kernel_initializer (TYPE, optional): An initializer for the convolution 41 | kernel. 42 | bias_initializer (TYPE, optional): An initializer for the bias vector. 43 | If None, no bias will be applied. 44 | kernel_regularizer (None, optional): Optional regularizer for the 45 | convolution kernel. 46 | bias_regularizer (None, optional): Optional regularizer for the bias 47 | vector. 48 | Returns: 49 | dict: dictionary of output tensors 50 | """ 51 | 52 | outputs = {} 53 | assert len(strides) == len(filters) 54 | assert len(inputs.get_shape().as_list()) == 5, \ 55 | 'inputs are required to have a rank of 5.' 56 | 57 | padding = 'same' 58 | 59 | conv_params = {'use_bias': use_bias, 60 | 'kernel_initializer': kernel_initializer, 61 | 'bias_initializer': bias_initializer, 62 | 'kernel_regularizer': kernel_regularizer, 63 | 'bias_regularizer': bias_regularizer} 64 | 65 | relu_op = tf.nn.relu6 66 | pool_op = tf.layers.max_pooling3d 67 | 68 | x = inputs 69 | 70 | tf.logging.info('Init conv tensor shape {}'.format(x.get_shape())) 71 | 72 | res_scales = [x] 73 | 74 | for res_scale in range(0, len(filters)): 75 | 76 | # Use max pooling when required 77 | if np.prod(strides[res_scale]) > 1: 78 | 79 | with tf.variable_scope('pool_{}'.format(res_scale)): 80 | x = pool_op( 81 | inputs=x, 82 | pool_size=[2 * s for s in strides[res_scale]], 83 | strides=strides[res_scale], 84 | padding=padding) 85 | 86 | # Add two blocks of conv/relu units for feature extraction 87 | with tf.variable_scope('enc_unit_{}'.format(res_scale)): 88 | for block in range(2): 89 | x = tf.layers.conv3d( 90 | inputs=x, 91 | filters=filters[res_scale], 92 | kernel_size=[kernel_sizes[res_scale]] * 3, 93 | strides=(1, 1, 1), 94 | padding=padding, 95 | **conv_params) 96 | 97 | x = tf.layers.batch_normalization( 98 | x, training=mode == tf.estimator.ModeKeys.TRAIN) 99 | x = relu_op(x) 100 | 101 | # Dropout with 0.5 on the last scale 102 | if res_scale is len(filters) - 1: 103 | with tf.variable_scope('dropout_{}'.format(res_scale)): 104 | x = tf.layers.dropout(x) 105 | 106 | tf.logging.info('Encoder at res_scale {} shape: {}'.format( 107 | res_scale, x.get_shape())) 108 | 109 | res_scales.append(x) 110 | 111 | # Upscore layers [2] reconstruct the predictions to higher resolution scales 112 | for res_scale in reversed(range(0, len(filters))): 113 | 114 | with tf.variable_scope('upscore_{}'.format(res_scale)): 115 | 116 | x = upscore_layer_3d(inputs=x, 117 | inputs2=res_scales[res_scale], 118 | out_filters=num_classes, 119 | strides=strides[res_scale], 120 | mode=mode, 121 | **conv_params) 122 | tf.logging.info('Decoder at res_scale {} shape: {}'.format( 123 | res_scale, x.get_shape())) 124 | 125 | # Last convolution 126 | with tf.variable_scope('last'): 127 | x = tf.layers.conv3d( 128 | inputs=x, 129 | filters=num_classes, 130 | kernel_size=(1, 1, 1), 131 | strides=(1, 1, 1), 132 | padding='same', 133 | **conv_params) 134 | 135 | tf.logging.info('Output tensor shape {}'.format(x.get_shape())) 136 | 137 | # Define the outputs 138 | outputs['logits'] = x 139 | 140 | with tf.variable_scope('pred'): 141 | y_prob = tf.nn.softmax(x) 142 | outputs['y_prob'] = y_prob 143 | 144 | y_ = tf.argmax(x, axis=-1) if num_classes > 1 \ 145 | else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32) 146 | outputs['y_'] = y_ 147 | 148 | return outputs 149 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/iFind_fetal.csv: -------------------------------------------------------------------------------- 1 | iFIND_id,image,segmentation 2 | iFIND00011,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00011.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00011.nii.gz 3 | iFIND00014,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00014.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00014.nii.gz 4 | iFIND00018,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00018.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00018.nii.gz 5 | iFIND00022,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00022.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00022.nii.gz 6 | iFIND00025,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00025.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00025.nii.gz 7 | iFIND00029,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00029.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00029.nii.gz 8 | iFIND00033,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00033.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00033.nii.gz 9 | iFIND00036,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00036.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00036.nii.gz 10 | iFIND00039,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00039.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00039.nii.gz 11 | iFIND00042,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00042.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00042.nii.gz 12 | iFIND00046,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00046.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00046.nii.gz 13 | iFIND00049,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00049.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00049.nii.gz 14 | iFIND00012,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00012.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00012.nii.gz 15 | iFIND00016,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00016.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00016.nii.gz 16 | iFIND00019,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00019.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00019.nii.gz 17 | iFIND00023,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00023.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00023.nii.gz 18 | iFIND00026,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00026.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00026.nii.gz 19 | iFIND00030,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00030.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00030.nii.gz 20 | iFIND00034,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00034.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00034.nii.gz 21 | iFIND00037,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00037.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00037.nii.gz 22 | iFIND00040,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00040.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00040.nii.gz 23 | iFIND00043,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00043.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00043.nii.gz 24 | iFIND00047,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00047.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00047.nii.gz 25 | iFIND00050,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00050.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00050.nii.gz 26 | iFIND00013,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00013.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00013.nii.gz 27 | iFIND00017,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00017.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00017.nii.gz 28 | iFIND00020,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00020.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00020.nii.gz 29 | iFIND00024,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00024.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00024.nii.gz 30 | iFIND00028,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00028.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00028.nii.gz 31 | iFIND00031,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00031.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00031.nii.gz 32 | iFIND00035,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00035.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00035.nii.gz 33 | iFIND00038,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00038.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00038.nii.gz 34 | iFIND00041,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00041.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00041.nii.gz 35 | iFIND00045,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00045.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00045.nii.gz 36 | iFIND00048,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00048.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00048.nii.gz 37 | iFIND00051,/vol/biomedic2/mrajchl/data/iFind2_db/imgs/iFIND00051.nii.gz,/vol/biomedic2/mrajchl/data/iFind2_db/seg/iFIND00051.nii.gz 38 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/reader.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | from dltk.io.augmentation import add_gaussian_offset, flip, extract_class_balanced_example_array 6 | from dltk.io.preprocessing import whitening 7 | 8 | 9 | def read_fn(file_references, mode, params=None): 10 | """A custom python read function for interfacing with nii image files. 11 | 12 | Args: 13 | file_references (list): A list of lists containing file references, 14 | such as [['id_0', 'image_filename_0', target_value_0], ..., 15 | ['id_N', 'image_filename_N', target_value_N]]. 16 | mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL 17 | or PREDICT. 18 | params (dict, optional): A dictionary to parameterise read_fn ouputs 19 | (e.g. reader_params = {'n_examples': 10, 'example_size': 20 | [64, 64, 64], 'extract_examples': True}, etc.). 21 | 22 | Yields: 23 | dict: A dictionary of reader outputs for dltk.io.abstract_reader. 24 | """ 25 | 26 | def _augment(img, lbl): 27 | """An image augmentation function.""" 28 | 29 | img = add_gaussian_offset(img, sigma=1.0) 30 | for a in range(3): 31 | [img, lbl] = flip([img, lbl], axis=a) 32 | 33 | return img, lbl 34 | 35 | def _map_labels(lbl, convert_to_protocol=False): 36 | """Map dataset specific label id protocols to consecutive integer ids 37 | for training and back. 38 | 39 | iFind segment ids: 40 | 0 background 41 | 2 brain 42 | 9 placenta 43 | 10 uterus ROI 44 | 45 | Args: 46 | lbl (np.array): A label map to be converted. 47 | convert_to_protocol (bool, optional) A flag to determine to convert 48 | from or to the protocol ids. 49 | 50 | Returns: 51 | np.array: The converted label map 52 | 53 | """ 54 | 55 | ids = [0, 2] 56 | 57 | out_lbl = np.zeros_like(lbl) 58 | 59 | if convert_to_protocol: 60 | 61 | # Map from consecutive ints to protocol labels 62 | for i in range(len(ids)): 63 | out_lbl[lbl == i] = ids[i] 64 | else: 65 | 66 | # Map from protocol labels to consecutive ints 67 | for i in range(len(ids)): 68 | out_lbl[lbl == ids[i]] = i 69 | 70 | return out_lbl 71 | 72 | for f in file_references: 73 | 74 | # Read the image nii with sitk 75 | img_id = f[0] 76 | img_fn = f[1] 77 | img_sitk = sitk.ReadImage(str(img_fn)) 78 | img = sitk.GetArrayFromImage(img_sitk) 79 | 80 | # Normalise volume image 81 | img = whitening(img) 82 | 83 | # Create a 4D image (i.e. [x, y, z, channels]) 84 | images = np.expand_dims(img, axis=-1).astype(np.float32) 85 | 86 | if mode == tf.estimator.ModeKeys.PREDICT: 87 | yield {'features': {'x': images}, 'labels': {'y': np.array([0])}, 'sitk': img_sitk, 'img_id': img_id} 88 | continue 89 | 90 | # Read the label nii with sitk 91 | lbl_fn = f[2] 92 | lbl = sitk.GetArrayFromImage(sitk.ReadImage(str(lbl_fn))).astype(np.int32) 93 | 94 | # Map the label ids to consecutive integers 95 | lbl = _map_labels(lbl) 96 | 97 | # Augment if used in training mode 98 | if mode == tf.estimator.ModeKeys.TRAIN: 99 | images, lbl = _augment(images, lbl) 100 | 101 | # Check if the reader is supposed to return training examples or 102 | # full images 103 | if params['extract_examples']: 104 | images, lbl = extract_class_balanced_example_array( 105 | images, 106 | lbl, 107 | example_size=params['example_size'], 108 | n_examples=params['n_examples'], classes=2) 109 | 110 | for e in range(params['n_examples']): 111 | yield {'features': {'x': images[e].astype(np.float32)}, 112 | 'labels': {'y': lbl[e].astype(np.int32)}} 113 | else: 114 | yield {'features': {'x': images}, 115 | 'labels': {'y': lbl}, 116 | 'sitk': img_sitk, 117 | 'img_id': img_id} 118 | return 119 | -------------------------------------------------------------------------------- /fetal_brain_segmentation_mri/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | 8 | import pandas as pd 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from dltk.core.metrics import dice 13 | from fetal_fcn import fetal_fcn_3d 14 | from dltk.io.abstract_reader import Reader 15 | 16 | from reader import read_fn 17 | 18 | 19 | EVAL_EVERY_N_STEPS = 500 20 | EVAL_STEPS = 1 21 | 22 | NUM_CLASSES = 2 23 | NUM_CHANNELS = 1 24 | 25 | BATCH_SIZE = 10 26 | SHUFFLE_CACHE_SIZE = 64 27 | 28 | MAX_STEPS = 50000 29 | 30 | 31 | def model_fn(features, labels, mode, params=None): 32 | """Model function to construct a tf.estimator.EstimatorSpec. It creates a 33 | network given input features (e.g. from a dltk.io.abstract_reader) and 34 | training targets (labels). Further, loss, optimiser, evaluation ops 35 | and custom tensorboard summary ops can be added. For additional 36 | information, please refer to 37 | https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#model_fn. 38 | 39 | Args: 40 | features (tf.Tensor): Tensor of input features to train from. Required 41 | rank and dimensions are determined by the subsequent ops 42 | (i.e. the network). 43 | labels (tf.Tensor): Tensor of training targets or labels. Required 44 | rank and dimensions are determined by the network output. 45 | mode (str): One of the tf.estimator.ModeKeys: TRAIN, EVAL or PREDICT 46 | params (dict, optional): A dictionary to parameterise the model_fn 47 | (e.g. learning_rate) 48 | 49 | Returns: 50 | tf.estimator.EstimatorSpec: A custom EstimatorSpec for this experiment 51 | """ 52 | 53 | # 1. create a model and its outputs 54 | net_output_ops = fetal_fcn_3d( 55 | inputs=features['x'], 56 | num_classes=NUM_CLASSES, 57 | mode=mode, 58 | kernel_regularizer=tf.contrib.layers.l2_regularizer(5e-4)) 59 | 60 | # 1.1 Generate predictions only (for `ModeKeys.PREDICT`) 61 | if mode == tf.estimator.ModeKeys.PREDICT: 62 | return tf.estimator.EstimatorSpec( 63 | mode=mode, 64 | predictions=net_output_ops, 65 | export_outputs={ 66 | 'out': tf.estimator.export.PredictOutput(net_output_ops)}) 67 | 68 | # 2. set up a loss function 69 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits( 70 | logits=net_output_ops['logits'], 71 | labels=labels['y']) 72 | loss = tf.reduce_mean(ce) 73 | 74 | # 3. define a training op and ops for updating moving averages 75 | # (i.e. for batch normalisation) 76 | global_step = tf.train.get_global_step() 77 | optimiser = tf.train.MomentumOptimizer( 78 | learning_rate=params["learning_rate"], 79 | momentum=0.9) 80 | 81 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 82 | with tf.control_dependencies(update_ops): 83 | train_op = optimiser.minimize(loss, global_step=global_step) 84 | 85 | # 4.1 (optional) create custom image summaries for tensorboard 86 | my_image_summaries = {} 87 | my_image_summaries['feat_0'] = features['x'][0, 1, :, :, 0] 88 | my_image_summaries['labels'] = tf.cast(labels['y'], tf.float32)[0, 1, :, :] 89 | my_image_summaries['predictions'] = tf.cast(net_output_ops['y_'], 90 | tf.float32)[0, 1, :, :] 91 | 92 | expected_output_size = [1, 128, 128, 1] # [B, W, H, C] 93 | [tf.summary.image(name, tf.reshape(image, expected_output_size)) 94 | for name, image in my_image_summaries.items()] 95 | 96 | # 4.2 (optional) create custom metric summaries for tensorboard 97 | dice_tensor = tf.py_func(dice, 98 | [net_output_ops['y_'], 99 | labels['y'], 100 | tf.constant(NUM_CLASSES)], tf.float32) 101 | 102 | [tf.summary.scalar('dsc_l{}'.format(i), dice_tensor[i]) 103 | for i in range(NUM_CLASSES)] 104 | 105 | # 5. Return EstimatorSpec object 106 | return tf.estimator.EstimatorSpec(mode=mode, 107 | predictions=net_output_ops, 108 | loss=loss, 109 | train_op=train_op, 110 | eval_metric_ops=None) 111 | 112 | 113 | def train(args): 114 | np.random.seed(42) 115 | tf.set_random_seed(42) 116 | 117 | print('Setting up...') 118 | 119 | # Parse csv files for file names 120 | all_filenames = pd.read_csv(args.train_csv, 121 | dtype=object, 122 | keep_default_na=False, 123 | na_values=[]).as_matrix() 124 | 125 | train_filenames = all_filenames[:30] 126 | val_filenames = all_filenames[30:] 127 | 128 | # Set up a data reader to handle the file i/o. 129 | reader_params = { 130 | 'n_examples': 20, 131 | 'example_size': [3, 128, 128], 132 | 'extract_examples': True} 133 | 134 | reader_example_shapes = { 135 | 'features': {'x': reader_params['example_size'] + [NUM_CHANNELS, ]}, 136 | 'labels': {'y': reader_params['example_size']}} 137 | 138 | reader = Reader( 139 | read_fn, 140 | {'features': {'x': tf.float32}, 141 | 'labels': {'y': tf.int32}}) 142 | 143 | # Get input functions and queue initialisation hooks for training and 144 | # validation data 145 | train_input_fn, train_qinit_hook = reader.get_inputs( 146 | file_references=train_filenames, 147 | mode=tf.estimator.ModeKeys.TRAIN, 148 | example_shapes=reader_example_shapes, 149 | batch_size=BATCH_SIZE, 150 | shuffle_cache_size=SHUFFLE_CACHE_SIZE, 151 | params=reader_params) 152 | 153 | val_input_fn, val_qinit_hook = reader.get_inputs( 154 | file_references=val_filenames, 155 | mode=tf.estimator.ModeKeys.EVAL, 156 | example_shapes=reader_example_shapes, 157 | batch_size=BATCH_SIZE, 158 | shuffle_cache_size=SHUFFLE_CACHE_SIZE, 159 | params=reader_params) 160 | 161 | # Instantiate the neural network estimator 162 | nn = tf.estimator.Estimator( 163 | model_fn=model_fn, 164 | model_dir=args.model_path, 165 | params={"learning_rate": 0.01}, 166 | config=tf.estimator.RunConfig()) 167 | 168 | # Hooks for validation summaries 169 | val_summary_hook = tf.contrib.training.SummaryAtEndHook( 170 | os.path.join(args.model_path, 'eval')) 171 | step_cnt_hook = tf.train.StepCounterHook( 172 | every_n_steps=EVAL_EVERY_N_STEPS, 173 | output_dir=args.model_path) 174 | 175 | print('Starting training...') 176 | try: 177 | for _ in range(MAX_STEPS // EVAL_EVERY_N_STEPS): 178 | nn.train( 179 | input_fn=train_input_fn, 180 | hooks=[train_qinit_hook, step_cnt_hook], 181 | steps=EVAL_EVERY_N_STEPS) 182 | 183 | if args.run_validation: 184 | results_val = nn.evaluate( 185 | input_fn=val_input_fn, 186 | hooks=[val_qinit_hook, val_summary_hook], 187 | steps=EVAL_STEPS * EVAL_STEPS) 188 | print('Step = {}; val loss = {:.5f};'.format( 189 | results_val['global_step'], results_val['loss'])) 190 | 191 | except KeyboardInterrupt: 192 | print('Stopping now.') 193 | finally: 194 | export_dir = nn.export_savedmodel( 195 | export_dir_base=args.model_path, 196 | serving_input_receiver_fn=reader.serving_input_receiver_fn( 197 | reader_example_shapes)) 198 | print('Model saved to {}.'.format(export_dir)) 199 | 200 | 201 | if __name__ == '__main__': 202 | 203 | # Set up argument parser 204 | parser = argparse.ArgumentParser(description='iFind2 fetal segmentation training script') 205 | parser.add_argument('--run_validation', default=True) 206 | parser.add_argument('--restart', default=False, action='store_true') 207 | parser.add_argument('--verbose', default=False, action='store_true') 208 | parser.add_argument('--cuda_devices', '-c', default='0') 209 | 210 | parser.add_argument('--model_path', '-p', default='/tmp/fetal_segmentation/') 211 | parser.add_argument('--train_csv', default='iFind_fetal.csv') 212 | 213 | args = parser.parse_args() 214 | 215 | # Set verbosity 216 | if args.verbose: 217 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 218 | tf.logging.set_verbosity(tf.logging.INFO) 219 | else: 220 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 221 | tf.logging.set_verbosity(tf.logging.ERROR) 222 | 223 | # GPU allocation options 224 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 225 | 226 | # Handle restarting and resuming training 227 | if args.restart: 228 | print('Restarting training from scratch.') 229 | os.system('rm -rf {}'.format(args.model_path)) 230 | 231 | if not os.path.isdir(args.model_path): 232 | os.system('mkdir -p {}'.format(args.model_path)) 233 | else: 234 | print('Resuming training on model_path {}'.format(args.model_path)) 235 | 236 | # Call training 237 | train(args) 238 | -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/logo.jpg -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | flake8-ignore = 3 | * E501 -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/README.md: -------------------------------------------------------------------------------- 1 | ## Abdomninal organ segmentation from 3D CT data 2 | 3 | ![Sample prediction on test data](pred.png) 4 | 5 | ### Contact and referencing this work 6 | If there are any issues please contact the corresponding author of this implementation. If you employ this model in your work, please refer to this citation of the [paper](https://arxiv.org/abs/1711.06853). 7 | ``` 8 | @article{pawlowski2017dltk, 9 | title={{DLTK: State of the Art Reference Implementations for Deep Learning on Medical Images}}, 10 | author={Pawlowski, Nick and Ktena, Sofia Ira and Lee, Matthew CH and Kainz, Bernhard and Rueckert, Daniel and Glocker, Ben and Rajchl, Martin}, 11 | journal={arXiv preprint arXiv:1711.06853}, 12 | year={2017} 13 | } 14 | ``` 15 | 16 | ### Important Notes 17 | - The original model was trained with DLTK 0.1 which had a slightly different U-Net implementation. We provide the original model and training scripts to closely replicate this. 18 | - We originally trained with `batch_size=8` which requires > 12GB GPU memory. We therefore decreased the batch size of the uploaded script to 4. 19 | 20 | ### Data 21 | The data can be downloaded after registration from the [challenge website](http://synapse.org/#!Synapse:syn3193805/wiki/217785). 22 | 23 | Images and segmentations are read from a csv file in the format below. The original files (*.csv) is provided in this repo. 24 | 25 | These are parsed and extract tf.Tensor examples for training and evaluation in `reader.py` using a [SimpleITK](http://www.simpleitk.org/) for i/o of the .nii files. 26 | 27 | 28 | ### Usage 29 | You can use the code (train.py) to train the model on the data yourself. Alternatively, we provide pretrained models here: 30 | - [original submission based on DLTK 0.1](https://www.doc.ic.ac.uk/~np716/dltk_models/ct_synapse/orig_unet.tar.gz) 31 | - [DLTK 0.2 asymetric U-Net](https://www.doc.ic.ac.uk/~np716/dltk_models/ct_synapse/asym_unet_balce_mom.tar.gz) 32 | - [DLTK 0.2 FCN](https://www.doc.ic.ac.uk/~np716/dltk_models/ct_synapse/fcn_balce.tar.gz) 33 | 34 | #### Data Preprocessing 35 | 36 | Use `preprocessing.py` for data preprocessing. You should call it as 37 | ``` 38 | python preprocessing.py -d path/to/training_data -p /path/to/save/processed/data 39 | ``` 40 | for the training data and 41 | ``` 42 | python preprocessing.py -d path/to/test_data -p /path/to/save/processed/data -n -s 43 | ``` 44 | for the test data. This will generate the csv files for training and testing. 45 | 46 | #### Training 47 | 48 | You can start a basic training with 49 | ``` 50 | python train.py -c CUDA_DEVICE 51 | ``` 52 | that will load the file paths from the previously created csvs and saves the model to `/tmp/synapse_ct_seg`. For more settings you can change `config.json`*[]: 53 | 54 | #### Deploy 55 | 56 | To deploy a model and run inference, run the deploy.py script and point to the model save_path: 57 | 58 | ``` 59 | python deploy.py -p path/to/saved/model -e path/to/save/predictions -c CUDA_DEVICE --csv CSV_WITH_FILES_TO_PREDICT 60 | ``` 61 | 62 | Please note, that this implementation imports saved models via [tf.estimator.Estimator.export_savedmodel](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_savedmodel) and during deploy parses the path to the saved model and not to the save path specified for [tf.estimator.Estimator.export_savedmodel](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_savedmodel). For prediction of data without labels use the `-n` flag 63 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "asym_unet", 3 | "opt": "momentum", 4 | "filters": [16, 64, 128, 256, 512], 5 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 6 | "num_residual_units": 3, 7 | "loss": "balce", 8 | "learning_rate": 0.001 9 | } 10 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config_asym_unet_balce.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "asym_unet", 3 | "opt": "momentum", 4 | "filters": [16, 64, 128, 256, 512], 5 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 6 | "num_residual_units": 3, 7 | "loss": "balce", 8 | "learning_rate": 0.001 9 | } 10 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config_asym_unet_ce.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "asym_unet", 3 | "opt": "rmsprop", 4 | "filters": [16, 64, 128, 256, 512], 5 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 6 | "num_residual_units": 3, 7 | "loss": "ce", 8 | "learning_rate": 0.001 9 | } 10 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config_fcn_balce.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "fcn", 3 | "filters": [16, 64, 128, 256, 512], 4 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 5 | "num_residual_units": 3, 6 | "loss": "balce", 7 | "learning_rate": 0.001 8 | } 9 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config_fcn_ce.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "fcn", 3 | "filters": [16, 64, 128, 256, 512], 4 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 5 | "num_residual_units": 3, 6 | "loss": "ce", 7 | "learning_rate": 0.001 8 | } 9 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config_unet_balce.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "unet", 3 | "filters": [16, 64, 128, 256, 512], 4 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 5 | "num_residual_units": 3, 6 | "loss": "balce", 7 | "learning_rate": 0.001 8 | } 9 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/config_unet_ce.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": "unet", 3 | "filters": [16, 64, 128, 256, 512], 4 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 1, 1]], 5 | "num_residual_units": 3, 6 | "loss": "ce", 7 | "learning_rate": 0.001 8 | } 9 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/deploy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import tensorflow as tf 12 | import SimpleITK as sitk 13 | 14 | from tensorflow.contrib import predictor 15 | 16 | from dltk.core import metrics as metrics 17 | 18 | from dltk.utils import sliding_window_segmentation_inference 19 | 20 | 21 | from reader import read_fn 22 | 23 | READER_PARAMS = {'extract_examples': False} 24 | 25 | 26 | def predict(args): 27 | # Read in the csv with the file names you would want to predict on 28 | file_names = pd.read_csv( 29 | args.csv, 30 | dtype=object, 31 | keep_default_na=False, 32 | na_values=[]).as_matrix() 33 | 34 | print('Loading from {}'.format(args.model_path)) 35 | my_predictor = predictor.from_saved_model(args.model_path) 36 | 37 | # Fetch the output probability op of the trained network 38 | y_prob = my_predictor._fetch_tensors['y_prob'] 39 | print('Got y_prob as {}'.format(y_prob)) 40 | num_classes = y_prob.get_shape().as_list()[-1] 41 | 42 | mode = (tf.estimator.ModeKeys.PREDICT if args.predict_only 43 | else tf.estimator.ModeKeys.EVAL) 44 | 45 | # Iterate through the files, predict on the full volumes and compute a Dice 46 | # coefficient 47 | for output in read_fn(file_references=file_names, 48 | mode=mode, 49 | params=READER_PARAMS): 50 | t0 = time.time() 51 | 52 | # Parse the read function output and add a dummy batch dimension as 53 | # required 54 | img = np.expand_dims(output['features']['x'], axis=0) 55 | 56 | print('running inference on {} with img {} and op {}'.format( 57 | my_predictor._feed_tensors['x'], img.shape, y_prob)) 58 | # Do a sliding window inference with our DLTK wrapper 59 | pred = sliding_window_segmentation_inference( 60 | session=my_predictor.session, 61 | ops_list=[y_prob], 62 | sample_dict={my_predictor._feed_tensors['x']: img}, 63 | batch_size=32)[0] 64 | 65 | # Calculate the prediction from the probabilities 66 | pred = np.argmax(pred, -1) 67 | 68 | if not args.predict_only: 69 | lbl = np.expand_dims(output['labels']['y'], axis=0) 70 | # Calculate the Dice coefficient 71 | dsc = metrics.dice(pred, lbl, num_classes)[1:].mean() 72 | 73 | # Save the file as .nii.gz using the header information from the 74 | # original sitk image 75 | output_fn = os.path.join( 76 | args.export_path, '{}_seg.nii.gz'.format(output['img_name'])) 77 | 78 | new_sitk = sitk.GetImageFromArray(pred[0].astype(np.int32)) 79 | new_sitk.CopyInformation(output['sitk']) 80 | 81 | sitk.WriteImage(new_sitk, output_fn) 82 | 83 | if args.predict_only: 84 | print('Id={}; time={:0.2} secs; output_path={};'.format( 85 | output['img_name'], time.time() - t0, output_fn)) 86 | else: 87 | # Print outputs 88 | print( 89 | 'Id={}; Dice={:0.4f} time={:0.2} secs; output_path={};'.format( 90 | output['img_name'], dsc, time.time() - t0, output_fn)) 91 | 92 | 93 | if __name__ == '__main__': 94 | # Set up argument parser 95 | parser = argparse.ArgumentParser( 96 | description='Synapse MultiAtlas example segmentation deploy script') 97 | parser.add_argument('--verbose', default=False, action='store_true') 98 | parser.add_argument('--predict_only', '-n', default=False, 99 | action='store_true') 100 | parser.add_argument('--cuda_devices', '-c', default='0') 101 | 102 | parser.add_argument('--model_path', '-p', default='/tmp/synapse_ct_seg/') 103 | parser.add_argument('--export_path', '-e', default='/tmp/synapse_ct_seg/') 104 | parser.add_argument('--csv', default='train.csv') 105 | 106 | args = parser.parse_args() 107 | 108 | # Set verbosity 109 | if args.verbose: 110 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 111 | tf.logging.set_verbosity(tf.logging.INFO) 112 | else: 113 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 114 | tf.logging.set_verbosity(tf.logging.ERROR) 115 | 116 | # GPU allocation options 117 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 118 | 119 | # Call training 120 | predict(args) 121 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/synapse_btcv_abdominal_ct_segmentation/pred.png -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/preprocessing.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import pandas as pd 4 | import SimpleITK as sitk 5 | import numpy as np 6 | import os 7 | import argparse 8 | 9 | 10 | def resample_img(itk_image, out_spacing=[2.0, 2.0, 2.0], is_label=False): 11 | # resample images to 2mm spacing with simple itk 12 | 13 | original_spacing = itk_image.GetSpacing() 14 | original_size = itk_image.GetSize() 15 | 16 | out_size = [ 17 | int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))), 18 | int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))), 19 | int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))] 20 | 21 | resample = sitk.ResampleImageFilter() 22 | resample.SetOutputSpacing(out_spacing) 23 | resample.SetSize(out_size) 24 | resample.SetOutputDirection(itk_image.GetDirection()) 25 | resample.SetOutputOrigin(itk_image.GetOrigin()) 26 | resample.SetTransform(sitk.Transform()) 27 | resample.SetDefaultPixelValue(itk_image.GetPixelIDValue()) 28 | 29 | if is_label: 30 | resample.SetInterpolator(sitk.sitkNearestNeighbor) 31 | else: 32 | resample.SetInterpolator(sitk.sitkBSpline) 33 | 34 | return resample.Execute(itk_image) 35 | 36 | 37 | def normalise(itk_image): 38 | # normalise and clip images 39 | 40 | np_img = sitk.GetArrayFromImage(itk_image) 41 | np_img = np.clip(np_img, -1000., 800.).astype(np.float32) 42 | np_img = (np_img + 1000.) / 900. - 1. 43 | s_itk_image = sitk.GetImageFromArray(np_img) 44 | s_itk_image.CopyInformation(itk_image) 45 | return s_itk_image 46 | 47 | 48 | def split_data(files, path, no_split=True, no_label=False): 49 | if no_split: 50 | # use this for test or so 51 | imgs = [os.path.join(path, 'img', 'img{}.nii.gz'.format(f)) 52 | for f in files] 53 | 54 | if not no_label: 55 | lbls = [os.path.join(path, 'label', 'label{}.nii.gz'.format(f)) 56 | for f in files] 57 | 58 | pd.DataFrame(data={'imgs': imgs, 'lbls': lbls}).to_csv( 59 | 'no_split.csv', index=False) 60 | else: 61 | pd.DataFrame(data={'imgs': imgs}).to_csv( 62 | 'no_split.csv', index=False) 63 | else: 64 | # split train data into train and val 65 | rng = np.random.RandomState(42) 66 | ids = [f[3:7] for f in files] 67 | validation = rng.choice(ids, 7) 68 | train = [f for f in ids if f not in validation] 69 | 70 | train_imgs = [os.path.join(path, 'img', 'img{}.nii.gz'.format(f)) 71 | for f in train] 72 | if not no_label: 73 | train_lbls = [os.path.join( 74 | path, 'label', 'label{}.nii.gz'.format(f)) for f in train] 75 | 76 | pd.DataFrame(data={'imgs': train_imgs, 'lbls': train_lbls}).to_csv( 77 | 'train.csv', index=False) 78 | else: 79 | pd.DataFrame(data={'imgs': train_imgs}).to_csv( 80 | 'train.csv', index=False) 81 | 82 | val_imgs = [os.path.join(path, 'img', 'img{}.nii.gz'.format(f)) 83 | for f in validation] 84 | if not no_label: 85 | val_lbls = [os.path.join(path, 'label', 'label{}.nii.gz'.format(f)) 86 | for f in validation] 87 | 88 | pd.DataFrame(data={'imgs': val_imgs, 'lbls': val_lbls}).to_csv( 89 | 'val.csv', index=False) 90 | else: 91 | pd.DataFrame(data={'imgs': val_imgs}).to_csv( 92 | 'val.csv', index=False) 93 | 94 | 95 | def preprocess(args): 96 | files = os.listdir(os.path.join(args.data_path, 'img')) 97 | 98 | split_data(files, args.output_path, args.no_split, args.no_label) 99 | 100 | if not os.path.exists(os.path.join(args.output_path, 'img')): 101 | os.makedirs(os.path.join(args.output_path, 'img')) 102 | 103 | if not args.no_label: 104 | if not os.path.exists(os.path.join(args.output_path, 'label')): 105 | os.makedirs(os.path.join(args.output_path, 'label')) 106 | 107 | for f in files: 108 | fid = f[3:7] 109 | f1 = os.path.join(args.data_path, 'img', f) 110 | 111 | nii_f1 = sitk.ReadImage(f1) 112 | res_nii_f1 = resample_img(nii_f1) 113 | scaled = normalise(res_nii_f1) 114 | sitk.WriteImage(scaled, os.path.join(args.output_path, 'img', f)) 115 | 116 | if not args.no_label: 117 | l1 = os.path.join( 118 | args.data_path, 'label', 'label{}.nii.gz'.format(fid)) 119 | nii_l1 = sitk.ReadImage(l1) 120 | res_nii_l1 = resample_img(nii_l1, is_label=True) 121 | sitk.WriteImage(res_nii_l1, os.path.join( 122 | args.output_path, 'label', 'label{}.nii.gz'.format(fid))) 123 | 124 | 125 | if __name__ == '__main__': 126 | # Set up argument parser 127 | parser = argparse.ArgumentParser( 128 | description='Example: Synapse CT example preprocessing script') 129 | 130 | parser.add_argument('--data_path', '-d') 131 | 132 | parser.add_argument('--output_path', '-p') 133 | 134 | parser.add_argument('--no_split', '-s', default=False, action='store_true') 135 | 136 | parser.add_argument('--no_label', '-n', default=False, action='store_true') 137 | 138 | args = parser.parse_args() 139 | 140 | # Call training 141 | preprocess(args) 142 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/reader.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | from dltk.io.augmentation import extract_class_balanced_example_array 6 | 7 | 8 | def read_fn(file_references, mode, params=None): 9 | """Summary 10 | 11 | Args: 12 | file_references (TYPE): Description 13 | mode (TYPE): Description 14 | params (TYPE): Description 15 | 16 | Returns: 17 | TYPE: Description 18 | """ 19 | for f in file_references: 20 | img_fn = str(f[0]) 21 | 22 | img_name = img_fn.split('/')[-1].split('.')[0] 23 | 24 | # Use a SimpleITK reader to load the multi channel 25 | # nii images and labels for training 26 | img_sitk = sitk.ReadImage(img_fn) 27 | images = sitk.GetArrayFromImage(img_sitk) 28 | 29 | images = np.expand_dims(images, axis=3) 30 | 31 | if mode == tf.estimator.ModeKeys.PREDICT: 32 | yield {'features': {'x': images}, 'labels': None, 33 | 'img_name': img_name, 'sitk': img_sitk} 34 | else: 35 | lbl_fn = str(f[1]) 36 | lbl = sitk.GetArrayFromImage( 37 | sitk.ReadImage(lbl_fn)).astype(np.int32) 38 | 39 | # Augment if used in training mode 40 | if mode == tf.estimator.ModeKeys.TRAIN: 41 | pass 42 | 43 | # Check if the reader is supposed to return 44 | # training examples or full images 45 | if params['extract_examples']: 46 | n_examples = params['n_examples'] 47 | example_size = params['example_size'] 48 | 49 | images, lbl = extract_class_balanced_example_array( 50 | images, lbl, example_size=example_size, 51 | n_examples=n_examples, classes=14) 52 | 53 | for e in range(len(images)): 54 | yield {'features': {'x': images[e].astype(np.float32)}, 55 | 'labels': {'y': lbl[e].astype(np.int32)}} 56 | else: 57 | yield {'features': {'x': images}, 'labels': {'y': lbl}, 58 | 'img_name': img_name, 'sitk': img_sitk} 59 | 60 | return 61 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/test.csv: -------------------------------------------------------------------------------- 1 | id 2 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0063.nii.gz 3 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0065.nii.gz 4 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0067.nii.gz 5 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0078.nii.gz 6 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0077.nii.gz 7 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0066.nii.gz 8 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0074.nii.gz 9 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0073.nii.gz 10 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0070.nii.gz 11 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0071.nii.gz 12 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0072.nii.gz 13 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0061.nii.gz 14 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0079.nii.gz 15 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0080.nii.gz 16 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0068.nii.gz 17 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0062.nii.gz 18 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0075.nii.gz 19 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0076.nii.gz 20 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0069.nii.gz 21 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed_testing/img/img0064.nii.gz 22 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/train.csv: -------------------------------------------------------------------------------- 1 | imgs,lbls 2 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0001.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0001.nii.gz 3 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0031.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0031.nii.gz 4 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0026.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0026.nii.gz 5 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0005.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0005.nii.gz 6 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0023.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0023.nii.gz 7 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0022.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0022.nii.gz 8 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0007.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0007.nii.gz 9 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0030.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0030.nii.gz 10 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0029.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0029.nii.gz 11 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0002.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0002.nii.gz 12 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0006.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0006.nii.gz 13 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0003.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0003.nii.gz 14 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0027.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0027.nii.gz 15 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0037.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0037.nii.gz 16 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0025.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0025.nii.gz 17 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0039.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0039.nii.gz 18 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0034.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0034.nii.gz 19 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0035.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0035.nii.gz 20 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0024.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0024.nii.gz 21 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0038.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0038.nii.gz 22 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0009.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0009.nii.gz 23 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0040.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0040.nii.gz 24 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0010.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0010.nii.gz 25 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0008.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0008.nii.gz 26 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import tensorflow as tf 11 | 12 | from dltk.core.metrics import dice 13 | from dltk.core.losses import sparse_balanced_crossentropy 14 | from dltk.networks.segmentation.unet import residual_unet_3d 15 | from dltk.networks.segmentation.unet import asymmetric_residual_unet_3d 16 | from dltk.networks.segmentation.fcn import residual_fcn_3d 17 | from dltk.core.activations import leaky_relu 18 | from dltk.io.abstract_reader import Reader 19 | from reader import read_fn 20 | import json 21 | 22 | # PARAMS 23 | EVAL_EVERY_N_STEPS = 1000 24 | EVAL_STEPS = 1 25 | 26 | NUM_CLASSES = 14 27 | NUM_CHANNELS = 1 28 | 29 | BATCH_SIZE = 4 30 | SHUFFLE_CACHE_SIZE = 128 31 | 32 | MAX_STEPS = 100000 33 | 34 | 35 | # MODEL 36 | def model_fn(features, labels, mode, params): 37 | """Summary 38 | 39 | Args: 40 | features (TYPE): Description 41 | labels (TYPE): Description 42 | mode (TYPE): Description 43 | params (TYPE): Description 44 | 45 | Returns: 46 | TYPE: Description 47 | """ 48 | # 1. create a model and its outputs 49 | 50 | filters = params["filters"] 51 | strides = params["strides"] 52 | num_residual_units = params["num_residual_units"] 53 | loss_type = params["loss"] 54 | net = params["net"] 55 | 56 | def lrelu(x): 57 | return leaky_relu(x, 0.1) 58 | 59 | if net == 'fcn': 60 | net_output_ops = residual_fcn_3d( 61 | features['x'], NUM_CLASSES, 62 | num_res_units=num_residual_units, 63 | filters=filters, 64 | strides=strides, 65 | activation=lrelu, 66 | mode=mode) 67 | elif net == 'unet': 68 | net_output_ops = residual_unet_3d( 69 | features['x'], NUM_CLASSES, 70 | num_res_units=num_residual_units, 71 | filters=filters, 72 | strides=strides, 73 | activation=lrelu, 74 | mode=mode) 75 | elif net == 'asym_unet': 76 | net_output_ops = asymmetric_residual_unet_3d( 77 | features['x'], 78 | NUM_CLASSES, 79 | num_res_units=num_residual_units, 80 | filters=filters, 81 | strides=strides, 82 | activation=lrelu, 83 | mode=mode) 84 | 85 | # 1.1 Generate predictions only (for `ModeKeys.PREDICT`) 86 | if mode == tf.estimator.ModeKeys.PREDICT: 87 | return tf.estimator.EstimatorSpec( 88 | mode=mode, predictions=net_output_ops, 89 | export_outputs={'out': tf.estimator.export.PredictOutput( 90 | net_output_ops)}) 91 | 92 | # 2. set up a loss function 93 | if loss_type == 'ce': 94 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits( 95 | logits=net_output_ops['logits'], labels=labels['y']) 96 | loss = tf.reduce_mean(ce) 97 | elif loss_type == 'balce': 98 | loss = sparse_balanced_crossentropy( 99 | net_output_ops['logits'], labels['y']) 100 | 101 | # 3. define a training op and ops for updating 102 | # moving averages (i.e. for batch normalisation) 103 | global_step = tf.train.get_global_step() 104 | if params["opt"] == 'adam': 105 | optimiser = tf.train.AdamOptimizer( 106 | learning_rate=params["learning_rate"], epsilon=1e-5) 107 | elif params["opt"] == 'momentum': 108 | optimiser = tf.train.MomentumOptimizer( 109 | learning_rate=params["learning_rate"], momentum=0.9) 110 | elif params["opt"] == 'rmsprop': 111 | optimiser = tf.train.RMSPropOptimizer( 112 | learning_rate=params["learning_rate"], momentum=0.9) 113 | 114 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 115 | with tf.control_dependencies(update_ops): 116 | train_op = optimiser.minimize(loss, global_step=global_step) 117 | 118 | # 4.1 (optional) create custom image summaries for tensorboard 119 | my_image_summaries = {} 120 | my_image_summaries['feat_t1'] = tf.expand_dims( 121 | features['x'][:, 0, :, :, 0], 3) 122 | my_image_summaries['labels'] = tf.expand_dims( 123 | tf.cast(labels['y'], tf.float32)[:, 0, :, :], 3) 124 | my_image_summaries['predictions'] = tf.expand_dims( 125 | tf.cast(net_output_ops['y_'], tf.float32)[:, 0, :, :], 3) 126 | 127 | [tf.summary.image(name, image) 128 | for name, image in my_image_summaries.items()] 129 | 130 | # 4.2 (optional) create custom metric summaries for tensorboard 131 | dice_tensor = tf.py_func( 132 | dice, [net_output_ops['y_'], labels['y'], 133 | tf.constant(NUM_CLASSES)], tf.float32) 134 | 135 | [tf.summary.scalar('dsc_l{}'.format(i), dice_tensor[i]) 136 | for i in range(NUM_CLASSES)] 137 | 138 | # 5. Return EstimatorSpec object 139 | return tf.estimator.EstimatorSpec( 140 | mode=mode, predictions=net_output_ops, 141 | loss=loss, train_op=train_op, 142 | eval_metric_ops=None) 143 | 144 | 145 | def train(args): 146 | np.random.seed(42) 147 | tf.set_random_seed(42) 148 | 149 | print('Setting up...') 150 | 151 | with open(args.config) as f: 152 | run_config = json.load(f) 153 | 154 | # Parse csv files for file names 155 | train_filenames = pd.read_csv( 156 | args.train_csv, dtype=object, keep_default_na=False, 157 | na_values=[]).as_matrix() 158 | 159 | val_filenames = pd.read_csv( 160 | args.val_csv, dtype=object, keep_default_na=False, 161 | na_values=[]).as_matrix() 162 | 163 | # Set up a data reader to handle the file i/o. 164 | reader_params = { 165 | 'n_examples': 32, 166 | 'example_size': [64, 64, 64], 167 | 'extract_examples': True 168 | } 169 | 170 | reader_example_shapes = { 171 | 'features': {'x': reader_params['example_size'] + [NUM_CHANNELS, ]}, 172 | 'labels': {'y': reader_params['example_size']}} 173 | 174 | reader = Reader(read_fn, {'features': {'x': tf.float32}, 175 | 'labels': {'y': tf.int32}}) 176 | 177 | # Get input functions and queue initialisation hooks 178 | # for training and validation data 179 | train_input_fn, train_qinit_hook = reader.get_inputs( 180 | train_filenames, 181 | tf.estimator.ModeKeys.TRAIN, 182 | example_shapes=reader_example_shapes, 183 | batch_size=BATCH_SIZE, 184 | shuffle_cache_size=SHUFFLE_CACHE_SIZE, 185 | params=reader_params) 186 | 187 | val_input_fn, val_qinit_hook = reader.get_inputs( 188 | val_filenames, 189 | tf.estimator.ModeKeys.EVAL, 190 | example_shapes=reader_example_shapes, 191 | batch_size=BATCH_SIZE, 192 | shuffle_cache_size=min(SHUFFLE_CACHE_SIZE, EVAL_STEPS), 193 | params=reader_params) 194 | 195 | config = tf.ConfigProto() 196 | # config.gpu_options.allow_growth = True 197 | 198 | # Instantiate the neural network estimator 199 | nn = tf.estimator.Estimator( 200 | model_fn=model_fn, 201 | model_dir=args.save_path, 202 | params=run_config, 203 | config=tf.estimator.RunConfig(session_config=config)) 204 | 205 | # Hooks for validation summaries 206 | val_summary_hook = tf.contrib.training.SummaryAtEndHook( 207 | os.path.join(args.save_path, 'eval')) 208 | step_cnt_hook = tf.train.StepCounterHook( 209 | every_n_steps=EVAL_EVERY_N_STEPS, output_dir=args.save_path) 210 | 211 | print('Starting training...') 212 | try: 213 | for _ in range(MAX_STEPS // EVAL_EVERY_N_STEPS): 214 | nn.train( 215 | input_fn=train_input_fn, 216 | hooks=[train_qinit_hook, step_cnt_hook], 217 | steps=EVAL_EVERY_N_STEPS) 218 | 219 | results_val = nn.evaluate( 220 | input_fn=val_input_fn, 221 | hooks=[val_qinit_hook, val_summary_hook], 222 | steps=EVAL_STEPS) 223 | print('Step = {}; val loss = {:.5f};'.format( 224 | results_val['global_step'], results_val['loss'])) 225 | 226 | except KeyboardInterrupt: 227 | pass 228 | 229 | print('Stopping now.') 230 | export_dir = nn.export_savedmodel( 231 | export_dir_base=args.save_path, 232 | serving_input_receiver_fn=reader.serving_input_receiver_fn(reader_example_shapes)) 233 | print('Model saved to {}.'.format(export_dir)) 234 | 235 | 236 | if __name__ == '__main__': 237 | # Set up argument parser 238 | parser = argparse.ArgumentParser( 239 | description='Example: Synapse CT example segmentation training script') 240 | parser.add_argument('--resume', default=False, action='store_true') 241 | parser.add_argument('--verbose', default=False, action='store_true') 242 | parser.add_argument('--cuda_devices', '-c', default='0') 243 | 244 | parser.add_argument('--save_path', '-p', default='/tmp/synapse_ct_seg/') 245 | parser.add_argument('--train_csv', default='train.csv') 246 | parser.add_argument('--val_csv', default='val.csv') 247 | parser.add_argument('--config', default="config.json") 248 | 249 | args = parser.parse_args() 250 | 251 | # Set verbosity 252 | if args.verbose: 253 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 254 | tf.logging.set_verbosity(tf.logging.INFO) 255 | else: 256 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 257 | tf.logging.set_verbosity(tf.logging.ERROR) 258 | 259 | # GPU allocation options 260 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 261 | 262 | # Create model save path 263 | os.system("rm -rf %s" % args.save_path) 264 | os.system("mkdir -p %s" % args.save_path) 265 | 266 | # Call training 267 | train(args) 268 | -------------------------------------------------------------------------------- /synapse_btcv_abdominal_ct_segmentation/val.csv: -------------------------------------------------------------------------------- 1 | imgs,lbls 2 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0004.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0004.nii.gz 3 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0032.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0032.nii.gz 4 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0036.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0036.nii.gz 5 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0021.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0021.nii.gz 6 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0033.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0033.nii.gz 7 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0028.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0028.nii.gz 8 | /vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/img/img0036.nii.gz,/vol/biomedic/users/np716/data/SynapseAbdominalCT/preprocessed/label/label0036.nii.gz 9 | -------------------------------------------------------------------------------- /ukbb_cardiac_segmentation_cine_mri/README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | Code for segmenting cardiovascular magnetic resonance (CMR) images from the [UK Biobank Imaging Study](http://imaging.ukbiobank.ac.uk/) using fully convolutional networks. 4 | 5 | **Note** This repository only contains the code, not the imaging data. To know more about how to access the UK Biobank imaging data, please go to the [UK Biobank Imaging Study](http://imaging.ukbiobank.ac.uk/) website. Researchers can [apply](http://www.ukbiobank.ac.uk/register-apply/) to use the UK Biobank data resource for health-related research in the public interest. 6 | 7 | ## Usage 8 | 9 | **A quick demo** You can run a quick demo: 10 | ``` 11 | python3 demo.py 12 | ``` 13 | There is one parameter in the script, *CUDA_VISIBLE_DEVICES*, which controls which GPU device to use on your machine. Currently, I set it to 0, which means the first GPU on your machine. 14 | 15 | This script will automatically download [two exemplar short-axis cardiac MR images](https://www.doc.ic.ac.uk/~wbai/data/ukbb_cardiac/demo_image/) and [a pre-trained network](https://www.doc.ic.ac.uk/~wbai/data/ukbb_cardiac/trained_model/), then segment the left and right ventricles using the network, saving the segmentation results *seg_sa.nii.gz* and also saving the clinical measures in a spreadsheet *clinical_measure.csv*, including the left ventricular end-diastolic volume (LVEDV), end-systolic volume (LVESV), myocardial mass (LVM) and the right ventricular end-diastolic volume (RVEDV), end-systolic volume (RVESV). The script will also download exemplar long-axis cardiac MR images and segment the left and right atria. 16 | 17 | **To know more** If you want to know more about how the network works and how it is trained, you can read these following files: 18 | * network.py, which describes the neural network architecture; 19 | * train_network.py, which trains a network on a dataset with both images and manual annotations; 20 | * deploy_network.py, which deploys the trained network onto new images. If you are interested in deploying the pre-trained network to more UK Biobank cardiac image set, this is the file that you need to read. 21 | 22 | ## References 23 | 24 | We would like to thank all the UK Biobank participants and staff who make the CMR imaging dataset possible and also people from Queen Mary's University London and Oxford University who performed the hard work of manual annotation. In case you find the toolbox or a certain part of it useful, please consider giving appropriate credit to it by citing one or some of the papers here, which respectively describes the segmentation method [1] and the manual annotation dataset [2]. Thanks. 25 | 26 | [1] W. Bai, et al. Human-level CMR image analysis with deep fully convolutional networks. arXiv:1710.09289. [arxiv](https://arxiv.org/abs/1710.09289) 27 | 28 | [2] S. Petersen, et al. Reference ranges for cardiac structure and function using cardiovascular magnetic resonance (CMR) in Caucasians from the UK Biobank population cohort. Journal of Cardiovascular Magnetic Resonance, 19:18, 2017. [doi](https://doi.org/10.1186/s12968-017-0327-9) -------------------------------------------------------------------------------- /ukbb_cardiac_segmentation_cine_mri/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | This script demonstrates the segmentation of a test cardiac MR image using 17 | a pre-trained neural network. 18 | """ 19 | import os 20 | import urllib.request 21 | 22 | 23 | if __name__ == '__main__': 24 | # Set the GPU device id 25 | CUDA_VISIBLE_DEVICES = 0 26 | 27 | # Set seq_name to 'sa', 'la_2ch' or 'la_4ch' for different imaging sequence 28 | for seq_name in ['sa', 'la_2ch', 'la_4ch']: 29 | print('Demo for {0} imaging sequence ...'.format(seq_name)) 30 | 31 | # Download exemplar images 32 | URL = 'https://www.doc.ic.ac.uk/~wbai/data/ukbb_cardiac/' 33 | print('Downloading images ...') 34 | for i in [1, 2]: 35 | if not os.path.exists('demo_image/{0}'.format(i)): 36 | os.makedirs('demo_image/{0}'.format(i)) 37 | f = 'demo_image/{0}/{1}.nii.gz'.format(i, seq_name) 38 | urllib.request.urlretrieve(URL + f, f) 39 | 40 | # Download the trained network 41 | print('Downloading the trained network ...') 42 | if not os.path.exists('trained_model'): 43 | os.makedirs('trained_model') 44 | for f in ['trained_model/FCN_{0}.meta'.format(seq_name), 45 | 'trained_model/FCN_{0}.index'.format(seq_name), 46 | 'trained_model/FCN_{0}.data-00000-of-00001'.format(seq_name)]: 47 | urllib.request.urlretrieve(URL + f, f) 48 | 49 | # Perform segmentation 50 | print('Performing segmentation ...') 51 | os.system('CUDA_VISIBLE_DEVICES={0} python3 deploy_network.py ' 52 | '--test_dir demo_image --dest_dir demo_image ' 53 | '--seq_name {1} --model_path trained_model/FCN_{1} ' 54 | '--process_seq --clinical_measures'.format(CUDA_VISIBLE_DEVICES, 55 | seq_name)) 56 | print('Done.') 57 | -------------------------------------------------------------------------------- /ukbb_cardiac_segmentation_cine_mri/deploy_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import os 16 | import time 17 | import math 18 | import numpy as np 19 | import nibabel as nib 20 | import pandas as pd 21 | import tensorflow as tf 22 | from image_utils import rescale_intensity 23 | 24 | 25 | """ Deployment parameters """ 26 | FLAGS = tf.app.flags.FLAGS 27 | tf.app.flags._global_parser.add_argument('--seq_name', 28 | choices=['sa', 'la_2ch', 'la_4ch'], 29 | default='sa', help="Sequence name.") 30 | tf.app.flags.DEFINE_string('test_dir', '/vol/biomedic2/wbai/tmp/github/test', 31 | 'Path to the test set directory, under which images ' 32 | 'are organised in subdirectories for each subject.') 33 | tf.app.flags.DEFINE_string('dest_dir', '/vol/biomedic2/wbai/tmp/github/output', 34 | 'Path to the destination directory, where the ' 35 | 'segmentations will be saved.') 36 | tf.app.flags.DEFINE_string('model_path', 37 | '/vol/biomedic2/wbai/tmp/github/model/FCN_sa.ckpt-50000', 38 | 'Path to the saved trained model.') 39 | tf.app.flags.DEFINE_boolean('process_seq', True, 40 | 'Process a time sequence of images.') 41 | tf.app.flags.DEFINE_boolean('save_seg', True, 42 | 'Save segmentation.') 43 | tf.app.flags.DEFINE_boolean('clinical_measure', True, 44 | 'Calculate clinical measures.') 45 | 46 | 47 | if __name__ == '__main__': 48 | with tf.Session() as sess: 49 | sess.run(tf.global_variables_initializer()) 50 | 51 | # Import the computation graph and restore the variable values 52 | saver = tf.train.import_meta_graph('{0}.meta'.format(FLAGS.model_path)) 53 | saver.restore(sess, '{0}'.format(FLAGS.model_path)) 54 | 55 | print('Start evaluating on the test set ...') 56 | start_time = time.time() 57 | 58 | # Process each subject subdirectory 59 | data_list = sorted(os.listdir(FLAGS.test_dir)) 60 | processed_list = [] 61 | table = [] 62 | table_time = [] 63 | for data in data_list: 64 | print(data) 65 | data_dir = os.path.join(FLAGS.test_dir, data) 66 | 67 | if FLAGS.process_seq: 68 | # Process the temporal sequence 69 | image_name = '{0}/{1}.nii.gz'.format(data_dir, FLAGS.seq_name) 70 | 71 | if not os.path.exists(image_name): 72 | print(' Directory {0} does not contain an image with file ' 73 | 'name {1}. Skip.'.format(data_dir, 74 | os.path.basename(image_name))) 75 | continue 76 | 77 | # Read the image 78 | print(' Reading {} ...'.format(image_name)) 79 | nim = nib.load(image_name) 80 | image = nim.get_data() 81 | X, Y, Z, T = image.shape 82 | orig_image = image 83 | 84 | print(' Segmenting full sequence ...') 85 | start_seg_time = time.time() 86 | 87 | # Intensity rescaling 88 | image = rescale_intensity(image, (1, 99)) 89 | 90 | # Prediction (segmentation) 91 | pred = np.zeros(image.shape) 92 | 93 | # Pad the image size to be a factor of 16 so that the 94 | # downsample and upsample procedures in the network will 95 | # result in the same image size at each resolution level. 96 | X2, Y2 = int(math.ceil(X / 16.0)) * 16, int(math.ceil(Y / 16.0)) * 16 97 | x_pre, y_pre = int((X2 - X) / 2), int((Y2 - Y) / 2) 98 | x_post, y_post = (X2 - X) - x_pre, (Y2 - Y) - y_pre 99 | image = np.pad(image, 100 | ((x_pre, x_post), (y_pre, y_post), (0, 0), (0, 0)), 101 | 'constant') 102 | 103 | # Process each time frame 104 | for t in range(T): 105 | # Transpose the shape to NXYC 106 | image_fr = image[:, :, :, t] 107 | image_fr = np.transpose(image_fr, axes=(2, 0, 1)).astype(np.float32) 108 | image_fr = np.expand_dims(image_fr, axis=-1) 109 | 110 | # Evaluate the network 111 | prob_fr, pred_fr = sess.run(['prob:0', 'pred:0'], 112 | feed_dict={'image:0': image_fr, 113 | 'training:0': False}) 114 | 115 | # Transpose and crop segmentation to recover the original size 116 | pred_fr = np.transpose(pred_fr, axes=(1, 2, 0)) 117 | pred_fr = pred_fr[x_pre:x_pre + X, y_pre:y_pre + Y] 118 | pred[:, :, :, t] = pred_fr 119 | 120 | seg_time = time.time() - start_seg_time 121 | print(' Segmentation time = {:3f}s'.format(seg_time)) 122 | table_time += [seg_time] 123 | processed_list += [data] 124 | 125 | # ED frame defaults to be the first time frame. 126 | # Determine ES frame according to the minimum LV volume. 127 | k = {} 128 | k['ED'] = 0 129 | if FLAGS.seq_name == 'sa': 130 | k['ES'] = np.argmin(np.sum(pred == 1, axis=(0, 1, 2))) 131 | else: 132 | k['ES'] = np.argmax(np.sum(pred == 1, axis=(0, 1, 2))) 133 | print(' ED frame = {:d}, ES frame = {:d}'.format(k['ED'], k['ES'])) 134 | 135 | # Save the segmentation 136 | if FLAGS.save_seg: 137 | print(' Saving segmentation ...') 138 | dest_data_dir = os.path.join(FLAGS.dest_dir, data) 139 | if not os.path.exists(dest_data_dir): 140 | os.makedirs(dest_data_dir) 141 | 142 | nim2 = nib.Nifti1Image(pred, nim.affine) 143 | nim2.header['pixdim'] = nim.header['pixdim'] 144 | nib.save(nim2, '{0}/seg_{1}.nii.gz'.format(dest_data_dir, 145 | FLAGS.seq_name)) 146 | 147 | for fr in ['ED', 'ES']: 148 | nib.save(nib.Nifti1Image(orig_image[:, :, :, k[fr]], 149 | nim.affine), 150 | '{0}/{1}_{2}.nii.gz'.format(dest_data_dir, 151 | FLAGS.seq_name, 152 | fr)) 153 | nib.save(nib.Nifti1Image(pred[:, :, :, k[fr]], 154 | nim.affine), 155 | '{0}/seg_{1}_{2}.nii.gz'.format(dest_data_dir, 156 | FLAGS.seq_name, 157 | fr)) 158 | 159 | # Evaluate the clinical measures 160 | if FLAGS.seq_name == 'sa' and FLAGS.clinical_measure: 161 | print(' Evaluating clinical measures ...') 162 | measure = {} 163 | dx, dy, dz = nim.header['pixdim'][1:4] 164 | volume_per_voxel = dx * dy * dz * 1e-3 165 | density = 1.05 166 | 167 | for fr in ['ED', 'ES']: 168 | measure[fr] = {} 169 | measure[fr]['LVV'] = np.sum(pred[:, :, :, k[fr]] == 1) * volume_per_voxel 170 | measure[fr]['LVM'] = np.sum(pred[:, :, :, k[fr]] == 2) * volume_per_voxel * density 171 | measure[fr]['RVV'] = np.sum(pred[:, :, :, k[fr]] == 3) * volume_per_voxel 172 | 173 | line = [measure['ED']['LVV'], measure['ES']['LVV'], 174 | measure['ED']['LVM'], 175 | measure['ED']['RVV'], measure['ES']['RVV']] 176 | table += [line] 177 | else: 178 | # Process ED and ES time frames 179 | image_ED_name = '{0}/{1}_{2}.nii.gz'.format(data_dir, 180 | FLAGS.seq_name, 181 | 'ED') 182 | image_ES_name = '{0}/{1}_{2}.nii.gz'.format(data_dir, 183 | FLAGS.seq_name, 184 | 'ES') 185 | if not os.path.exists(image_ED_name) \ 186 | or not os.path.exists(image_ES_name): 187 | print(' Directory {0} does not contain an image with ' 188 | 'file name {1} or {2}. Skip.'.format(data_dir, 189 | os.path.basename(image_ED_name), 190 | os.path.basename(image_ES_name))) 191 | continue 192 | 193 | measure = {} 194 | for fr in ['ED', 'ES']: 195 | image_name = '{0}/{1}_{2}.nii.gz'.format(data_dir, FLAGS.seq_name, fr) 196 | 197 | # Read the image 198 | print(' Reading {} ...'.format(image_name)) 199 | nim = nib.load(image_name) 200 | image = nim.get_data() 201 | X, Y = image.shape[:2] 202 | if image.ndim == 2: 203 | image = np.expand_dims(image, axis=2) 204 | 205 | print(' Segmenting {} frame ...'.format(fr)) 206 | start_seg_time = time.time() 207 | 208 | # Intensity rescaling 209 | image = rescale_intensity(image, (1, 99)) 210 | 211 | # Pad the image size to be a factor of 16 so that 212 | # the downsample and upsample procedures in the network 213 | # will result in the same image size at each resolution 214 | # level. 215 | X2, Y2 = int(math.ceil(X / 16.0)) * 16, int(math.ceil(Y / 16.0)) * 16 216 | x_pre, y_pre = int((X2 - X) / 2), int((Y2 - Y) / 2) 217 | x_post, y_post = (X2 - X) - x_pre, (Y2 - Y) - y_pre 218 | image = np.pad(image, 219 | ((x_pre, x_post), (y_pre, y_post), (0, 0)), 220 | 'constant') 221 | 222 | # Transpose the shape to NXYC 223 | image = np.transpose(image, axes=(2, 0, 1)).astype(np.float32) 224 | image = np.expand_dims(image, axis=-1) 225 | 226 | # Evaluate the network 227 | prob, pred = sess.run(['prob:0', 'pred:0'], 228 | feed_dict={'image:0': image, 229 | 'training:0': False}) 230 | 231 | # Transpose and crop the segmentation to recover the original size 232 | pred = np.transpose(pred, axes=(1, 2, 0)) 233 | pred = pred[x_pre:x_pre + X, y_pre:y_pre + Y] 234 | 235 | seg_time = time.time() - start_seg_time 236 | print(' Segmentation time = {:3f}s'.format(seg_time)) 237 | table_time += [seg_time] 238 | processed_list += [data] 239 | 240 | # Save the segmentation 241 | if FLAGS.save_seg: 242 | print(' Saving segmentation ...') 243 | dest_data_dir = os.path.join(FLAGS.dest_dir, data) 244 | if not os.path.exists(dest_data_dir): 245 | os.makedirs(dest_data_dir) 246 | 247 | nim2 = nib.Nifti1Image(pred, nim.affine) 248 | nim2.header['pixdim'] = nim.header['pixdim'] 249 | nib.save(nim2, 250 | '{0}/seg_{1}_{2}.nii.gz'.format(dest_data_dir, 251 | FLAGS.seq_name, 252 | fr)) 253 | 254 | # Evaluate the clinical measures 255 | if FLAGS.seq_name == 'sa' and FLAGS.clinical_measure: 256 | print(' Evaluating clinical measures ...') 257 | dx, dy, dz = nim.header['pixdim'][1:4] 258 | volume_per_voxel = dx * dy * dz * 1e-3 259 | density = 1.05 260 | 261 | measure[fr] = {} 262 | measure[fr]['LVV'] = np.sum(pred == 1) * volume_per_voxel 263 | measure[fr]['LVM'] = np.sum(pred == 2) * volume_per_voxel * density 264 | measure[fr]['RVV'] = np.sum(pred == 3) * volume_per_voxel 265 | 266 | if FLAGS.clinical_measure and FLAGS.seq_name == 'sa': 267 | line = [measure['ED']['LVV'], measure['ES']['LVV'], 268 | measure['ED']['LVM'], 269 | measure['ED']['RVV'], measure['ES']['RVV']] 270 | table += [line] 271 | 272 | # Save the spreadsheet for the clinical measures 273 | if FLAGS.seq_name == 'sa' and FLAGS.clinical_measure: 274 | column_names = ['LVEDV (mL)', 'LVESV (mL)', 'LVM (g)', 275 | 'RVEDV (mL)', 'RVESV (mL)'] 276 | df = pd.DataFrame(table, index=processed_list, columns=column_names) 277 | csv_name = os.path.join(FLAGS.dest_dir, 'clinical_measure.csv') 278 | print(' Saving clinical measures at {0} ...'.format(csv_name)) 279 | df.to_csv(csv_name) 280 | 281 | if FLAGS.process_seq: 282 | print('Average segmentation time = {:.3f}s per sequence'.format( 283 | np.mean(table_time))) 284 | else: 285 | print('Average segmentation time = {:.3f}s per frame'.format( 286 | np.mean(table_time))) 287 | process_time = time.time() - start_time 288 | print('Including image I/O, CUDA resource allocation, ' 289 | 'it took {:.3f}s for processing {:d} subjects ' 290 | '({:.3f}s per subjects).'.format(process_time, 291 | len(processed_list), 292 | process_time / len(processed_list))) 293 | -------------------------------------------------------------------------------- /ukbb_cardiac_segmentation_cine_mri/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import cv2 16 | import numpy as np 17 | import tensorflow as tf 18 | from scipy import ndimage 19 | 20 | 21 | def tf_categorical_accuracy(pred, truth): 22 | """ Accuracy metric """ 23 | return tf.reduce_mean(tf.cast(tf.equal(pred, truth), dtype=tf.float32)) 24 | 25 | 26 | def tf_categorical_dice(pred, truth, k): 27 | """ Dice overlap metric for label k """ 28 | A = tf.cast(tf.equal(pred, k), dtype=tf.float32) 29 | B = tf.cast(tf.equal(truth, k), dtype=tf.float32) 30 | return 2 * tf.reduce_sum(tf.multiply(A, B)) / (tf.reduce_sum(A) + tf.reduce_sum(B)) 31 | 32 | 33 | def crop_image(image, cx, cy, size): 34 | """ Crop a 3D image using a bounding box centred at (cx, cy) with specified size """ 35 | X, Y = image.shape[:2] 36 | r = int(size / 2) 37 | x1, x2 = cx - r, cx + r 38 | y1, y2 = cy - r, cy + r 39 | x1_, x2_ = max(x1, 0), min(x2, X) 40 | y1_, y2_ = max(y1, 0), min(y2, Y) 41 | # Crop the image 42 | crop = image[x1_: x2_, y1_: y2_] 43 | # Pad the image if the specified size is larger than the input image size 44 | if crop.ndim == 3: 45 | crop = np.pad(crop, 46 | ((x1_ - x1, x2 - x2_), (y1_ - y1, y2 - y2_), (0, 0)), 47 | 'constant') 48 | elif crop.ndim == 4: 49 | crop = np.pad(crop, 50 | ((x1_ - x1, x2 - x2_), (y1_ - y1, y2 - y2_), (0, 0), (0, 0)), 51 | 'constant') 52 | else: 53 | print('Error: unsupported dimension, crop.ndim = {0}.'.format(crop.ndim)) 54 | exit(0) 55 | return crop 56 | 57 | 58 | def rescale_intensity(image, thres=(1.0, 99.0)): 59 | """ Rescale the image intensity to the range of [0, 1] """ 60 | val_l, val_h = np.percentile(image, thres) 61 | image2 = image 62 | image2[image < val_l] = val_l 63 | image2[image > val_h] = val_h 64 | image2 = (image2.astype(np.float32) - val_l) / (val_h - val_l) 65 | return image2 66 | 67 | 68 | def data_augmenter(image, label, shift, rotate, scale, intensity, flip): 69 | """ 70 | Online data augmentation 71 | Perform affine transformation on image and label, 72 | which are 4D tensor of shape (N, H, W, C) and 3D tensor of shape (N, H, W). 73 | """ 74 | image2 = np.zeros(image.shape, dtype=np.float32) 75 | label2 = np.zeros(label.shape, dtype=np.int32) 76 | for i in range(image.shape[0]): 77 | # For each image slice, generate random affine transformation parameters 78 | # using the Gaussian distribution 79 | shift_val = [np.clip(np.random.normal(), -3, 3) * shift, 80 | np.clip(np.random.normal(), -3, 3) * shift] 81 | rotate_val = np.clip(np.random.normal(), -3, 3) * rotate 82 | scale_val = 1 + np.clip(np.random.normal(), -3, 3) * scale 83 | intensity_val = 1 + np.clip(np.random.normal(), -3, 3) * intensity 84 | 85 | # Apply the affine transformation (rotation + scale + shift) to the image 86 | row, col = image.shape[1:3] 87 | M = cv2.getRotationMatrix2D((row / 2, col / 2), rotate_val, 1.0 / scale_val) 88 | M[:, 2] += shift_val 89 | for c in range(image.shape[3]): 90 | image2[i, :, :, c] = ndimage.interpolation.affine_transform(image[i, :, :, c], 91 | M[:, :2], M[:, 2], order=1) 92 | 93 | # Apply the affine transformation (rotation + scale + shift) to the label map 94 | label2[i, :, :] = ndimage.interpolation.affine_transform(label[i, :, :], 95 | M[:, :2], M[:, 2], order=0) 96 | 97 | # Apply intensity variation 98 | image2[i] *= intensity_val 99 | 100 | # Apply random horizontal or vertical flipping 101 | if flip: 102 | if np.random.uniform() >= 0.5: 103 | image2[i] = image2[i, ::-1, :, :] 104 | label2[i] = label2[i, ::-1, :] 105 | else: 106 | image2[i] = image2[i, :, ::-1, :] 107 | label2[i] = label2[i, :, ::-1] 108 | return image2, label2 109 | 110 | 111 | def np_categorical_dice(pred, truth, k): 112 | """ Dice overlap metric for label k """ 113 | A = (pred == k).astype(np.float32) 114 | B = (truth == k).astype(np.float32) 115 | return 2 * np.sum(A * B) / (np.sum(A) + np.sum(B)) 116 | 117 | 118 | def distance_metric(seg_A, seg_B, dx): 119 | """ 120 | Measure the distance errors between the contours of two segmentations. 121 | The manual contours are drawn on 2D slices. 122 | We calculate contour to contour distance for each slice. 123 | """ 124 | table_md = [] 125 | table_hd = [] 126 | X, Y, Z = seg_A.shape 127 | for z in range(Z): 128 | # Binary mask at this slice 129 | slice_A = seg_A[:, :, z].astype(np.uint8) 130 | slice_B = seg_B[:, :, z].astype(np.uint8) 131 | 132 | # The distance is defined only when both contours exist on this slice 133 | if np.sum(slice_A) > 0 and np.sum(slice_B) > 0: 134 | # Find contours and retrieve all the points 135 | _, contours, _ = cv2.findContours(cv2.inRange(slice_A, 1, 1), 136 | cv2.RETR_EXTERNAL, 137 | cv2.CHAIN_APPROX_NONE) 138 | pts_A = contours[0] 139 | for i in range(1, len(contours)): 140 | pts_A = np.vstack((pts_A, contours[i])) 141 | 142 | _, contours, _ = cv2.findContours(cv2.inRange(slice_B, 1, 1), 143 | cv2.RETR_EXTERNAL, 144 | cv2.CHAIN_APPROX_NONE) 145 | pts_B = contours[0] 146 | for i in range(1, len(contours)): 147 | pts_B = np.vstack((pts_B, contours[i])) 148 | 149 | # Distance matrix between point sets 150 | M = np.zeros((len(pts_A), len(pts_B))) 151 | for i in range(len(pts_A)): 152 | for j in range(len(pts_B)): 153 | M[i, j] = np.linalg.norm(pts_A[i, 0] - pts_B[j, 0]) 154 | 155 | # Mean distance and hausdorff distance 156 | md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx 157 | hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx 158 | table_md += [md] 159 | table_hd += [hd] 160 | 161 | # Return the mean distance and Hausdorff distance across 2D slices 162 | mean_md = np.mean(table_md) if table_md else None 163 | mean_hd = np.mean(table_hd) if table_hd else None 164 | return mean_md, mean_hd 165 | -------------------------------------------------------------------------------- /ukbb_cardiac_segmentation_cine_mri/network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | import numpy as np 17 | 18 | 19 | def conv2d_bn_relu(x, filters, training, kernel_size=3, strides=1): 20 | """ Basic Conv + BN + ReLU unit """ 21 | x_conv = tf.layers.conv2d(x, filters=filters, kernel_size=kernel_size, 22 | strides=strides, 23 | padding='same', use_bias=False) 24 | x_bn = tf.layers.batch_normalization(x_conv, training=training) 25 | x_relu = tf.nn.relu(x_bn) 26 | return x_relu 27 | 28 | 29 | def residual_unit(x, filters, training, strides=1): 30 | """ 31 | Basic residual learning unit, which implements the unit illustrated 32 | by Figure 1(b) in He et al. Identity Mappings in Deep Residual 33 | Networks, ECCV 2016. https://arxiv.org/pdf/1603.05027 34 | """ 35 | orig_x = x 36 | with tf.name_scope('sub1'): 37 | x = tf.layers.batch_normalization(x, training=training) 38 | x = tf.nn.relu(x) 39 | x = tf.layers.conv2d(x, filters=filters, kernel_size=3, 40 | strides=strides, 41 | padding='same', use_bias=False) 42 | with tf.name_scope('sub2'): 43 | x = tf.layers.batch_normalization(x, training=training) 44 | x = tf.nn.relu(x) 45 | x = tf.layers.conv2d(x, filters=filters, kernel_size=3, 46 | strides=1, 47 | padding='same', use_bias=False) 48 | # Use projection for increased dimension and identity mapping 49 | # for the same dimension i.e. option B in 50 | # He at al. Deep Residual Learning for Image Recognition, CVPR 2016. 51 | # https://arxiv.org/pdf/1512.03385 52 | with tf.name_scope('add'): 53 | if orig_x.shape[3] == x.shape[3] and strides == 1: 54 | shortcut = orig_x 55 | else: 56 | shortcut = tf.layers.conv2d(orig_x, filters=filters, 57 | kernel_size=1, strides=strides, 58 | padding='same', use_bias=False) 59 | x = shortcut + x 60 | return x 61 | 62 | 63 | def bottleneck_unit(x, filters, training, strides=1): 64 | """ 65 | Bottleneck residual learning unit, which implements the unit illustrated 66 | on the right of Figure 5 in 67 | He at al. Deep Residual Learning for Image Recognition, CVPR 2016. 68 | https://arxiv.org/pdf/1512.03385 69 | We also move batch_norm and relu before conv, as shown in He's next paper 70 | He at al. Identity Mappings in Deep Residual Networks, ECCV 2016. 71 | https://arxiv.org/pdf/1603.05027. 72 | """ 73 | orig_x = x 74 | with tf.name_scope('sub1'): 75 | x = tf.layers.batch_normalization(x, training=training) 76 | x = tf.nn.relu(x) 77 | x = tf.layers.conv2d(x, filters=filters / 4, kernel_size=1, 78 | strides=strides, 79 | padding='same', use_bias=False) 80 | with tf.name_scope('sub2'): 81 | x = tf.layers.batch_normalization(x, training=training) 82 | x = tf.nn.relu(x) 83 | x = tf.layers.conv2d(x, filters=filters / 4, kernel_size=3, 84 | strides=1, 85 | padding='same', use_bias=False) 86 | with tf.name_scope('sub3'): 87 | x = tf.layers.batch_normalization(x, training=training) 88 | x = tf.nn.relu(x) 89 | x = tf.layers.conv2d(x, filters=filters, kernel_size=1, 90 | strides=1, 91 | padding='same', use_bias=False) 92 | with tf.name_scope('add'): 93 | if orig_x.shape[3] == x.shape[3] and strides == 1: 94 | shortcut = orig_x 95 | else: 96 | shortcut = tf.layers.conv2d(orig_x, filters=filters, 97 | kernel_size=1, strides=strides, 98 | padding='same', use_bias=False) 99 | x = shortcut + x 100 | return x 101 | 102 | 103 | def linear_1d(sz): 104 | """ 1D linear interpolation kernel """ 105 | if sz % 2 == 0: 106 | raise NotImplementedError('`Linear kernel` requires odd filter size.') 107 | c = int((sz + 1) / 2) 108 | h = np.array(list(range(1, c + 1)) + list(range(c - 1, 0, -1)), dtype=np.float32) 109 | h /= float(c) 110 | return h 111 | 112 | 113 | def linear_2d(sz): 114 | """ 2D linear interpolation kernel """ 115 | W = np.ones((sz, sz), dtype=np.float32) 116 | h = linear_1d(sz) 117 | for i in range(sz): 118 | W[i, :] *= h 119 | for j in range(sz): 120 | W[:, j] *= h 121 | return W 122 | 123 | 124 | def transpose_upsample2d(x, factor, constant=True): 125 | """ 2D upsampling operator using transposed convolution """ 126 | x_shape = tf.shape(x) 127 | output_shape = tf.stack([x_shape[0], 128 | x_shape[1] * factor, 129 | x_shape[2] * factor, 130 | x.shape[3].value]) 131 | 132 | # The bilinear interpolation weight for the upsampling filter 133 | sz = factor * 2 - 1 134 | W = linear_2d(sz) 135 | n = x.shape[3].value 136 | filt_val = np.zeros((sz, sz, n, n), dtype=np.float32) 137 | for i in range(n): 138 | filt_val[:, :, i, i] = W 139 | 140 | # Currently, we simply use the fixed bilinear interpolation weights. 141 | # However, it is possible to set the filt to a trainable variable. 142 | if constant: 143 | filt = tf.constant(filt_val, dtype=tf.float32) 144 | else: 145 | filt = tf.Variable(filt_val, dtype=tf.float32) 146 | 147 | # Currently, if output_shape is an unknown shape, conv2d_transpose() 148 | # will output an unknown shape during graph construction. This will be 149 | # a problem for the next step tf.concat(), which requires a known shape. 150 | # A workaround is to reshape this tensor to the expected shape size. 151 | # Refer to https://github.com/tensorflow/tensorflow/issues/833#issuecomment-278016198 152 | x_up = tf.nn.conv2d_transpose(x, 153 | filter=filt, 154 | output_shape=output_shape, 155 | strides=[1, factor, factor, 1], 156 | padding='SAME') 157 | x_out = tf.reshape(x_up, 158 | (x_shape[0], 159 | x_shape[1] * factor, 160 | x_shape[2] * factor, 161 | x.shape[3].value)) 162 | return x_out 163 | 164 | 165 | def build_FCN(image, n_class, n_level, n_filter, n_block, training, same_dim=32, fc=64): 166 | """ 167 | Build a fully convolutional network for segmenting an input image 168 | into n_class classes and return the logits map. 169 | """ 170 | net = {} 171 | x = image 172 | 173 | # Learn fine-to-coarse features at each resolution level 174 | for level in range(0, n_level): 175 | with tf.name_scope('conv{0}'.format(level)): 176 | # If this is the first level (l = 0), keep the resolution. 177 | # Otherwise, convolve with a stride of 2, i.e. downsample 178 | # by a factor of 2。 179 | strides = 1 if level == 0 else 2 180 | # For each resolution level, perform n_block[l] times convolutions 181 | x = conv2d_bn_relu(x, 182 | filters=n_filter[level], 183 | training=training, 184 | kernel_size=3, 185 | strides=strides) 186 | 187 | for i in range(1, n_block[level]): 188 | x = conv2d_bn_relu(x, 189 | filters=n_filter[level], 190 | training=training, 191 | kernel_size=3) 192 | net['conv{0}'.format(level)] = x 193 | 194 | # Before upsampling back to the original resolution level, map all the 195 | # feature maps to have same_dim dimensions. Otherwise, the upsampled 196 | # feature maps will have both a large size (e.g. 192 x 192) and a high 197 | # dimension (e.g. 256 features), which may exhaust the GPU memory (e.g. 198 | # 12 GB for Nvidia Titan K80). 199 | # Exemplar calculation: 200 | # batch size 20 x image size 192 x 192 x feature dimension 256 x floating data type 4 201 | # = 755 MB for a feature map 202 | # Apart from this, there is also associated memory of the same size 203 | # used for gradient calculation. 204 | with tf.name_scope('same_dim'): 205 | for level in range(0, n_level): 206 | net['conv{0}_same_dim'.format(level)] = conv2d_bn_relu( 207 | net['conv{0}'.format(level)], 208 | filters=same_dim, 209 | training=training, 210 | kernel_size=1) 211 | 212 | # Upsample the feature maps at each resolution level to the original resolution 213 | with tf.name_scope('up'): 214 | net['conv0_up'] = net['conv0_same_dim'] 215 | for level in range(1, n_level): 216 | net['conv{0}_up'.format(level)] = transpose_upsample2d( 217 | net['conv{0}_same_dim'.format(level)], 218 | factor=int(pow(2, level))) 219 | 220 | # Concatenate the multi-level feature maps 221 | with tf.name_scope('concat'): 222 | list_up = [] 223 | for level in range(0, n_level): 224 | list_up += [net['conv{0}_up'.format(level)]] 225 | net['concat'] = tf.concat(list_up, axis=-1) 226 | 227 | # Perform prediction using the multi-level feature maps 228 | with tf.name_scope('out'): 229 | # We only calculate logits, instead of softmax here because the loss 230 | # function tf.nn.softmax_cross_entropy() accepts the unscaled logits 231 | # and performs softmax internally for efficiency and numerical stability 232 | # reasons. Refer to https://github.com/tensorflow/tensorflow/issues/2462 233 | x = net['concat'] 234 | x = conv2d_bn_relu(x, filters=fc, training=training, kernel_size=1) 235 | x = conv2d_bn_relu(x, filters=fc, training=training, kernel_size=1) 236 | logits = tf.layers.conv2d(x, filters=n_class, kernel_size=1, padding='same') 237 | return logits 238 | 239 | 240 | def build_ResNet(image, n_class, n_level, n_filter, n_block, training, 241 | use_bottleneck=False, same_dim=32, fc=64): 242 | """ 243 | Build a fully convolutional network with residual learning units 244 | for segmenting an input image into n_class classes and return the 245 | logits map. 246 | """ 247 | if use_bottleneck: 248 | res_func = bottleneck_unit 249 | else: 250 | res_func = residual_unit 251 | net = {} 252 | x = image 253 | 254 | # Learn fine-to-coarse features at each resolution level 255 | # As Figure 1 in 256 | # He at al. Deep Residual Learning for Image Recognition, CVPR 2016. 257 | # https://arxiv.org/pdf/1512.03385 258 | # shows, the original residual network for ImageNet classification only 259 | # starts using the residual units from the third resolution level. 260 | # We do the same here. 261 | for level in range(0, 2): 262 | with tf.name_scope('conv{0}'.format(level)): 263 | strides = 1 if level == 0 else 2 264 | x = conv2d_bn_relu(x, 265 | filters=n_filter[level], 266 | training=training, 267 | kernel_size=3, 268 | strides=strides) 269 | for i in range(1, n_block[level]): 270 | x = conv2d_bn_relu(x, 271 | filters=n_filter[level], 272 | training=training, 273 | kernel_size=3) 274 | net['conv{0}'.format(level)] = x 275 | 276 | for level in range(2, n_level): 277 | with tf.name_scope('conv{0}'.format(level)): 278 | x = res_func(x, filters=n_filter[level], training=training, strides=2) 279 | for i in range(1, n_block[level]): 280 | x = res_func(x, filters=n_filter[level], training=training) 281 | net['conv{0}'.format(level)] = x 282 | 283 | # Before upsampling back to the original resolution level, map all the feature maps 284 | # to have same_dim dimensions. 285 | with tf.name_scope('same_dim'): 286 | for level in range(0, n_level): 287 | net['conv{0}_same_dim'.format(level)] = conv2d_bn_relu( 288 | net['conv{0}'.format(level)], 289 | training=training, 290 | filters=same_dim, 291 | kernel_size=1) 292 | 293 | # Upsample the feature maps at each resolution level to the original resolution 294 | with tf.name_scope('up'): 295 | net['conv0_up'] = net['conv0_same_dim'] 296 | for level in range(1, n_level): 297 | net['conv{0}_up'.format(level)] = transpose_upsample2d( 298 | net['conv{0}_same_dim'.format(level)], 299 | factor=int(pow(2, level))) 300 | 301 | # Concatenate the multi-level feature maps 302 | with tf.name_scope('concat'): 303 | list_up = [] 304 | for level in range(0, n_level): 305 | list_up += [net['conv{0}_up'.format(level)]] 306 | net['concat'] = tf.concat(list_up, axis=-1) 307 | 308 | # Perform prediction using the multi-level feature maps 309 | with tf.name_scope('out'): 310 | # We only calculate logits, instead of softmax here because the loss function 311 | # tf.nn.softmax_cross_entropy() accepts the unscaled logits and performs softmax 312 | # internally for efficiency and numerical stability reasons. 313 | # Refer to https://github.com/tensorflow/tensorflow/issues/2462 314 | x = net['concat'] 315 | x = conv2d_bn_relu(x, filters=fc, training=training, kernel_size=1) 316 | x = conv2d_bn_relu(x, filters=fc, training=training, kernel_size=1) 317 | logits = tf.layers.conv2d(x, filters=n_class, kernel_size=1, padding='same') 318 | return logits 319 | -------------------------------------------------------------------------------- /ukbb_cardiac_segmentation_cine_mri/train_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import os 16 | import time 17 | import random 18 | import numpy as np 19 | import nibabel as nib 20 | import tensorflow as tf 21 | from network import build_FCN, build_ResNet 22 | from image_utils import tf_categorical_accuracy, tf_categorical_dice 23 | from image_utils import crop_image, rescale_intensity, data_augmenter 24 | 25 | 26 | """ Training parameters """ 27 | FLAGS = tf.app.flags.FLAGS 28 | tf.app.flags.DEFINE_integer('image_size', 192, 29 | 'Image size after cropping.') 30 | tf.app.flags.DEFINE_integer('train_batch_size', 2, 31 | 'Number of images for each training batch.') 32 | tf.app.flags.DEFINE_integer('validation_batch_size', 2, 33 | 'Number of images for each validation batch.') 34 | tf.app.flags.DEFINE_integer('train_iteration', 50000, 35 | 'Number of training iterations.') 36 | tf.app.flags.DEFINE_integer('num_filter', 16, 37 | 'Number of filters for the first convolution layer.') 38 | tf.app.flags.DEFINE_integer('num_level', 5, 39 | 'Number of network levels.') 40 | tf.app.flags.DEFINE_float('learning_rate', 1e-3, 41 | 'Learning rate.') 42 | tf.app.flags._global_parser.add_argument('--seq_name', 43 | choices=['sa', 'la_2ch', 'la_4ch'], 44 | default='sa', help='Sequence name for training.') 45 | tf.app.flags._global_parser.add_argument('--model', 46 | choices=['FCN', 'ResNet'], 47 | default='FCN', help='Model name.') 48 | tf.app.flags._global_parser.add_argument('--optimizer', 49 | choices=['Adam', 'SGD', 'Momentum'], 50 | default='Adam', help='Optimizer.') 51 | tf.app.flags.DEFINE_string('dataset_dir', 52 | '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa', 53 | 'Path to the dataset directory, which is split into ' 54 | 'training, validation and test subdirectories.') 55 | tf.app.flags.DEFINE_string('log_dir', 56 | '/vol/bitbucket/wbai/ukbb_cardiac/log', 57 | 'Directory for saving the log file.') 58 | tf.app.flags.DEFINE_string('checkpoint_dir', 59 | '/vol/bitbucket/wbai/ukbb_cardiac/model', 60 | 'Directory for saving the trained model.') 61 | 62 | 63 | def get_random_batch(filename_list, batch_size, image_size=192, data_augmentation=False, 64 | shift=0.0, rotate=0.0, scale=0.0, intensity=0.0, flip=False): 65 | # Randomly select batch_size images from filename_list 66 | n_file = len(filename_list) 67 | n_selected = 0 68 | images = [] 69 | labels = [] 70 | while n_selected < batch_size: 71 | rand_index = random.randrange(n_file) 72 | image_name, label_name = filename_list[rand_index] 73 | if os.path.exists(image_name) and os.path.exists(label_name): 74 | print(' Select {0} {1}'.format(image_name, label_name)) 75 | 76 | # Read image and label 77 | image = nib.load(image_name).get_data() 78 | label = nib.load(label_name).get_data() 79 | 80 | # Handle exceptions 81 | if image.shape != label.shape: 82 | print('Error: mismatched size, image.shape = {0}, ' 83 | 'label.shape = {1}'.format(image.shape, label.shape)) 84 | print('Skip {0}, {1}'.format(image_name, label_name)) 85 | continue 86 | 87 | if image.max() < 1e-6: 88 | print('Error: blank image, image.max = {0}'.format(image.max())) 89 | print('Skip {0} {1}'.format(image_name, label_name)) 90 | continue 91 | 92 | # Normalise the image size 93 | X, Y, Z = image.shape 94 | cx, cy = int(X / 2), int(Y / 2) 95 | image = crop_image(image, cx, cy, image_size) 96 | label = crop_image(label, cx, cy, image_size) 97 | 98 | # Intensity rescaling 99 | image = rescale_intensity(image, (1.0, 99.0)) 100 | 101 | # Append the image slices to the batch 102 | # Use list for appending, which is much faster than numpy array 103 | for z in range(Z): 104 | images += [image[:, :, z]] 105 | labels += [label[:, :, z]] 106 | 107 | # Increase the counter 108 | n_selected += 1 109 | 110 | # Convert to a numpy array 111 | images = np.array(images, dtype=np.float32) 112 | labels = np.array(labels, dtype=np.int32) 113 | 114 | # Add the channel dimension 115 | # tensorflow by default assumes NHWC format 116 | images = np.expand_dims(images, axis=3) 117 | 118 | # Perform data augmentation 119 | if data_augmentation: 120 | images, labels = data_augmenter(images, labels, 121 | shift=shift, rotate=rotate, 122 | scale=scale, 123 | intensity=intensity, flip=flip) 124 | return images, labels 125 | 126 | 127 | def main(argv=None): 128 | """ Main function """ 129 | # Go through each subset (training, validation, test) under the data directory 130 | # and list the file names of the subjects 131 | data_list = {} 132 | for k in ['train', 'validation', 'test']: 133 | subset_dir = os.path.join(FLAGS.dataset_dir, k) 134 | data_list[k] = [] 135 | 136 | for data in sorted(os.listdir(subset_dir)): 137 | data_dir = os.path.join(subset_dir, data) 138 | # Check the existence of the image and label map at ED and ES time frames 139 | # and add their file names to the list 140 | for fr in ['ED', 'ES']: 141 | image_name = '{0}/{1}_{2}.nii.gz'.format(data_dir, 142 | FLAGS.seq_name, 143 | fr) 144 | label_name = '{0}/label_{1}_{2}.nii.gz'.format(data_dir, 145 | FLAGS.seq_name, 146 | fr) 147 | if os.path.exists(image_name) and os.path.exists(label_name): 148 | data_list[k] += [[image_name, label_name]] 149 | 150 | # Prepare tensors for the image and label map pairs 151 | # Use int32 for label_pl as tf.one_hot uses int32 152 | image_pl = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='image') 153 | label_pl = tf.placeholder(tf.int32, shape=[None, None, None], name='label') 154 | 155 | # Print out the placeholders' names, which will be useful when deploying the network 156 | print('Placeholder image_pl.name = ' + image_pl.name) 157 | print('Placeholder label_pl.name = ' + label_pl.name) 158 | 159 | # Placeholder for the training phase 160 | # This flag is important for the batch_normalization layer to function properly. 161 | training_pl = tf.placeholder(tf.bool, shape=[], name='training') 162 | print('Placeholder training_pl.name = ' + training_pl.name) 163 | 164 | # Determine the number of label classes according to the manual annotation procedure 165 | # for each image sequence. 166 | n_class = 0 167 | if FLAGS.seq_name == 'sa': 168 | # sa, short-axis images 169 | # 4 classes (background, LV cavity, LV myocardium, RV cavity) 170 | n_class = 4 171 | elif FLAGS.seq_name == 'la_2ch': 172 | # la_2ch, long-axis 2 chamber view images 173 | # 2 classes (background, LA cavity) 174 | n_class = 2 175 | elif FLAGS.seq_name == 'la_4ch': 176 | # la_4ch, long-axis 4 chamber views 177 | # 3 classes (background, LA cavity, RA cavity) 178 | n_class = 3 179 | else: 180 | print('Error: unknown seq_name {0}.'.format(FLAGS.seq_name)) 181 | exit(0) 182 | 183 | # The number of resolution levels 184 | n_level = FLAGS.num_level 185 | 186 | # The number of filters at each resolution level 187 | # Follow the VGG philosophy, increasing the dimension 188 | # by a factor of 2 for each level 189 | n_filter = [] 190 | for i in range(n_level): 191 | n_filter += [FLAGS.num_filter * pow(2, i)] 192 | print('Number of filters at each level =', n_filter) 193 | print('Note: The connection between neurons is proportional to ' 194 | 'n_filter * n_filter. Increasing n_filter by a factor of 2 ' 195 | 'will increase the number of parameters by a factor of 4. ' 196 | 'So it is better to start experiments with a small n_filter ' 197 | 'and increase it later.') 198 | 199 | # Build the neural network, which outputs the logits, 200 | # i.e. the unscaled values just before the softmax layer, 201 | # which will then normalise the logits into the probabilities. 202 | n_block = [] 203 | if FLAGS.model == 'FCN': 204 | n_block = [2, 2, 3, 3, 3] 205 | logits = build_FCN(image_pl, n_class, n_level=n_level, 206 | n_filter=n_filter, n_block=n_block, 207 | training=training_pl, same_dim=32, fc=64) 208 | elif FLAGS.model == 'ResNet': 209 | n_block = [2, 2, 3, 4, 6] 210 | logits = build_ResNet(image_pl, n_class, n_level=n_level, 211 | n_filter=n_filter, n_block=n_block, 212 | training=training_pl, use_bottleneck=False, 213 | same_dim=32, fc=64) 214 | else: 215 | print('Error: unknown model {0}.'.format(FLAGS.model)) 216 | exit(0) 217 | 218 | # The softmax probability and the predicted segmentation 219 | prob = tf.nn.softmax(logits, name='prob') 220 | pred = tf.cast(tf.argmax(prob, axis=-1), dtype=tf.int32, name='pred') 221 | print('prob.name = ' + prob.name) 222 | print('pred.name = ' + pred.name) 223 | 224 | # Loss 225 | label_1hot = tf.one_hot(indices=label_pl, depth=n_class) 226 | label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=label_1hot, 227 | logits=logits) 228 | loss = tf.reduce_mean(label_loss) 229 | 230 | # Evaluation metrics 231 | accuracy = tf_categorical_accuracy(pred, label_pl) 232 | dice_lv = tf_categorical_dice(pred, label_pl, 1) 233 | dice_myo = tf_categorical_dice(pred, label_pl, 2) 234 | dice_rv = tf_categorical_dice(pred, label_pl, 3) 235 | dice_la = tf_categorical_dice(pred, label_pl, 1) 236 | dice_ra = tf_categorical_dice(pred, label_pl, 2) 237 | 238 | # Optimiser 239 | lr = FLAGS.learning_rate 240 | 241 | # We need to add the operators associated with batch_normalization 242 | # to the optimiser, according to 243 | # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization 244 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 245 | with tf.control_dependencies(update_ops): 246 | if FLAGS.optimizer == 'SGD': 247 | print('Using SGD optimizer.') 248 | train_op = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(loss) 249 | elif FLAGS.optimizer == 'Adam': 250 | print('Using Adam optimizer.') 251 | train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) 252 | elif FLAGS.optimizer == 'Momentum': 253 | print('Using Momentum optimizer with Nesterov momentum.') 254 | train_op = tf.train.MomentumOptimizer(learning_rate=lr, 255 | momentum=0.9, 256 | use_nesterov=True).minimize(loss) 257 | else: 258 | print('Error: unknown optimizer {0}.'.format(FLAGS.optimizer)) 259 | exit(0) 260 | 261 | # Model name and directory 262 | model_name = '{0}_{1}_level{2}_filter{3}_{4}_{5}_batch{6}_iter{7}_lr{8}'.format( 263 | FLAGS.model, FLAGS.seq_name, n_level, n_filter[0], 264 | ''.join([str(x) for x in n_block]), 265 | FLAGS.optimizer, FLAGS.train_batch_size, 266 | FLAGS.train_iteration, FLAGS.learning_rate) 267 | model_dir = os.path.join(FLAGS.checkpoint_dir, model_name) 268 | if not os.path.exists(model_dir): 269 | os.makedirs(model_dir) 270 | 271 | # Create a logger 272 | if not os.path.exists(FLAGS.log_dir): 273 | os.makedirs(FLAGS.log_dir) 274 | csv_name = os.path.join(FLAGS.log_dir, '{0}_log.csv'.format(model_name)) 275 | f_log = open(csv_name, 'w') 276 | if FLAGS.seq_name == 'sa': 277 | f_log.write('iteration,time,train_loss,train_acc,test_loss,test_acc,' 278 | 'test_dice_lv,test_dice_myo,test_dice_rv\n') 279 | elif FLAGS.seq_name == 'la_2ch': 280 | f_log.write('iteration,time,train_loss,train_acc,test_loss,test_acc,' 281 | 'test_dice_la\n') 282 | elif FLAGS.seq_name == 'la_4ch': 283 | f_log.write('iteration,time,train_loss,train_acc,test_loss,test_acc,' 284 | 'test_dice_la,test_dice_ra\n') 285 | 286 | # Start the tensorflow session 287 | with tf.Session() as sess: 288 | print('Start training...') 289 | start_time = time.time() 290 | 291 | # Create a saver 292 | saver = tf.train.Saver(max_to_keep=20) 293 | 294 | # Summary writer 295 | summary_dir = os.path.join(FLAGS.log_dir, model_name) 296 | if os.path.exists(summary_dir): 297 | os.system('rm -rf {0}'.format(summary_dir)) 298 | train_writer = tf.summary.FileWriter(os.path.join(summary_dir, 'train'), 299 | graph=sess.graph) 300 | validation_writer = tf.summary.FileWriter(os.path.join(summary_dir, 'validation'), 301 | graph=sess.graph) 302 | 303 | # Initialise variables 304 | sess.run(tf.global_variables_initializer()) 305 | 306 | # Iterate 307 | for iteration in range(1, 1 + FLAGS.train_iteration): 308 | # For each iteration, we randomly choose a batch of subjects 309 | print('Iteration {0}: training...'.format(iteration)) 310 | start_time_iter = time.time() 311 | 312 | images, labels = get_random_batch(data_list['train'], 313 | FLAGS.train_batch_size, 314 | image_size=FLAGS.image_size, 315 | data_augmentation=True, 316 | shift=10, rotate=10, scale=0.1, 317 | intensity=0.1, flip=False) 318 | 319 | # Stochastic optimisation using this batch 320 | _, train_loss, train_acc = sess.run([train_op, loss, accuracy], 321 | {image_pl: images, 322 | label_pl: labels, 323 | training_pl: True}) 324 | 325 | summary = tf.Summary() 326 | summary.value.add(tag='loss', simple_value=train_loss) 327 | summary.value.add(tag='accuracy', simple_value=train_acc) 328 | train_writer.add_summary(summary, iteration) 329 | 330 | # After every ten iterations, we perform validation 331 | if iteration % 10 == 0: 332 | print('Iteration {0}: validation...'.format(iteration)) 333 | images, labels = get_random_batch(data_list['validation'], 334 | FLAGS.validation_batch_size, 335 | image_size=FLAGS.image_size, 336 | data_augmentation=False) 337 | 338 | if FLAGS.seq_name == 'sa': 339 | validation_loss, validation_acc, validation_dice_lv, validation_dice_myo, validation_dice_rv = \ 340 | sess.run([loss, accuracy, dice_lv, dice_myo, dice_rv], 341 | {image_pl: images, label_pl: labels, 342 | training_pl: False}) 343 | elif FLAGS.seq_name == 'la_2ch': 344 | validation_loss, validation_acc, validation_dice_la = \ 345 | sess.run([loss, accuracy, dice_la], 346 | {image_pl: images, label_pl: labels, 347 | training_pl: False}) 348 | elif FLAGS.seq_name == 'la_4ch': 349 | validation_loss, validation_acc, validation_dice_la, validation_dice_ra = \ 350 | sess.run([loss, accuracy, dice_la, dice_ra], 351 | {image_pl: images, label_pl: labels, 352 | training_pl: False}) 353 | 354 | summary = tf.Summary() 355 | summary.value.add(tag='loss', simple_value=validation_loss) 356 | summary.value.add(tag='accuracy', simple_value=validation_acc) 357 | if FLAGS.seq_name == 'sa': 358 | summary.value.add(tag='dice_lv', simple_value=validation_dice_lv) 359 | summary.value.add(tag='dice_myo', simple_value=validation_dice_myo) 360 | summary.value.add(tag='dice_rv', simple_value=validation_dice_rv) 361 | elif FLAGS.seq_name == 'la_2ch': 362 | summary.value.add(tag='dice_la', simple_value=validation_dice_la) 363 | elif FLAGS.seq_name == 'la_4ch': 364 | summary.value.add(tag='dice_la', simple_value=validation_dice_la) 365 | summary.value.add(tag='dice_ra', simple_value=validation_dice_ra) 366 | validation_writer.add_summary(summary, iteration) 367 | 368 | # Print the results for this iteration 369 | print('Iteration {} of {} took {:.3f}s'.format(iteration, 370 | FLAGS.train_iteration, 371 | time.time() - start_time_iter)) 372 | print(' training loss:\t\t{:.6f}'.format(train_loss)) 373 | print(' training accuracy:\t\t{:.2f}%'.format(train_acc * 100)) 374 | print(' validation loss: \t\t{:.6f}'.format(validation_loss)) 375 | print(' validation accuracy:\t\t{:.2f}%'.format(validation_acc * 100)) 376 | if FLAGS.seq_name == 'sa': 377 | print(' validation Dice LV:\t\t{:.6f}'.format(validation_dice_lv)) 378 | print(' validation Dice Myo:\t\t{:.6f}'.format(validation_dice_myo)) 379 | print(' validation Dice RV:\t\t{:.6f}\n'.format(validation_dice_rv)) 380 | elif FLAGS.seq_name == 'la_2ch': 381 | print(' validation Dice LA:\t\t{:.6f}'.format(validation_dice_la)) 382 | elif FLAGS.seq_name == 'la_4ch': 383 | print(' validation Dice LA:\t\t{:.6f}'.format(validation_dice_la)) 384 | print(' validation Dice RA:\t\t{:.6f}'.format(validation_dice_ra)) 385 | 386 | # Log 387 | if FLAGS.seq_name == 'sa': 388 | f_log.write('{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}\n'.format( 389 | iteration, time.time() - start_time, 390 | train_loss, train_acc, 391 | validation_loss, validation_acc, 392 | validation_dice_lv, validation_dice_myo, 393 | validation_dice_rv)) 394 | elif FLAGS.seq_name == 'la_2ch': 395 | f_log.write('{0}, {1}, {2}, {3}, {4}, {5}, {6}\n'.format( 396 | iteration, time.time() - start_time, 397 | train_loss, train_acc, 398 | validation_loss, validation_acc, 399 | validation_dice_la)) 400 | elif FLAGS.seq_name == 'la_4ch': 401 | f_log.write('{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}\n'.format( 402 | iteration, time.time() - start_time, 403 | train_loss, train_acc, 404 | validation_loss, validation_acc, 405 | validation_dice_la, validation_dice_ra)) 406 | f_log.flush() 407 | else: 408 | # Print the results for this iteration 409 | print('Iteration {} of {} took {:.3f}s'.format(iteration, 410 | FLAGS.train_iteration, 411 | time.time() - start_time_iter)) 412 | print(' training loss:\t\t{:.6f}'.format(train_loss)) 413 | print(' training accuracy:\t\t{:.2f}%'.format(train_acc * 100)) 414 | 415 | # Save models after every 1000 iterations (1 epoch) 416 | # One epoch needs to go through 417 | # 1000 subjects * 2 time frames = 2000 images = 1000 training iterations 418 | # if one iteration processes 2 images. 419 | if iteration % 1000 == 0: 420 | saver.save(sess, save_path=os.path.join(model_dir, '{0}.ckpt'.format(model_name)), 421 | global_step=iteration) 422 | 423 | # Close the logger and summary writers 424 | f_log.close() 425 | train_writer.close() 426 | validation_writer.close() 427 | print('Training took {:.3f}s in total.\n'.format(time.time() - start_time)) 428 | 429 | 430 | if __name__ == '__main__': 431 | tf.app.run() 432 | -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/README.md: -------------------------------------------------------------------------------- 1 | ## Fast and Robust Reproduction of Multiple Brain Image Segmentation Pipelines 2 | 3 | ![Example prediction on test data](figures/example_seg.png) 4 | 5 | ### Contact and referencing this work 6 | If there are any issues please contact the corresponding author of this implementation. If you employ this model in your work, please refer to this citation of the [paper](https://openreview.net/pdf?id=Hks1TRisM). 7 | ``` 8 | @inproceedings{rajchl2018neuronet, 9 | title={NeuroNet: Fast and Robust Reproduction of Multiple Brain Image Segmentation Pipelines}, 10 | author={Martin Rajchl and Nick Pawlowski and Daniel Rueckert and Paul M. Matthews and Ben Glocker}, 11 | booktitle={International conference on Medical Imaging with Deep Learning (MIDL)}, 12 | year={2018} 13 | } 14 | ``` 15 | 16 | 17 | ### Data 18 | The data can be downloaded after registration from the [UK Biobank Imaging Enhancement Study website](https://imaging.ukbiobank.ac.uk/). 19 | 20 | Images and segmentations are read from a csv file in the format below. The original files (*.csv) is provided in this repo. 21 | 22 | These are parsed and extract tf.Tensor examples for training and evaluation in `reader.py` using a [SimpleITK](http://www.simpleitk.org/) for i/o of the .nii files. 23 | 24 | 25 | ### Usage 26 | Files: 27 | - `parse_csvs.ipynb` creates training/validation/testing .csv files from data paths and splits the subject ids into categories. 28 | - `sandbox.ipynb` visually assesses the outputs of the `reader.py` for a visual check of the inputs 29 | - `eval.ipynb` computes the visual and numerical results for the paper 30 | 31 | - `reader.py` dltk reader, containing the label mappings to and from consecutive ids and the python generator creating input tensors to the network, using a SimpleITK interface 32 | - `train.py` main training script to run all experiments with 33 | - `deploy.py` generic deploy script for all experiments 34 | 35 | - `config*.json` are configuration files to determine the dataset(s) to train on, scaling the flexible NeuroNet architecture and a few exposed training parameters. 36 | - `*.csv` csv files generated with `parse_csvs.ipynb`, containing the paths to all .nii image files 37 | 38 | 39 | #### Data Preprocessing 40 | We did not apply any data preprocessing, such as brain stripping or additional bias correction, etc. The input to the network is a single MNI registered 1mm isotropic T1-weighted MR image (as procude by the UK Biobank). Please refer to the [UKB Neuroimaging documentation](https://biobank.ctsu.ox.ac.uk/crystal/docs/brain_mri.pdf) for additional information. 41 | 42 | #### Training 43 | You can use the code (train.py) to train the model on the data yourself. Alternatively, we provide pretrained models from the paper here: 44 | - [neuronet_all](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/neuronet_all.tar.gz) 45 | - [neuronet_tissue](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/neuronet_tissue.tar.gz) 46 | - [neuronet_single fsl fast](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/fsl_fast.tar.gz) 47 | - [neuronet_single fsl first](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/fsl_first.tar.gz) 48 | - [neuronet_single spm tissue](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/spm_tissue.tar.gz) 49 | - [neuronet_single malp_em tissue](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/malp_em_tissue.tar.gz) 50 | - [neuronet_single malp_em](http://www.doc.ic.ac.uk/~mrajchl/dltk_models/model_zoo/neuronet/malp_em.tar.gz) 51 | 52 | 53 | Depending on the model, the number of output volumes will correspond with the number of segmentation tasks (i.e. neuronet_single will produce one volume, neuronet_all will produce 5 segmentation volumes). 54 | 55 | You can start a basic training with 56 | ``` 57 | python train.py -c CUDA_DEVICE --config MY_CONFIG 58 | ``` 59 | that will load the file paths from the previously created csvs, according to the config parameters. 60 | 61 | #### Deploy 62 | To deploy a model and run inference, run the deploy.py script and point to the model save_path: 63 | 64 | ``` 65 | python deploy.py -p path/to/saved/model -c CUDA_DEVICE --config MY_CONFIG 66 | ``` -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["fsl_fast", "fsl_first", "spm_tissue", "malp_em", "malp_em_tissue"], 3 | "num_classes": [4, 16, 4, 139, 6], 4 | "model_path": "/tmp/neuronet/models/neuronet_all", 5 | "out_segm_path": "/tmp/neuronet/out/neuronet_all", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_fsl_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["fsl_fast"], 3 | "num_classes": [4], 4 | "model_path": "/tmp/neuronet/models/fsl_fast", 5 | "out_segm_path": "/tmp/neuronet/out/fsl_fast", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_fsl_first.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["fsl_first"], 3 | "num_classes": [16], 4 | "model_path": "/tmp/neuronet/models/fsl_first", 5 | "out_segm_path": "/tmp/neuronet/out/fsl_first", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_malp_em.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["malp_em"], 3 | "num_classes": [139], 4 | "model_path": "/tmp/neuronet/models/malp_em", 5 | "out_segm_path": "/tmp/neuronet/out/malp_em", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_malp_em_tissue.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["malp_em_tissue"], 3 | "num_classes": [6], 4 | "model_path": "/tmp/neuronet/models/malp_em_tissue", 5 | "out_segm_path": "/tmp/neuronet/out/malp_em_tissue", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_spm_tissue.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["spm_tissue"], 3 | "num_classes": [4], 4 | "model_path": "/tmp/neuronet/models/spm_tissue", 5 | "out_segm_path": "/tmp/neuronet/out/spm_tissue", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/config_tissue.json: -------------------------------------------------------------------------------- 1 | { 2 | "protocols": ["fsl_fast", "spm_tissue", "malp_em_tissue"], 3 | "num_classes": [4, 4, 6], 4 | "model_path": "/tmp/neuronet/models/neuronet_tissue", 5 | "out_segm_path": "/tmp/neuronet/out/neuronet_tissue", 6 | "learning_rate": 0.001, 7 | "network": { 8 | "filters": [16, 32, 64, 128], 9 | "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 10 | "num_residual_units": 2 11 | } 12 | } -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/deploy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import tensorflow as tf 12 | import SimpleITK as sitk 13 | import json 14 | 15 | from tensorflow.contrib import predictor 16 | 17 | from dltk.utils import sliding_window_segmentation_inference 18 | 19 | from reader import read_fn, map_labels 20 | 21 | 22 | def predict(args, config): 23 | 24 | # Read in the csv with the file names you would want to predict on 25 | file_names = pd.read_csv(args.csv, 26 | dtype=object, 27 | keep_default_na=False, 28 | na_values=[]).as_matrix() 29 | 30 | # From the model model_path, parse the latest saved estimator model 31 | # and restore a predictor from it 32 | export_dir = [os.path.join(config["model_path"], o) for o in os.listdir(config["model_path"]) 33 | if os.path.isdir(os.path.join(config["model_path"], o)) and o.isdigit()][-1] 34 | print('Loading from {}'.format(export_dir)) 35 | my_predictor = predictor.from_saved_model(export_dir) 36 | 37 | protocols = config["protocols"] 38 | # Fetch the output probability ops of the trained network 39 | y_probs = [my_predictor._fetch_tensors['y_prob_{}'.format(p)] for p in protocols] 40 | 41 | # Iterate through the files, predict on the full volumes and 42 | # compute a Dice similariy coefficient 43 | for output in read_fn(file_references=file_names, 44 | mode=tf.estimator.ModeKeys.PREDICT, 45 | params={'extract_examples': False, 46 | 'protocols': protocols}): 47 | 48 | print('Running file {}'.format(output['img_id'])) 49 | t0 = time.time() 50 | 51 | # Parse the read function output and add a dummy batch dimension 52 | # as required 53 | img = np.expand_dims(output['features']['x'], axis=0) 54 | 55 | # Do a sliding window inference with our DLTK wrapper 56 | preds = sliding_window_segmentation_inference( 57 | session=my_predictor.session, 58 | ops_list=y_probs, 59 | sample_dict={my_predictor._feed_tensors['x']: img}, 60 | batch_size=2) 61 | 62 | # Calculate the prediction from the probabilities 63 | preds = [np.squeeze(np.argmax(pred, -1), axis=0) for pred in preds] 64 | 65 | # Map the consecutive integer label ids back to the original ones 66 | for i in range(len(protocols)): 67 | preds[i] = map_labels(preds[i], 68 | protocol=protocols[i], 69 | convert_to_protocol=True) 70 | 71 | # Save the file as .nii.gz using the header information from the 72 | # original sitk image 73 | out_folder = os.path.join(config["out_segm_path"], '{}'.format(output['img_id'])) 74 | os.system('mkdir -p {}'.format(out_folder)) 75 | 76 | for i in range(len(protocols)): 77 | output_fn = os.path.join(out_folder, protocols[i] + '.nii.gz') 78 | new_sitk = sitk.GetImageFromArray(preds[i].astype(np.int32)) 79 | new_sitk.CopyInformation(output['sitk']) 80 | sitk.WriteImage(new_sitk, output_fn) 81 | 82 | # Print outputs 83 | print('ID={}; input_dim={}; time={};'.format( 84 | output['img_id'], img.shape, time.time() - t0)) 85 | 86 | 87 | if __name__ == '__main__': 88 | # Set up argument parser 89 | parser = argparse.ArgumentParser(description='Neuronet deploy script') 90 | parser.add_argument('--verbose', default=False, action='store_true') 91 | parser.add_argument('--cuda_devices', '-c', default='0') 92 | 93 | parser.add_argument('--csv', default='test.csv') 94 | parser.add_argument('--config', default='config_all.json') 95 | 96 | args = parser.parse_args() 97 | 98 | # Set verbosity 99 | if args.verbose: 100 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 101 | tf.logging.set_verbosity(tf.logging.INFO) 102 | else: 103 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 104 | tf.logging.set_verbosity(tf.logging.ERROR) 105 | 106 | # GPU allocation options 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 108 | 109 | # Parse the run config 110 | with open(args.config) as f: 111 | config = json.load(f) 112 | 113 | # Call training 114 | predict(args, config) 115 | -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/figures/boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/ukbb_neuronet_brain_segmentation/figures/boxplot.png -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/figures/ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/ukbb_neuronet_brain_segmentation/figures/ex.png -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/figures/example_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/ukbb_neuronet_brain_segmentation/figures/example_seg.png -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/figures/fail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLTK/models/7fd907a325cd7a23ccca62d2def2f9f770020cff/ukbb_neuronet_brain_segmentation/figures/fail.png -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/neuronet.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | from __future__ import print_function 3 | from __future__ import division 4 | from __future__ import absolute_import 5 | 6 | import tensorflow as tf 7 | 8 | from dltk.core.residual_unit import vanilla_residual_unit_3d 9 | from dltk.core.upsample import linear_upsample_3d 10 | 11 | 12 | def upscore_layer_3d(inputs, 13 | inputs2, 14 | out_filters, 15 | in_filters=None, 16 | strides=(2, 2, 2), 17 | mode=tf.estimator.ModeKeys.EVAL, use_bias=False, 18 | kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'), 19 | bias_initializer=tf.zeros_initializer(), 20 | kernel_regularizer=None, 21 | bias_regularizer=None): 22 | """Upscore layer according to [1]. 23 | 24 | [1] J. Long et al. Fully convolutional networks for semantic segmentation. 25 | CVPR 2015. 26 | 27 | Args: 28 | inputs (tf.Tensor): Input features to be upscored. 29 | inputs2 (tf.Tensor): Higher resolution features from the encoder to add. 30 | out_filters (int): Number of output filters (typically, number of 31 | segmentation classes) 32 | in_filters (None, optional): None or number of input filters. 33 | strides (tuple, optional): Upsampling factor for a strided transpose 34 | convolution. 35 | mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN, 36 | EVAL or PREDICT 37 | use_bias (bool, optional): Boolean, whether the layer uses a bias. 38 | kernel_initializer (TYPE, optional): An initializer for the convolution 39 | kernel. 40 | bias_initializer (TYPE, optional): An initializer for the bias vector. 41 | If None, no bias will be applied. 42 | kernel_regularizer (None, optional): Optional regularizer for the 43 | convolution kernel. 44 | bias_regularizer (None, optional): Optional regularizer for the bias 45 | vector. 46 | 47 | Returns: 48 | tf.Tensor: Upscore tensor 49 | 50 | """ 51 | conv_params = {'use_bias': use_bias, 52 | 'kernel_initializer': kernel_initializer, 53 | 'bias_initializer': bias_initializer, 54 | 'kernel_regularizer': kernel_regularizer, 55 | 'bias_regularizer': bias_regularizer} 56 | 57 | # Compute an upsampling shape dynamically from the input tensor. Input 58 | # filters are required to be static. 59 | if in_filters is None: 60 | in_filters = inputs.get_shape().as_list()[-1] 61 | 62 | assert len(inputs.get_shape().as_list()) == 5, \ 63 | 'inputs are required to have a rank of 5.' 64 | assert len(inputs.get_shape().as_list()) == len(inputs2.get_shape().as_list()), \ 65 | 'Ranks of input and input2 differ' 66 | 67 | # Account for differences in the number of input and output filters 68 | if in_filters != out_filters: 69 | x = tf.layers.conv3d(inputs=inputs, 70 | filters=out_filters, 71 | kernel_size=(1, 1, 1), 72 | strides=(1, 1, 1), 73 | padding='same', 74 | name='filter_conversion', 75 | **conv_params) 76 | else: 77 | x = inputs 78 | 79 | # Upsample inputs 80 | x = linear_upsample_3d(inputs=x, strides=strides) 81 | 82 | # Skip connection 83 | x2 = tf.layers.conv3d(inputs=inputs2, 84 | filters=out_filters, 85 | kernel_size=(1, 1, 1), 86 | strides=(1, 1, 1), 87 | padding='same', 88 | **conv_params) 89 | 90 | x2 = tf.layers.batch_normalization( 91 | x2, training=mode == tf.estimator.ModeKeys.TRAIN) 92 | 93 | # Return the element-wise sum 94 | return tf.add(x, x2) 95 | 96 | 97 | def neuronet_3d(inputs, 98 | num_classes, 99 | protocols, 100 | num_res_units=2, 101 | filters=(16, 32, 64, 128), 102 | strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)), 103 | mode=tf.estimator.ModeKeys.EVAL, 104 | use_bias=False, 105 | activation=tf.nn.relu6, 106 | kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'), 107 | bias_initializer=tf.zeros_initializer(), 108 | kernel_regularizer=None, 109 | bias_regularizer=None): 110 | """ 111 | NeuroNet [1] is a multi-task image segmentation network based on an FCN 112 | architecture [2] using residual units [3] as feature extractors. 113 | Downsampling and upsampling of features is done via strided convolutions 114 | and transpose convolutions, respectively. On each resolution scale s 115 | are num_residual_units with filter size = filters[s]. strides[s] determine 116 | the downsampling factor at each resolution scale. 117 | 118 | [1] M. Rajchl et al. NeuroNet: Fast and Robust Reproduction of Multiple 119 | Brain Image Segmentation Pipelines. MIDL 2018. 120 | 121 | [2] J. Long et al. Fully convolutional networks for semantic segmentation. 122 | CVPR 2015. 123 | [3] K. He et al. Identity Mappings in Deep Residual Networks. ECCV 2016. 124 | 125 | Args: 126 | inputs (tf.Tensor): Input feature tensor to the network (rank 5 127 | required). 128 | num_classes (int): Number of output classes. 129 | num_res_units (int, optional): Number of residual units at each 130 | resolution scale. 131 | filters (tuple, optional): Number of filters for all residual units at 132 | each resolution scale. 133 | strides (tuple, optional): Stride of the first unit on a resolution 134 | scale. 135 | mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: 136 | TRAIN, EVAL or PREDICT 137 | use_bias (bool, optional): Boolean, whether the layer uses a bias. 138 | kernel_initializer (TYPE, optional): An initializer for the convolution 139 | kernel. 140 | bias_initializer (TYPE, optional): An initializer for the bias vector. 141 | If None, no bias will be applied. 142 | kernel_regularizer (None, optional): Optional regularizer for the 143 | convolution kernel. 144 | bias_regularizer (None, optional): Optional regularizer for the bias 145 | vector. 146 | 147 | Returns: 148 | dict: dictionary of output tensors 149 | """ 150 | outputs = {} 151 | assert len(strides) == len(filters) 152 | assert len(inputs.get_shape().as_list()) == 5, \ 153 | 'inputs are required to have a rank of 5.' 154 | assert len(protocols) == len(num_classes) 155 | 156 | conv_params = {'use_bias': use_bias, 157 | 'kernel_initializer': kernel_initializer, 158 | 'bias_initializer': bias_initializer, 159 | 'kernel_regularizer': kernel_regularizer, 160 | 'bias_regularizer': bias_regularizer} 161 | 162 | x = inputs 163 | 164 | # Inital convolution with filters[0] 165 | x = tf.layers.conv3d(inputs=x, 166 | filters=filters[0], 167 | kernel_size=(3, 3, 3), 168 | strides=strides[0], 169 | padding='same', 170 | **conv_params) 171 | 172 | tf.logging.info('Init conv tensor shape {}'.format(x.get_shape())) 173 | 174 | # Residual feature encoding blocks with num_res_units at different 175 | # resolution scales res_scales 176 | res_scales = [x] 177 | saved_strides = [] 178 | with tf.variable_scope('encoder'): 179 | for res_scale in range(1, len(filters)): 180 | 181 | # Features are downsampled via strided convolutions. These are defined 182 | # in `strides` and subsequently saved 183 | with tf.variable_scope('unit_{}_0'.format(res_scale)): 184 | 185 | x = vanilla_residual_unit_3d( 186 | inputs=x, 187 | out_filters=filters[res_scale], 188 | strides=strides[res_scale], 189 | mode=mode) 190 | saved_strides.append(strides[res_scale]) 191 | 192 | for i in range(1, num_res_units): 193 | 194 | with tf.variable_scope('unit_{}_{}'.format(res_scale, i)): 195 | 196 | x = vanilla_residual_unit_3d( 197 | inputs=x, 198 | out_filters=filters[res_scale], 199 | strides=(1, 1, 1), 200 | mode=mode) 201 | res_scales.append(x) 202 | 203 | tf.logging.info('Encoder at res_scale {} tensor shape: {}'.format( 204 | res_scale, x.get_shape())) 205 | 206 | outputs['encoder_out'] = x 207 | 208 | tails = [] 209 | for tail in range(len(num_classes)): 210 | # Create a separate prediction tail for each labeling protocol to learn 211 | with tf.variable_scope('tail_{}'.format(tail)): 212 | x = outputs['encoder_out'] 213 | 214 | for res_scale in range(len(filters) - 2, -1, -1): 215 | # Upscore layers [2] reconstruct the predictions to 216 | # higher resolution scales 217 | with tf.variable_scope('upscore_{}'.format(res_scale)): 218 | x = upscore_layer_3d( 219 | inputs=x, 220 | inputs2=res_scales[res_scale], 221 | out_filters=num_classes[tail], 222 | strides=saved_strides[res_scale], 223 | mode=mode, 224 | **conv_params) 225 | 226 | tf.logging.info('Decoder at res_scale {} tensor shape: {}'.format( 227 | res_scale, x.get_shape())) 228 | 229 | # Last convolution 230 | with tf.variable_scope('last'): 231 | tails.append(tf.layers.conv3d(inputs=x, 232 | filters=num_classes[tail], 233 | kernel_size=(1, 1, 1), 234 | strides=(1, 1, 1), 235 | padding='same', 236 | **conv_params)) 237 | 238 | tf.logging.info('Output tensor shape {}'.format(x.get_shape())) 239 | 240 | # Define the outputs 241 | for i in range(len(tails)): 242 | outputs['logits_{}'.format(protocols[i])] = tails[i] 243 | 244 | with tf.variable_scope('pred'): 245 | outputs['y_prob_{}'.format(protocols[i])] = tf.nn.softmax(tails[i]) 246 | outputs['y_{}'.format(protocols[i])] = tf.argmax(tails[i], axis=-1) 247 | 248 | return outputs 249 | -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/parse_csvs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pandas as pd\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 8, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "['fsl_fast']\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import json\n", 29 | "config_fn = 'fsl_fast_config.json'\n", 30 | "with open(config_fn) as f:\n", 31 | " config = json.load(f)\n", 32 | "print (config['targets']['protocols'])" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 35, 38 | "metadata": { 39 | "collapsed": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "# paths\n", 44 | "base_path = '/vol/biobank/12579/brain/images'\n", 45 | "\n", 46 | "# list all ids \n", 47 | "all_ids = sorted(os.listdir(base_path))\n", 48 | "\n", 49 | "# check if all of them contain a T1w image and required segmentation(s)\n", 50 | "valid_ids = []\n", 51 | "for i in all_ids:\n", 52 | " if (os.path.isfile(os.path.join(base_path, i,'T1.nii.gz')) and \n", 53 | " os.path.isfile(os.path.join(base_path, i,'T1_first_all_fast_firstseg.nii.gz')) and\n", 54 | " os.path.isfile(os.path.join(base_path, i,'T1_brain_seg.nii.gz')) and\n", 55 | " os.path.isfile(os.path.join(base_path, i,'T1_brain_seg_spm.nii.gz')) and\n", 56 | " os.path.isfile(os.path.join(base_path, i,'T1_MALPEM.nii.gz')) and\n", 57 | " os.path.isfile(os.path.join(base_path, i,'T1_MALPEM_tissues.nii.gz'))):\n", 58 | " valid_ids.append(i)\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 36, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "5723\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "print(len(valid_ids))" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 37, 81 | "metadata": { 82 | "collapsed": true 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "def get_full_paths(_id, fn):\n", 87 | " return os.path.join(base_path, _id, fn)\n", 88 | "\n", 89 | "hdr = ['id', 't1', 'fsl_fast', 'fsl_first', 'spm_tissue', 'malp_em', 'malp_em_tissue']\n", 90 | "valid_df = []\n", 91 | "for i in valid_ids:\n", 92 | " valid_df.append([i, \n", 93 | " get_full_paths(i, 'T1.nii.gz'),\n", 94 | " get_full_paths(i, 'T1_brain_seg.nii.gz'),\n", 95 | " get_full_paths(i, 'T1_first_all_fast_firstseg.nii.gz'),\n", 96 | " get_full_paths(i, 'T1_brain_seg_spm.nii.gz'),\n", 97 | " get_full_paths(i, 'T1_MALPEM.nii.gz'),\n", 98 | " get_full_paths(i, 'T1_MALPEM_tissues.nii.gz')])\n", 99 | "\n" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 38, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "['1000845', '/vol/biobank/12579/brain/images/1000845/T1.nii.gz', '/vol/biobank/12579/brain/images/1000845/T1_brain_seg.nii.gz', '/vol/biobank/12579/brain/images/1000845/T1_first_all_fast_firstseg.nii.gz', '/vol/biobank/12579/brain/images/1000845/T1_brain_seg_spm.nii.gz', '/vol/biobank/12579/brain/images/1000845/T1_MALPEM.nii.gz', '/vol/biobank/12579/brain/images/1000845/T1_MALPEM_tissues.nii.gz']\n", 112 | "(5723, 7)\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "print(valid_df[0])\n", 118 | "print(np.array(valid_df).shape)\n" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 39, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "5000\n", 131 | "10\n", 132 | "713\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "\n", 138 | "# 5k training ids\n", 139 | "write_df = valid_df[:5000]\n", 140 | "pd.DataFrame(write_df).to_csv('train.csv', index=False, header=hdr)\n", 141 | "print(len(write_df))\n", 142 | "\n", 143 | "# 10 validation ids\n", 144 | "write_df = valid_df[5000:5010]\n", 145 | "pd.DataFrame(write_df).to_csv('val.csv', index=False, header=hdr)\n", 146 | "print(len(write_df))\n", 147 | "\n", 148 | "# 713 test ids\n", 149 | "write_df = valid_df[5010:]\n", 150 | "pd.DataFrame(write_df).to_csv('test.csv', index=False, header=hdr)\n", 151 | "print(len(write_df))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "Python 3", 165 | "language": "python", 166 | "name": "python3" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.5.2" 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 2 183 | } 184 | -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/reader.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | from dltk.io.augmentation import extract_random_example_array 6 | from dltk.io.preprocessing import whitening 7 | 8 | ALL_PROTOCOLS = ['fsl_fast', 'fsl_first', 'spm_tissue', 'malp_em', 'malp_em_tissue'] 9 | NUM_CLASSES = [4, 16, 4, 139, 6] 10 | 11 | 12 | def map_labels(lbl, protocol=None, convert_to_protocol=False): 13 | """ 14 | Map dataset specific label id protocols to consecutive integer ids for training and back. 15 | Parameters 16 | ---------- 17 | lbl : np.array 18 | a label map 19 | protocol : str 20 | a string describing the labeling protocol 21 | convert_to_protocol : bool 22 | flag to determine to convert from or to the protocol ids 23 | """ 24 | 25 | """ 26 | SPM tissue ids: 27 | 0 Background 28 | 1 CSF 29 | 2 GM 30 | 3 WM 31 | """ 32 | spm_tissue_ids = range(4) 33 | 34 | """ 35 | Fast ids: 36 | 0 Background 37 | 1 CSF 38 | 2 GM 39 | 3 WM 40 | """ 41 | fast_ids = range(4) 42 | 43 | """ 44 | First ids: 45 | 0 Background 46 | 10 Left-Thalamus-Proper 40 47 | 11 Left-Caudate 30 48 | 12 Left-Putamen 40 49 | 13 Left-Pallidum 40 50 | 16 Brain-Stem /4th Ventricle 40 51 | 17 Left-Hippocampus 30 52 | 18 Left-Amygdala 50 53 | 26 Left-Accumbens-area 50 54 | 49 Right-Thalamus-Proper 40 55 | 50 Right-Caudate 30 56 | 51 Right-Putamen 40 57 | 52 Right-Pallidum 40 58 | 53 Right-Hippocampus 30 59 | 54 Right-Amygdala 50 60 | 58 Right-Accumbens-area 50 61 | """ 62 | first_ids = [0, 10, 11, 12, 13, 16, 17, 18, 26, 49, 50, 51, 52, 53, 54, 58] 63 | 64 | """ 65 | MALP-EM tissue ids: 66 | 0 Background 67 | 1 Ventricles 68 | 2 Sub-cortical and cerebellum GM 69 | 3 WM 70 | 4 Cortical GM 71 | 5 72 | """ 73 | malpem_tissue_ids = range(6) 74 | 75 | """ 76 | MALP-EM ids: 77 | 0 Background 78 | 1 3rdVentricle 79 | 2 4thVentricle 80 | 3 RightAccumbensArea 81 | 4 LeftAccumbensArea 82 | 5 RightAmygdala 83 | 6 LeftAmygdala 84 | 7 BrainStem 85 | 8 RightCaudate 86 | 9 LeftCaudate 87 | 10 RightCerebellumExterior 88 | 11 LeftCerebellumExterior 89 | 12 RightCerebellumWhiteMatter 90 | 13 LeftCerebellumWhiteMatter 91 | 14 RightCerebralExterior 92 | 15 LeftCerebralExterior 93 | 16 RightCerebralWhiteMatter 94 | 17 LeftCerebralWhiteMatter 95 | 18 CSF 96 | 19 RightHippocampus 97 | 20 LeftHippocampus 98 | 21 RightInfLatVent 99 | 22 LeftInfLatVent 100 | 23 RightLateralVentricle 101 | 24 LeftLateralVentricle 102 | 25 RightPallidum 103 | 26 LeftPallidum 104 | 27 RightPutamen 105 | 28 LeftPutamen 106 | 29 RightThalamusProper 107 | 30 LeftThalamusProper 108 | 31 RightVentralDC 109 | 32 LeftVentralDC 110 | 33 Rightvessel 111 | 34 Leftvessel 112 | 35 OpticChiasm 113 | 36 CerebellarVermalLobulesI-V 114 | 37 CerebellarVermalLobulesVI-VII 115 | 38 CerebellarVermalLobulesVIII-X 116 | 39 LeftBasalForebrain 117 | 40 RightBasalForebrain 118 | 41 RightACg Ganteriorcingulategyrus Right 119 | 42 LeftACg Ganteriorcingulategyrus Left 120 | 43 RightAIns Anteriorinsula Right 121 | 44 LeftAIns Anteriorinsula Left 122 | 45 RightAOrG Anteriororbitalgyrus Right 123 | 46 LeftAOrG Anteriororbitalgyrus Left 124 | 47 RightAnG Angulargyrus Right 125 | 48 LeftAnG Angulargyrus Left 126 | 49 RightCalc Calcarinecortex Right 127 | 50 LeftCalc Calcarinecortex Left 128 | 51 RightCO Centraloperculum Right 129 | 52 LeftCO Centraloperculum Left 130 | 53 RightCun Cuneus Right 131 | 54 LeftCun Cuneus Left 132 | 55 RightEntA Ententorhinalarea Right 133 | 56 LeftEntA Ententorhinalarea Left 134 | 57 RightFO Frontaloperculum Right 135 | 58 LeftFO Frontaloperculum Left 136 | 59 RightFRP Frontalpole Right 137 | 60 LeftFRP Frontalpole Left 138 | 61 RightFuG Fusiformgyrus Right 139 | 62 LeftFuG Fusiformgyrus Left 140 | 63 RightGRe Gyrusrectus Right 141 | 64 LeftGRe Gyrusrectus Left 142 | 65 RightIOG Inferioroccipitalgyrus Right 143 | 66 LeftIOG Inferioroccipitalgyrus Left 144 | 67 RightITG Inferiortemporalgyrus Right 145 | 68 LeftITG Inferiortemporalgyrus Left 146 | 69 RightLiG Lingualgyrus Right 147 | 70 LeftLiG Lingualgyrus Left 148 | 71 RightLOrG Lateralorbitalgyrus Right 149 | 72 LeftLOrG Lateralorbitalgyrus Left 150 | 73 RightMCgG Middlecingulategyrus Right 151 | 74 LeftMCgG Middlecingulategyrus Left 152 | 75 RightMFC Medialfrontalcortex Right 153 | 76 LeftMFC Medialfrontalcortex Left 154 | 77 RightMFG Middlefrontalgyrus Right 155 | 78 LeftMFG Middlefrontalgyrus Left 156 | 79 RightMOG Middleoccipitalgyrus Right 157 | 80 LeftMOG Middleoccipitalgyrus Left 158 | 81 RightMOrG Medialorbitalgyrus Right 159 | 82 LeftMOrG Medialorbitalgyrus Left 160 | 83 RightMPoG Postcentralgyrusmedialsegment Right 161 | 84 LeftMPoG Postcentralgyrusmedialsegment Left 162 | 85 RightMPrG Precentralgyrusmedialsegment Right 163 | 86 LeftMPrG Precentralgyrusmedialsegment Left 164 | 87 RightMSFG Superiorfrontalgyrusmedialsegment Right 165 | 88 LeftMSFG Superiorfrontalgyrusmedialsegment Left 166 | 89 RightMTG Middletemporalgyrus Right 167 | 90 LeftMTG Middletemporalgyrus Left 168 | 91 RightOCP Occipitalpole Right 169 | 92 LeftOCP Occipitalpole Left 170 | 93 RightOFuG Occipitalfusiformgyrus Right 171 | 94 LeftOFuG Occipitalfusiformgyrus Left 172 | 95 RightOpIFG Opercularpartoftheinferiorfrontalgyrus Right 173 | 96 LeftOpIFG Opercularpartoftheinferiorfrontalgyrus Left 174 | 97 RightOrIFG Orbitalpartoftheinferiorfrontalgyrus Right 175 | 98 LeftOrIFG Orbitalpartoftheinferiorfrontalgyrus Left 176 | 99 RightPCgG Posteriorcingulategyrus Right 177 | 100 LeftPCgG Posteriorcingulategyrus Left 178 | 101 RightPCu Precuneus Right 179 | 102 LeftPCu Precuneus Left 180 | 103 RightPHG Parahippocampalgyrus Right 181 | 104 LeftPHG Parahippocampalgyrus Left 182 | 105 RightPIns Posteriorinsula Right 183 | 106 LeftPIns Posteriorinsula Left 184 | 107 RightPO Parietaloperculum Right 185 | 108 LeftPO Parietaloperculum Left 186 | 109 RightPoG Postcentralgyrus Right 187 | 110 LeftPoG Postcentralgyrus Left 188 | 111 RightPOrG Posteriororbitalgyrus Right 189 | 112 LeftPOrG Posteriororbitalgyrus Left 190 | 113 RightPP Planumpolare Right 191 | 114 LeftPP Planumpolare Left 192 | 115 RightPrG Precentralgyrus Right 193 | 116 LeftPrG Precentralgyrus Left 194 | 117 RightPT Planumtemporale Right 195 | 118 LeftPT Planumtemporale Left 196 | 119 RightSCA Subcallosalarea Right 197 | 120 LeftSCA Subcallosalarea Left 198 | 121 RightSFG Superiorfrontalgyrus Right 199 | 122 LeftSFG Superiorfrontalgyrus Left 200 | 123 RightSMC Supplementarymotorcortex Right 201 | 124 LeftSMC Supplementarymotorcortex Left 202 | 125 RightSMG Supramarginalgyrus Right 203 | 126 LeftSMG Supramarginalgyrus Left 204 | 127 RightSOG Superioroccipitalgyrus Right 205 | 128 LeftSOG Superioroccipitalgyrus Left 206 | 129 RightSPL Superiorparietallobule Right 207 | 130 LeftSPL Superiorparietallobule Left 208 | 131 RightSTG Superiortemporalgyrus Right 209 | 132 LeftSTG Superiortemporalgyrus Left 210 | 133 RightTMP Temporalpole Right 211 | 134 LeftTMP Temporalpole Left 212 | 135 RightTrIFG Triangularpartoftheinferiorfrontalgyrus Right 213 | 136 LeftTrIFG Triangularpartoftheinferiorfrontalgyrus Left 214 | 137 RightTTG Transversetemporalgyrus Right 215 | 138 LeftTTG Transversetemporalgyrus Left 216 | """ 217 | malpem_ids = range(139) 218 | 219 | out_lbl = np.zeros_like(lbl) 220 | 221 | if protocol == 'fsl_fast': 222 | ids = fast_ids 223 | elif protocol == 'fsl_first': 224 | ids = first_ids 225 | elif protocol == 'spm_tissue': 226 | ids = spm_tissue_ids 227 | elif protocol == 'malp_em': 228 | ids = malpem_ids 229 | elif protocol == 'malp_em_tissue': 230 | ids = malpem_tissue_ids 231 | else: 232 | print("Method is not recognised. Exiting.") 233 | return -1 234 | 235 | if convert_to_protocol: 236 | # map from consecutive ints to protocol labels 237 | for i in range(len(ids)): 238 | out_lbl[lbl == i] = ids[i] 239 | else: 240 | # map from protocol labels to consecutive ints 241 | for i in range(len(ids)): 242 | out_lbl[lbl == ids[i]] = i 243 | 244 | return out_lbl 245 | 246 | 247 | def read_fn(file_references, mode, params=None): 248 | """A custom python read function for interfacing with nii image files. 249 | 250 | Args: 251 | file_references (list): A list of lists containing file references, 252 | such as [['id_0', 'image_filename_0', target_value_0], ..., 253 | ['id_N', 'image_filename_N', target_value_N]]. 254 | mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL 255 | or PREDICT. 256 | params (dict, optional): A dictionary to parameterise read_fn ouputs 257 | (e.g. reader_params = {'n_examples': 10, 'example_size': 258 | [64, 64, 64], 'extract_examples': True}, etc.). 259 | 260 | Yields: 261 | dict: A dictionary of reader outputs for dltk.io.abstract_reader. 262 | """ 263 | 264 | if mode == tf.estimator.ModeKeys.TRAIN: 265 | np.random.shuffle(file_references) 266 | 267 | for f in file_references: 268 | 269 | # Read the image nii with sitk 270 | img_id = f[0] 271 | img_fn = f[1] 272 | img_sitk = sitk.ReadImage(str(img_fn)) 273 | img = sitk.GetArrayFromImage(img_sitk) 274 | 275 | # Normalise volume image 276 | img = whitening(img) 277 | 278 | # Create a 4D image (i.e. [x, y, z, channels]) 279 | img = np.expand_dims(img, axis=-1).astype(np.float32) 280 | 281 | if mode == tf.estimator.ModeKeys.PREDICT: 282 | yield {'features': {'x': img}, 283 | 'labels': None, 284 | 'sitk': img_sitk, 285 | 'img_id': img_id} 286 | continue 287 | 288 | # Read the label nii with sitk for each of the protocols 289 | lbls = [] 290 | for p in params['protocols']: 291 | idx = ALL_PROTOCOLS.index(p) 292 | lbl_fn = f[2 + idx] 293 | lbl = sitk.GetArrayFromImage(sitk.ReadImage(str(lbl_fn))).astype(np.int32) 294 | 295 | # Map the label ids to consecutive integers 296 | lbl = map_labels(lbl, protocol=p) 297 | lbls.append(lbl) 298 | 299 | # Check if the reader is supposed to return training examples or 300 | # full images 301 | if params['extract_examples']: 302 | # Concatenate into a list of images and labels and extract 303 | img_lbls_list = [img] + lbls 304 | img_lbls_list = extract_random_example_array( 305 | img_lbls_list, 306 | example_size=params['example_size'], 307 | n_examples=params['n_examples']) 308 | 309 | # Yield each image example and corresponding label protocols 310 | for e in range(params['n_examples']): 311 | yield {'features': {'x': img_lbls_list[0][e].astype(np.float32)}, 312 | 'labels': {params['protocols'][i]: img_lbls_list[1 + i][e] 313 | for i in range(len(params['protocols']))}} 314 | else: 315 | yield {'features': {'x': img}, 316 | 'labels': {params['protocols'][i]: 317 | lbls[i] for i in range(len(params['protocols']))}, 318 | 'sitk': img_sitk, 319 | 'img_id': img_id} 320 | return 321 | -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import tensorflow as tf 11 | 12 | from dltk.core.metrics import dice 13 | from dltk.core.activations import leaky_relu 14 | from dltk.io.abstract_reader import Reader 15 | 16 | from neuronet import neuronet_3d 17 | 18 | from reader import read_fn 19 | import json 20 | 21 | # PARAMS 22 | EVAL_EVERY_N_STEPS = 1000 23 | EVAL_STEPS = 10 24 | 25 | NUM_CHANNELS = 1 26 | 27 | BATCH_SIZE = 1 28 | SHUFFLE_CACHE_SIZE = 16 29 | 30 | MAX_STEPS = 100000 31 | 32 | 33 | # MODEL 34 | def model_fn(features, labels, mode, params): 35 | 36 | # 1. create a model and its outputs 37 | def lrelu(x): 38 | return leaky_relu(x, 0.1) 39 | 40 | protocols = params["protocols"] 41 | 42 | net_output_ops = neuronet_3d(features['x'], 43 | num_classes=params["num_classes"], 44 | protocols=protocols, 45 | num_res_units=params["network"]["num_residual_units"], 46 | filters=params["network"]["filters"], 47 | strides=params["network"]["strides"], 48 | activation=lrelu, 49 | mode=mode) 50 | 51 | # 1.1 Generate predictions only (for `ModeKeys.PREDICT`) 52 | if mode == tf.estimator.ModeKeys.PREDICT: 53 | return tf.estimator.EstimatorSpec( 54 | mode=mode, 55 | predictions=net_output_ops, 56 | export_outputs={'out': tf.estimator.export.PredictOutput(net_output_ops)}) 57 | 58 | # 2. set up a loss function 59 | ce = [] 60 | for p in protocols: 61 | ce.append(tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 62 | logits=net_output_ops['logits_{}'.format(p)], 63 | labels=labels[p]))) 64 | 65 | # Sum the crossentropy losses and divide through number of protocols to be predicted 66 | loss = tf.div(tf.add_n(ce), tf.constant(len(protocols), dtype=tf.float32)) 67 | 68 | # 3. define a training op and ops for updating moving averages (i.e. for batch normalisation) 69 | global_step = tf.train.get_global_step() 70 | optimiser = tf.train.AdamOptimizer(learning_rate=params["learning_rate"], epsilon=1e-5) 71 | 72 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 73 | with tf.control_dependencies(update_ops): 74 | train_op = optimiser.minimize(loss, global_step=global_step) 75 | 76 | # 4.1 (optional) create custom image summaries for tensorboard 77 | my_image_summaries = {} 78 | my_image_summaries['feat_t1'] = features['x'][0, 64, :, :, 0] 79 | for p in protocols: 80 | my_image_summaries['{}/lbl'.format(p)] = tf.cast(labels[p], tf.float32)[0, 64, :, :] 81 | my_image_summaries['{}/pred'.format(p)] = tf.cast(net_output_ops['y_{}'.format(p)], tf.float32)[0, 64, :, :] 82 | 83 | expected_output_size = [1, 128, 128, 1] # [B, W, H, C] 84 | [tf.summary.image(name, tf.reshape(image, expected_output_size)) 85 | for name, image in my_image_summaries.items()] 86 | 87 | # 4.2 (optional) create custom metric summaries for tensorboard 88 | for i in range(len(protocols)): 89 | p = protocols[i] 90 | c = tf.constant(params["num_classes"][i]) 91 | 92 | mean_dice = tf.reduce_mean(tf.py_func( 93 | dice, [net_output_ops['y_{}'.format(p)], labels[p], c], tf.float32)[1:]) 94 | tf.summary.scalar('dsc_{}'.format(p), mean_dice) 95 | 96 | # 5. Return EstimatorSpec object 97 | return tf.estimator.EstimatorSpec(mode=mode, 98 | predictions=None, 99 | loss=loss, 100 | train_op=train_op, 101 | eval_metric_ops=None) 102 | 103 | 104 | def train(args, config): 105 | 106 | np.random.seed(42) 107 | tf.set_random_seed(42) 108 | 109 | print('Setting up...') 110 | # Parse csv files for file names 111 | train_filenames = pd.read_csv(args.train_csv, 112 | dtype=object, 113 | keep_default_na=False, 114 | na_values=[]).as_matrix() 115 | 116 | val_filenames = pd.read_csv(args.val_csv, 117 | dtype=object, 118 | keep_default_na=False, 119 | na_values=[]).as_matrix() 120 | 121 | # Set up a data reader to handle the file i/o. 122 | reader_params = { 123 | 'n_examples': 8, 124 | 'example_size': [128, 128, 128], 125 | 'extract_examples': True, 126 | 'protocols': config["protocols"]} 127 | 128 | reader_example_shapes = { 129 | 'features': {'x': reader_params['example_size'] + [NUM_CHANNELS, ]}, 130 | 'labels': {p: reader_params['example_size'] for p in config["protocols"]}} 131 | 132 | reader = Reader(read_fn, 133 | {'features': {'x': tf.float32}, 134 | 'labels': {p: tf.int32 for p in config["protocols"]}}) 135 | 136 | # Get input functions and queue initialisation hooks for training and validation data 137 | train_input_fn, train_qinit_hook = reader.get_inputs( 138 | train_filenames, 139 | tf.estimator.ModeKeys.TRAIN, 140 | example_shapes=reader_example_shapes, 141 | batch_size=BATCH_SIZE, 142 | shuffle_cache_size=SHUFFLE_CACHE_SIZE, 143 | params=reader_params) 144 | 145 | val_input_fn, val_qinit_hook = reader.get_inputs( 146 | val_filenames, 147 | tf.estimator.ModeKeys.EVAL, 148 | example_shapes=reader_example_shapes, 149 | batch_size=BATCH_SIZE, 150 | shuffle_cache_size=SHUFFLE_CACHE_SIZE, 151 | params=reader_params) 152 | 153 | # Instantiate the neural network estimator 154 | nn = tf.estimator.Estimator(model_fn=model_fn, 155 | model_dir=config["model_path"], 156 | params=config, 157 | config=tf.estimator.RunConfig(session_config=tf.ConfigProto())) 158 | 159 | # Hooks for validation summaries 160 | val_summary_hook = tf.contrib.training.SummaryAtEndHook( 161 | os.path.join(config["model_path"], 'eval')) 162 | step_cnt_hook = tf.train.StepCounterHook( 163 | every_n_steps=EVAL_EVERY_N_STEPS, output_dir=config["model_path"]) 164 | 165 | print('Starting training...') 166 | try: 167 | for _ in range(MAX_STEPS // EVAL_EVERY_N_STEPS): 168 | nn.train(input_fn=train_input_fn, 169 | hooks=[train_qinit_hook, step_cnt_hook], 170 | steps=EVAL_EVERY_N_STEPS) 171 | 172 | results_val = nn.evaluate(input_fn=val_input_fn, 173 | hooks=[val_qinit_hook, val_summary_hook], 174 | steps=EVAL_STEPS) 175 | print('Step = {}; val loss = {:.5f};'.format(results_val['global_step'], results_val['loss'])) 176 | 177 | except KeyboardInterrupt: 178 | pass 179 | 180 | print('Stopping now.') 181 | export_dir = nn.export_savedmodel( 182 | export_dir_base=config["model_path"], 183 | serving_input_receiver_fn=reader.serving_input_receiver_fn(reader_example_shapes)) 184 | print('Model saved to {}.'.format(export_dir)) 185 | 186 | 187 | if __name__ == '__main__': 188 | 189 | # Set up argument parser 190 | parser = argparse.ArgumentParser(description='NeuroNet training script') 191 | parser.add_argument('--restart', default=False, action='store_true') 192 | parser.add_argument('--verbose', default=False, action='store_true') 193 | parser.add_argument('--cuda_devices', '-c', default='0') 194 | 195 | parser.add_argument('--train_csv', default='train.csv') 196 | parser.add_argument('--val_csv', default='val.csv') 197 | parser.add_argument('--config', default='config_all.json') 198 | 199 | args = parser.parse_args() 200 | 201 | # Set verbosity 202 | if args.verbose: 203 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 204 | tf.logging.set_verbosity(tf.logging.INFO) 205 | else: 206 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 207 | tf.logging.set_verbosity(tf.logging.ERROR) 208 | 209 | # GPU allocation options 210 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 211 | 212 | # Parse the run config 213 | with open(args.config) as f: 214 | config = json.load(f) 215 | 216 | # Handle restarting and resuming training 217 | if args.restart: 218 | print('Restarting training from scratch.') 219 | os.system('rm -rf {}'.format(config["model_path"])) 220 | 221 | if not os.path.isdir(config["model_path"]): 222 | os.system('mkdir -p {}'.format(config["model_path"])) 223 | else: 224 | print('Resuming training on model_path {}'.format(config["model_path"])) 225 | 226 | # Call training 227 | train(args, config) 228 | -------------------------------------------------------------------------------- /ukbb_neuronet_brain_segmentation/val.csv: -------------------------------------------------------------------------------- 1 | id,t1,fsl_fast,fsl_first,spm_tissue,malp_em,malp_em_tissue 2 | 5396937,/vol/biobank/12579/brain/images/5396937/T1.nii.gz,/vol/biobank/12579/brain/images/5396937/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5396937/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5396937/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5396937/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5396937/T1_MALPEM_tissues.nii.gz 3 | 5397135,/vol/biobank/12579/brain/images/5397135/T1.nii.gz,/vol/biobank/12579/brain/images/5397135/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5397135/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5397135/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5397135/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5397135/T1_MALPEM_tissues.nii.gz 4 | 5397714,/vol/biobank/12579/brain/images/5397714/T1.nii.gz,/vol/biobank/12579/brain/images/5397714/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5397714/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5397714/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5397714/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5397714/T1_MALPEM_tissues.nii.gz 5 | 5397841,/vol/biobank/12579/brain/images/5397841/T1.nii.gz,/vol/biobank/12579/brain/images/5397841/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5397841/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5397841/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5397841/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5397841/T1_MALPEM_tissues.nii.gz 6 | 5397960,/vol/biobank/12579/brain/images/5397960/T1.nii.gz,/vol/biobank/12579/brain/images/5397960/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5397960/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5397960/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5397960/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5397960/T1_MALPEM_tissues.nii.gz 7 | 5398988,/vol/biobank/12579/brain/images/5398988/T1.nii.gz,/vol/biobank/12579/brain/images/5398988/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5398988/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5398988/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5398988/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5398988/T1_MALPEM_tissues.nii.gz 8 | 5399804,/vol/biobank/12579/brain/images/5399804/T1.nii.gz,/vol/biobank/12579/brain/images/5399804/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5399804/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5399804/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5399804/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5399804/T1_MALPEM_tissues.nii.gz 9 | 5401227,/vol/biobank/12579/brain/images/5401227/T1.nii.gz,/vol/biobank/12579/brain/images/5401227/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5401227/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5401227/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5401227/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5401227/T1_MALPEM_tissues.nii.gz 10 | 5402562,/vol/biobank/12579/brain/images/5402562/T1.nii.gz,/vol/biobank/12579/brain/images/5402562/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5402562/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5402562/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5402562/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5402562/T1_MALPEM_tissues.nii.gz 11 | 5403873,/vol/biobank/12579/brain/images/5403873/T1.nii.gz,/vol/biobank/12579/brain/images/5403873/T1_brain_seg.nii.gz,/vol/biobank/12579/brain/images/5403873/T1_first_all_fast_firstseg.nii.gz,/vol/biobank/12579/brain/images/5403873/T1_brain_seg_spm.nii.gz,/vol/biobank/12579/brain/images/5403873/T1_MALPEM.nii.gz,/vol/biobank/12579/brain/images/5403873/T1_MALPEM_tissues.nii.gz 12 | --------------------------------------------------------------------------------