├── .gitignore ├── LICENSE.txt ├── README.md ├── SynthSeg ├── __init__.py ├── brain_generator.py ├── estimate_priors.py ├── evaluate.py ├── labels_to_image_model.py ├── metrics_model.py ├── model_inputs.py ├── predict.py ├── predict_denoiser.py ├── predict_group.py ├── predict_qc.py ├── predict_synthseg.py ├── sample_segmentation_pairs_d.py ├── training.py ├── training_denoiser.py ├── training_group.py ├── training_qc.py ├── training_supervised.py ├── validate.py ├── validate_denoiser.py ├── validate_group.py └── validate_qc.py ├── bibtex.bib ├── data ├── README_figures │ ├── new_features.png │ ├── overview.png │ ├── robust.png │ ├── segmentations.png │ ├── table_versions.png │ ├── training_data.png │ └── youtube_link.png ├── labels table.txt ├── labels_classes_priors │ ├── generation_classes.npy │ ├── generation_classes_contrast_specific.npy │ ├── generation_labels.npy │ ├── prior_means_t1.npy │ ├── prior_stds_t1.npy │ ├── synthseg_denoiser_labels_2.0.npy │ ├── synthseg_parcellation_labels.npy │ ├── synthseg_parcellation_names.npy │ ├── synthseg_qc_labels.npy │ ├── synthseg_qc_labels_2.0.npy │ ├── synthseg_qc_names.npy │ ├── synthseg_qc_names_2.0.npy │ ├── synthseg_segmentation_labels.npy │ ├── synthseg_segmentation_labels_2.0.npy │ ├── synthseg_segmentation_names.npy │ ├── synthseg_segmentation_names_2.0.npy │ ├── synthseg_topological_classes.npy │ └── synthseg_topological_classes_2.0.npy ├── training_label_maps │ ├── training_seg_01.nii.gz │ ├── training_seg_02.nii.gz │ ├── training_seg_03.nii.gz │ ├── training_seg_04.nii.gz │ ├── training_seg_05.nii.gz │ ├── training_seg_06.nii.gz │ ├── training_seg_07.nii.gz │ ├── training_seg_08.nii.gz │ ├── training_seg_09.nii.gz │ ├── training_seg_10.nii.gz │ ├── training_seg_11.nii.gz │ ├── training_seg_12.nii.gz │ ├── training_seg_13.nii.gz │ ├── training_seg_14.nii.gz │ ├── training_seg_15.nii.gz │ ├── training_seg_16.nii.gz │ ├── training_seg_17.nii.gz │ ├── training_seg_18.nii.gz │ ├── training_seg_19.nii.gz │ └── training_seg_20.nii.gz └── tutorial_7 │ ├── noisy_segmentations_d │ ├── 0001.nii.gz │ ├── 0002.nii.gz │ └── 0003.nii.gz │ ├── segmentation_labels_s1.npy │ └── target_segmentations_d │ ├── 0001.nii.gz │ ├── 0002.nii.gz │ └── 0003.nii.gz ├── ext ├── __init__.py ├── lab2im │ ├── __init__.py │ ├── edit_tensors.py │ ├── edit_volumes.py │ ├── image_generator.py │ ├── lab2im_model.py │ ├── layers.py │ └── utils.py └── neuron │ ├── __init__.py │ ├── layers.py │ ├── models.py │ └── utils.py ├── models └── synthseg_1.0.h5 ├── requirements_python3.6.txt ├── requirements_python3.8.txt ├── scripts ├── commands │ ├── SynthSeg_predict.py │ ├── predict.py │ ├── training.py │ └── training_supervised.py └── tutorials │ ├── 1-generation_visualisation.py │ ├── 2-generation_explained.py │ ├── 3-training.py │ ├── 4-prediction.py │ ├── 5-generation_advanced.py │ ├── 6-intensity_estimation.py │ └── 7-synthseg+.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | SynthSeg.egg-info 4 | venv 5 | **/__pycache__/ 6 | .idea 7 | 8 | ext/pynd 9 | ext/statannot 10 | 11 | data/labels_classes_priors/generation_classes_2.0.npy 12 | data/labels_classes_priors/generation_labels_2.0.npy 13 | 14 | models/synthseg_2.0.h5 15 | models/synthseg_parc_2.0.h5 16 | models/synthseg_qc_2.0.h5 17 | models/synthseg_robust_2.0.h5 18 | 19 | scripts/checks 20 | scripts/commands/SynthSeg_predict_claustrum.py 21 | scripts/previous_papers 22 | scripts/*.py 23 | 24 | SynthSeg/auc_roc_delong.py 25 | SynthSeg/boxplots.py 26 | SynthSeg/brain_generator_* 27 | SynthSeg/check_* 28 | SynthSeg/predict_synthseg_* 29 | SynthSeg/predict_vae.py 30 | SynthSeg/qc.py 31 | SynthSeg/training_heart.py 32 | SynthSeg/training_vae.py 33 | SynthSeg/validate_vae.py 34 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SynthSeg 2 | 3 | 4 | In this repository, we present SynthSeg, the first deep learning tool for segmentation of brain scans of 5 | any contrast and resolution. SynthSeg works out-of-the-box without any retraining, and is robust to: 6 | - any contrast 7 | - any resolution up to 10mm slice spacing 8 | - a wide array of populations: from young and healthy to ageing and diseased 9 | - scans with or without preprocessing: bias field correction, skull stripping, normalisation, etc. 10 | - white matter lesions. 11 | \ 12 | \ 13 | ![Generation examples](data/README_figures/segmentations.png) 14 | 15 | 16 | \ 17 | SynthSeg was first presented for the automated segmentation of brain scans of any contrast and resolution. 18 | 19 | **SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining** \ 20 | B. Billot, D.N. Greve, O. Puonti, A. Thielscher, K. Van Leemput, B. Fischl, A.V. Dalca, J.E. Iglesias \ 21 | Medical Image Analysis (2023) \ 22 | [ [article](https://www.sciencedirect.com/science/article/pii/S1361841523000506) | [arxiv](https://arxiv.org/abs/2107.09559) | [bibtex](bibtex.bib) ] 23 | \ 24 | \ 25 | Then, we extended it to work on heterogeneous clinical scans, and to perform cortical parcellation and automated 26 | quality control. 27 | 28 | **Robust machine learning segmentation for large-scale analysis of heterogeneous clinical brain MRI datasets** \ 29 | B. Billot, M. Colin, Y. Cheng, S.E. Arnold, S. Das, J.E. Iglesias \ 30 | PNAS (2023) \ 31 | [ [article](https://www.pnas.org/doi/full/10.1073/pnas.2216399120#bibliography) | [arxiv](https://arxiv.org/abs/2203.01969) | [bibtex](bibtex.bib) ] 32 | 33 | \ 34 | Here, we distribute our model to enable users to run SynthSeg on their own data. We emphasise that 35 | predictions are always given at 1mm isotropic resolution (regardless of the input resolution). The code can be run on 36 | the GPU (~15s per scan) or on the CPU (~1min). 37 | 38 | 39 | ---------------- 40 | 41 | ### New features and updates 42 | 43 | \ 44 | 01/03/2023: **The papers for SynthSeg and SynthSeg 2.0 are out! :open_book: :open_book:** \ 45 | After a long review process for SynthSeg (Medical Image Analysis), and a much faster one for SynthSeg 2.0 (PNAS), both 46 | papers have been accepted nearly at the same time ! See the references above, or in the citation section. 47 | 48 | \ 49 | 04/10/2022: **SynthSeg is available with Matlab!** :star: \ 50 | We are delighted that Matlab 2022b (and onwards) now includes SynthSeg in its Medical Image 51 | Toolbox. They have a [documented example](https://www.mathworks.com/help/medical-imaging/ug/Brain-MRI-Segmentation-Using-Trained-3-D-U-Net.html) 52 | on how to use it. But, to simplify things, we wrote our own Matlab wrapper, which you can call in one single line. 53 | Just download [this zip file](https://liveuclac-my.sharepoint.com/:u:/g/personal/rmappmb_ucl_ac_uk/EctEe3hOP8dDh1hYHlFS_rUBo80yFg7MQY5WnagHlWcS6A?e=e8bK0f), 54 | uncompress it, open Matlab, and type `help SynthSeg` for instructions. 55 | 56 | \ 57 | 29/06/2022: **SynthSeg 2.0 is out !** :v: \ 58 | In addition to whole-brain segmentation, it now also performs **Cortical parcellation, automated QC, and intracranial 59 | volume (ICV) estimation** (see figure below). Also, most of these features are compatible with SynthSeg 1.0. (see table). 60 | \ 61 | \ 62 | ![new features](data/README_figures/new_features.png) 63 | 64 | ![table versions](data/README_figures/table_versions.png) 65 | 66 | \ 67 | 01/03/2022: **Robust version** :hammer: \ 68 | SynthSeg sometimes falters on scans with low signal-to-noise ratio, or with very low tissue contrast. For this reason, 69 | we developed a new model for increased robustness, named "SynthSeg-robust". You can use this mode when SynthSeg gives 70 | results like in the figure below: 71 | \ 72 | \ 73 | ![Robust](data/README_figures/robust.png) 74 | 75 | \ 76 | 29/10/2021: **SynthSeg is now available on the dev version of 77 | [FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/DownloadAndInstall) !!** :tada: \ 78 | See [here](https://surfer.nmr.mgh.harvard.edu/fswiki/SynthSeg) on how to use it. 79 | 80 | ---------------- 81 | 82 | ### Try it in one command ! 83 | 84 | Once all the python packages are installed (see below), you can simply test SynthSeg on your own data with: 85 | ``` 86 | python ./scripts/commands/SynthSeg_predict.py --i --o [--parc --robust --ct --vol --qc --post --resample ] 87 | ``` 88 | 89 | 90 | where: 91 | - `` path to a scan to segment, or to a folder. This can also be the path to a text file, where each line is the 92 | path of an image to segment. 93 | - `` path where the output segmentations will be saved. This must be the same type as `` (i.e., the path 94 | to a file, a folder, or a text file where each line is the path to an output segmentation). 95 | - `--parc` (optional) to perform cortical parcellation in addition to whole-brain segmentation. 96 | - `--robust` (optional) to use the variant for increased robustness (e.g., when analysing clinical data with large space 97 | spacing). This can be slower than the other model. 98 | - `--ct` (optional) use on CT scans in Hounsfield scale. It clips intensities to [0, 80]. 99 | - `` (optional) path to a CSV file where the volumes (in mm3) of all segmented regions will be saved for all scans 100 | (e.g. /path/to/volumes.csv). If `` is a text file, so must be ``, for which each line is the path to a 101 | different CSV file corresponding to one subject only. 102 | - `` (optional) path to a CSV file where QC scores will be saved. The same formatting requirements as `` apply. 103 | - `` (optional) path where the posteriors, given as soft probability maps, will be saved (same formatting 104 | requirements as for ``). 105 | - `` (optional) SynthSeg segmentations are always given at 1mm isotropic resolution. Hence, 106 | images are always resampled internally to this resolution (except if they are already at 1mm resolution). 107 | Use this flag to save the resampled images (same formatting requirements as for ``). 108 | 109 | Additional optional flags are also available: 110 | - `--cpu`: (optional) to enforce the code to run on the CPU, even if a GPU is available. 111 | - `--threads`: (optional) number of threads to be used by Tensorflow (default uses one core). Increase it to decrease 112 | the runtime when using the CPU version. 113 | - `--crop`: (optional) to crop the input images to a given shape before segmentation. This must be divisible by 32. 114 | Images are cropped around their centre, and their segmentations are given at the original size. It can be given as a 115 | single (i.e., `--crop 160`), or several integers (i.e, `--crop 160 128 192`, ordered in RAS coordinates). By default the 116 | whole image is processed. Use this flag for faster analysis or to fit in your GPU. 117 | - `--fast`: (optional) to disable some operations for faster prediction (twice as fast, but slightly less accurate). 118 | This doesn't apply when the --robust flag is used. 119 | - `--v1`: (optional) to run the first version of SynthSeg (SynthSeg 1.0, updated 29/06/2022). 120 | 121 | 122 | **IMPORTANT:** SynthSeg always give results at 1mm isotropic resolution, regardless of the input. However, this can 123 | cause some viewers to not correctly overlay segmentations on their corresponding images. In this case, you can use the 124 | `--resample` flag to obtain a resampled image that lives in the same space as the segmentation, such that they can be 125 | visualised together with any viewer. 126 | 127 | The complete list of segmented structures is available in [labels table.txt](data/labels%20table.txt) along with their 128 | corresponding values. This table also details the order in which the posteriors maps are sorted. 129 | 130 | 131 | ---------------- 132 | 133 | ### Installation 134 | 135 | 1. Clone this repository. 136 | 137 | 2. Create a virtual environment (i.e., with pip or conda) and install all the required packages. \ 138 | These depend on your python version, and here we list the requirements for Python 3.6 139 | ([requirements_3.6](requirements_python3.6.txt)) and Python 3.8 (see [requirements_3.8](requirements_python3.8.txt)). 140 | The choice is yours, but in each case, please stick to the exact package versions.\ 141 | A first solution to install the dependencies, if you use pip, is to run setup.py (with and activated virtual 142 | environment): `python setup.py install`. Otherwise, we also give here the minimal commands to install the required 143 | packages using pip/conda for Python 3.6/3.8. 144 | 145 | ``` 146 | # Conda, Python 3.6: 147 | conda create -n synthseg_36 python=3.6 tensorflow-gpu=2.0.0 keras=2.3.1 h5py==2.10.0 nibabel matplotlib -c anaconda -c conda-forge 148 | 149 | # Conda, Python 3.8: 150 | conda create -n synthseg_38 python=3.8 tensorflow-gpu=2.2.0 keras=2.3.1 nibabel matplotlib -c anaconda -c conda-forge 151 | 152 | # Pip, Python 3.6: 153 | pip install tensorflow-gpu==2.0.0 keras==2.3.1 nibabel==3.2.2 matplotlib==3.3.4 154 | 155 | # Pip, Python 3.8: 156 | pip install tensorflow-gpu==2.2.0 keras==2.3.1 protobuf==3.20.3 numpy==1.23.5 nibabel==5.0.1 matplotlib==3.6.2 157 | ``` 158 | 159 | 3. Go to this link [UCL dropbox](https://liveuclac-my.sharepoint.com/:f:/g/personal/rmappmb_ucl_ac_uk/EtlNnulBSUtAvOP6S99KcAIBYzze7jTPsmFk2_iHqKDjEw?e=rBP0RO), and download the missing models. Then simply copy them to [models](models). 160 | 161 | 4. If you wish to run on the GPU, you will also need to install Cuda (10.0 for Python 3.6, 10.1 for Python 3.8), and 162 | CUDNN (7.6.5 for both). Note that if you used conda, these were already automatically installed. 163 | 164 | That's it ! You're now ready to use SynthSeg ! :tada: 165 | 166 | 167 | ---------------- 168 | 169 | ### How does it work ? 170 | 171 | In short, we train a network with synthetic images sampled on the fly from a generative model based on the forward 172 | model of Bayesian segmentation. Crucially, we adopt a domain randomisation strategy where we fully randomise the 173 | generation parameters which are drawn at each minibatch from uninformative uniform priors. By exposing the network to 174 | extremely variable input data, we force it to learn domain-agnostic features. As a result, SynthSeg is able to readily 175 | segment real scans of any target domain, without retraining or fine-tuning. 176 | 177 | The following figure first illustrates the workflow of a training iteration, and then provides an overview of the 178 | different steps of the generative model: 179 | \ 180 | \ 181 | ![Overview](data/README_figures/overview.png) 182 | \ 183 | \ 184 | Finally we show additional examples of the synthesised images along with an overlay of their target segmentations: 185 | \ 186 | \ 187 | ![Training data](data/README_figures/training_data.png) 188 | \ 189 | \ 190 | If you are interested to learn more about SynthSeg, you can read the associated publication (see below), and watch this 191 | presentation, which was given at MIDL 2020 for a related article on a preliminary version of SynthSeg (robustness to 192 | MR contrast but not resolution). 193 | \ 194 | \ 195 | [![Talk SynthSeg](data/README_figures/youtube_link.png)](https://www.youtube.com/watch?v=Bfp3cILSKZg&t=1s) 196 | 197 | 198 | ---------------- 199 | 200 | ### Train your own model 201 | 202 | This repository contains all the code and data necessary to train, validate, and test your own network. Importantly, the 203 | proposed method only requires a set of anatomical segmentations to be trained (no images), which we include in 204 | [data](data/training_label_maps). While the provided functions are thoroughly documented, we highly recommend to start 205 | with the following tutorials: 206 | 207 | - [1-generation_visualisation](scripts/tutorials/1-generation_visualisation.py): This very simple script shows examples 208 | of the synthetic images used to train SynthSeg. 209 | 210 | - [2-generation_explained](scripts/tutorials/2-generation_explained.py): This second script describes all the parameters 211 | used to control the generative model. We advise you to thoroughly follow this tutorial, as it is essential to understand 212 | how the synthetic data is formed before you start training your own models. 213 | 214 | - [3-training](scripts/tutorials/3-training.py): This scripts re-uses the parameters explained in the previous tutorial 215 | and focuses on the learning/architecture parameters. The script here is the very one we used to train SynthSeg ! 216 | 217 | - [4-prediction](scripts/tutorials/4-prediction.py): This scripts shows how to make predictions, once the network has 218 | been trained. 219 | 220 | - [5-generation_advanced](scripts/tutorials/5-generation_advanced.py): Here we detail more advanced generation options, 221 | in the case of training a version of SynthSeg that is specific to a given contrast and/or resolution (although these 222 | types of variants were shown to be outperformed by the SynthSeg model trained in the 3rd tutorial). 223 | 224 | - [6-intensity_estimation](scripts/tutorials/6-intensity_estimation.py): This script shows how to estimate the 225 | Gaussian priors of the GMM when training a contrast-specific version of SynthSeg. 226 | 227 | - [7-synthseg+](scripts/tutorials/7-synthseg+.py): Finally, we show how the robust version of SynthSeg was 228 | trained. 229 | 230 | These tutorials cover a lot of materials and will enable you to train your own SynthSeg model. Moreover, even more 231 | detailed information is provided in the docstrings of all functions, so don't hesitate to have a look at these ! 232 | 233 | 234 | ---------------- 235 | 236 | ### Content 237 | 238 | - [SynthSeg](SynthSeg): this is the main folder containing the generative model and training function: 239 | 240 | - [labels_to_image_model.py](SynthSeg/labels_to_image_model.py): contains the generative model for MRI scans. 241 | 242 | - [brain_generator.py](SynthSeg/brain_generator.py): contains the class `BrainGenerator`, which is a wrapper around 243 | `labels_to_image_model`. New images can simply be generated by instantiating an object of this class, and call the 244 | method `generate_image()`. 245 | 246 | - [training.py](SynthSeg/training.py): contains code to train the segmentation network (with explanations for all 247 | training parameters). This function also shows how to integrate the generative model in a training setting. 248 | 249 | - [predict.py](SynthSeg/predict.py): prediction and testing. 250 | 251 | - [validate.py](SynthSeg/validate.py): includes code for validation (which has to be done offline on real images). 252 | 253 | - [models](models): this is where you will find the trained model for SynthSeg. 254 | 255 | - [data](data): this folder contains some examples of brain label maps if you wish to train your own SynthSeg model. 256 | 257 | - [script](scripts): contains tutorials as well as scripts to launch trainings and testings from a terminal. 258 | 259 | - [ext](ext): includes external packages, especially the *lab2im* package, and a modified version of *neuron*. 260 | 261 | 262 | ---------------- 263 | 264 | ### Citation/Contact 265 | 266 | This code is under [Apache 2.0](LICENSE.txt) licensing. 267 | 268 | - If you use the **cortical parcellation**, **automated QC**, or **robust version**, please cite the following paper: 269 | 270 | **Robust machine learning segmentation for large-scale analysisof heterogeneous clinical brain MRI datasets** \ 271 | B. Billot, M. Colin, Y. Cheng, S.E. Arnold, S. Das, J.E. Iglesias \ 272 | PNAS (2023) \ 273 | [ [article](https://www.pnas.org/doi/full/10.1073/pnas.2216399120#bibliography) | [arxiv](https://arxiv.org/abs/2203.01969) | [bibtex](bibtex.bib) ] 274 | 275 | 276 | - Otherwise, please cite: 277 | 278 | **SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining** \ 279 | B. Billot, D.N. Greve, O. Puonti, A. Thielscher, K. Van Leemput, B. Fischl, A.V. Dalca, J.E. Iglesias \ 280 | Medical Image Analysis (2023) \ 281 | [ [article](https://www.sciencedirect.com/science/article/pii/S1361841523000506) | [arxiv](https://arxiv.org/abs/2107.09559) | [bibtex](bibtex.bib) ] 282 | 283 | If you have any question regarding the usage of this code, or any suggestions to improve it, please raise an issue or 284 | contact us at: bbillot@mit.edu 285 | -------------------------------------------------------------------------------- /SynthSeg/__init__.py: -------------------------------------------------------------------------------- 1 | from . import brain_generator 2 | from . import estimate_priors 3 | from . import evaluate 4 | from . import labels_to_image_model 5 | from . import metrics_model 6 | from . import model_inputs 7 | from . import predict 8 | from . import training_supervised 9 | from . import training 10 | -------------------------------------------------------------------------------- /SynthSeg/estimate_priors.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import os 19 | import numpy as np 20 | try: 21 | from scipy.stats import median_absolute_deviation 22 | except ImportError: 23 | from scipy.stats import median_abs_deviation as median_absolute_deviation 24 | 25 | 26 | # third-party imports 27 | from ext.lab2im import utils 28 | from ext.lab2im import edit_volumes 29 | 30 | 31 | def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True): 32 | """This function takes an image and corresponding segmentation as inputs. It estimates the mean and std intensity 33 | for all specified label values. Labels can share the same statistics by being regrouped into K classes. 34 | :param image: image from which to evaluate mean intensity and std deviation. 35 | :param segmentation: segmentation of the input image. Must have the same size as image. 36 | :param labels_list: list of labels for which to evaluate mean and std intensity. 37 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. 38 | :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics. 39 | Intensities associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation. 40 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. 41 | It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of 42 | classes. Default is all labels have different classes (K=len(labels_list)). 43 | :param keep_strictly_positive: (optional) whether to only keep strictly positive intensity values when 44 | computing stats. This doesn't apply to the first label in label_list (or class if class_list is provided), for 45 | which we keep positive and zero values, as we consider it to be the background label. 46 | :return: a numpy array of size (2, K), the first row being the mean intensity for each structure, 47 | and the second being the median absolute deviation (robust estimation of std). 48 | """ 49 | 50 | # reformat labels and classes 51 | labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int')) 52 | if classes_list is not None: 53 | classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int')) 54 | else: 55 | classes_list = np.arange(labels_list.shape[0]) 56 | assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length' 57 | 58 | # get unique classes 59 | unique_classes, unique_indices = np.unique(classes_list, return_index=True) 60 | n_classes = len(unique_classes) 61 | if not np.array_equal(unique_classes, np.arange(n_classes)): 62 | raise ValueError('classes_list should only contain values between 0 and K-1, ' 63 | 'where K is the total number of classes. Here K = %d' % n_classes) 64 | 65 | # compute mean/std of specified classes 66 | means = np.zeros(n_classes) 67 | stds = np.zeros(n_classes) 68 | for idx, tmp_class in enumerate(unique_classes): 69 | 70 | # get list of all intensity values for the current class 71 | class_labels = labels_list[classes_list == tmp_class] 72 | intensities = np.empty(0) 73 | for label in class_labels: 74 | tmp_intensities = image[segmentation == label] 75 | intensities = np.concatenate([intensities, tmp_intensities]) 76 | if tmp_class: # i.e. if not background 77 | if keep_strictly_positive: 78 | intensities = intensities[intensities > 0] 79 | 80 | # compute stats for class and put them to the location of corresponding label values 81 | if len(intensities) != 0: 82 | means[idx] = np.nanmedian(intensities) 83 | stds[idx] = median_absolute_deviation(intensities, nan_policy='omit') 84 | 85 | return np.stack([means, stds]) 86 | 87 | 88 | def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3, 89 | rescale=True): 90 | """This function aims at estimating the intensity distributions of K different structure types from a set of images. 91 | The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation. 92 | Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the 93 | parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described 94 | by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation. 95 | This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters. 96 | Structures can share the same statistics by being regrouped into classes of similar structure types. 97 | Images can be multi-modal (n_channels), in which case different statistics are estimated for each modality. 98 | :param image_dir: path of directory with images to estimate the intensity distribution 99 | :param labels_dir: path of directory with segmentation of input images. 100 | They are matched with images by sorting order. 101 | :param labels_list: list of labels for which to evaluate mean and std intensity. 102 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. 103 | :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics. 104 | Intensities associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation. 105 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. 106 | It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of 107 | classes. Default is all labels have different classes (K=len(labels_list)). 108 | :param max_channel: (optional) maximum number of channels to consider if the data is multi-spectral. Default is 3. 109 | :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation 110 | :return: 2 numpy arrays of size (2*n_channels, K), one with the evaluated means/std for the mean 111 | intensity, and one for the mean/std for the standard deviation. 112 | Each block of two rows correspond to a different modality (channel). For each block of two rows, the first row 113 | represents the mean, and the second represents the std. 114 | """ 115 | 116 | # list files 117 | path_images = utils.list_images_in_folder(image_dir) 118 | path_labels = utils.list_images_in_folder(labels_dir) 119 | assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files' 120 | 121 | # reformat list labels and classes 122 | labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int')) 123 | if classes_list is not None: 124 | classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int')) 125 | else: 126 | classes_list = np.arange(labels_list.shape[0]) 127 | assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length' 128 | 129 | # get unique classes 130 | unique_classes, unique_indices = np.unique(classes_list, return_index=True) 131 | n_classes = len(unique_classes) 132 | if not np.array_equal(unique_classes, np.arange(n_classes)): 133 | raise ValueError('classes_list should only contain values between 0 and K-1, ' 134 | 'where K is the total number of classes. Here K = %d' % n_classes) 135 | 136 | # initialise result arrays 137 | n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel) 138 | means = np.zeros((len(path_images), n_classes, n_channels)) 139 | stds = np.zeros((len(path_images), n_classes, n_channels)) 140 | 141 | # loop over images 142 | loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True) 143 | for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)): 144 | loop_info.update(idx) 145 | 146 | # load image and label map 147 | image = utils.load_volume(path_im) 148 | la = utils.load_volume(path_la) 149 | if n_channels == 1: 150 | image = utils.add_axis(image, -1) 151 | 152 | # loop over channels 153 | for channel in range(n_channels): 154 | im = image[..., channel] 155 | if rescale: 156 | im = edit_volumes.rescale_volume(im) 157 | stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list) 158 | means[idx, :, channel] = stats[0, :] 159 | stds[idx, :, channel] = stats[1, :] 160 | 161 | # compute prior parameters for mean/std 162 | mean_means = np.mean(means, axis=0) 163 | std_means = np.std(means, axis=0) 164 | mean_stds = np.mean(stds, axis=0) 165 | std_stds = np.std(stds, axis=0) 166 | 167 | # regroup prior parameters in two different arrays: one for the mean and one for the std 168 | prior_means = np.zeros((2 * n_channels, n_classes)) 169 | prior_stds = np.zeros((2 * n_channels, n_classes)) 170 | for channel in range(n_channels): 171 | prior_means[2 * channel, :] = mean_means[:, channel] 172 | prior_means[2 * channel + 1, :] = std_means[:, channel] 173 | prior_stds[2 * channel, :] = mean_stds[:, channel] 174 | prior_stds[2 * channel + 1, :] = std_stds[:, channel] 175 | 176 | return prior_means, prior_stds 177 | 178 | 179 | def build_intensity_stats(list_image_dir, 180 | list_labels_dir, 181 | result_dir, 182 | estimation_labels, 183 | estimation_classes=None, 184 | max_channel=3, 185 | rescale=True): 186 | """This function aims at estimating the intensity distributions of K different structure types from a set of images. 187 | The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation. 188 | Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the 189 | parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described 190 | by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation. 191 | This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters. 192 | Additionally, it can estimate the 4*K parameters for several image datasets, that we call here n_datasets. 193 | This function writes 2 numpy arrays of size (2*n_datasets, K), one with the evaluated means/std for the mean 194 | intensities, and one for the mean/std for the standard deviations. 195 | In these arrays, each block of two rows refer to a different dataset. 196 | Within each block of two rows, the first row represents the mean, and the second represents the std. 197 | :param list_image_dir: path of folders with images for intensity distribution estimation. 198 | Can be the path of single directory (n_datasets=1), or a list of folders, each being a separate dataset. 199 | Images can be multimodal, in which case each modality is treated as a different dataset, i.e. each modality will 200 | have a separate block (of size (2, K)) in the result arrays. 201 | :param list_labels_dir: path of folders with label maps corresponding to input images. 202 | If list_image_dir is a list of several folders, list_labels_dir can either be a list of folders (one for each image 203 | folder), or the path to a single folder, which will be used for all datasets. 204 | If a dataset has multi-modal images, the same label map is applied to all modalities. 205 | :param result_dir: path of directory where estimated priors will be writen. 206 | :param estimation_labels: labels to estimate intensity statistics from. 207 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. 208 | :param estimation_classes: (optional) enables to regroup structures into classes of similar intensity statistics. 209 | Intensities associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation. 210 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. 211 | It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of 212 | classes. Default is all labels have different classes (K=len(estimation_labels)). 213 | :param max_channel: (optional) maximum number of channels to consider if the data is multi-spectral. Default is 3. 214 | :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation 215 | """ 216 | 217 | # handle results directories 218 | utils.mkdir(result_dir) 219 | 220 | # reformat image/labels dir into lists 221 | list_image_dir = utils.reformat_to_list(list_image_dir) 222 | list_labels_dir = utils.reformat_to_list(list_labels_dir, length=len(list_image_dir)) 223 | 224 | # reformat list estimation labels and classes 225 | estimation_labels = np.array(utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype='int')) 226 | if estimation_classes is not None: 227 | estimation_classes = np.array(utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype='int')) 228 | else: 229 | estimation_classes = np.arange(estimation_labels.shape[0]) 230 | assert len(estimation_classes) == len(estimation_labels), 'estimation labels and classes should be of same length' 231 | 232 | # get unique classes 233 | unique_estimation_classes, unique_indices = np.unique(estimation_classes, return_index=True) 234 | n_classes = len(unique_estimation_classes) 235 | if not np.array_equal(unique_estimation_classes, np.arange(n_classes)): 236 | raise ValueError('estimation_classes should only contain values between 0 and N-1, ' 237 | 'where K is the total number of classes. Here N = %d' % n_classes) 238 | 239 | # loop over dataset 240 | list_datasets_prior_means = list() 241 | list_datasets_prior_stds = list() 242 | for image_dir, labels_dir in zip(list_image_dir, list_labels_dir): 243 | 244 | # get prior stats for dataset 245 | tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset(image_dir, 246 | labels_dir, 247 | estimation_labels, 248 | estimation_classes, 249 | max_channel=max_channel, 250 | rescale=rescale) 251 | 252 | # add stats arrays to list of datasets-wise statistics 253 | list_datasets_prior_means.append(tmp_prior_means) 254 | list_datasets_prior_stds.append(tmp_prior_stds) 255 | 256 | # stack all modalities together 257 | prior_means = np.concatenate(list_datasets_prior_means, axis=0) 258 | prior_stds = np.concatenate(list_datasets_prior_stds, axis=0) 259 | 260 | # save files 261 | np.save(os.path.join(result_dir, 'prior_means.npy'), prior_means) 262 | np.save(os.path.join(result_dir, 'prior_stds.npy'), prior_stds) 263 | 264 | return prior_means, prior_stds 265 | -------------------------------------------------------------------------------- /SynthSeg/metrics_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import numpy as np 19 | import tensorflow as tf 20 | import keras.layers as KL 21 | from keras.models import Model 22 | 23 | # third-party imports 24 | from ext.lab2im import layers 25 | 26 | 27 | def metrics_model(input_model, label_list, metrics='dice'): 28 | 29 | # get prediction 30 | last_tensor = input_model.outputs[0] 31 | input_shape = last_tensor.get_shape().as_list()[1:] 32 | 33 | # check shapes 34 | n_labels = input_shape[-1] 35 | label_list = np.unique(label_list) 36 | assert n_labels == len(label_list), 'label_list should be as long as the posteriors channels' 37 | 38 | # get GT and convert it to probabilistic values 39 | labels_gt = input_model.get_layer('labels_out').output 40 | labels_gt = layers.ConvertLabels(label_list)(labels_gt) 41 | labels_gt = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels_gt) 42 | labels_gt = KL.Reshape(input_shape)(labels_gt) 43 | 44 | # make sure the tensors have the right keras shape 45 | last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list()) 46 | labels_gt._keras_shape = tuple(labels_gt.get_shape().as_list()) 47 | 48 | if metrics == 'dice': 49 | last_tensor = layers.DiceLoss()([labels_gt, last_tensor]) 50 | 51 | elif metrics == 'wl2': 52 | last_tensor = layers.WeightedL2Loss(target_value=5)([labels_gt, last_tensor]) 53 | 54 | else: 55 | raise Exception('metrics should either be "dice or "wl2, got {}'.format(metrics)) 56 | 57 | # create the model and return 58 | model = Model(inputs=input_model.inputs, outputs=last_tensor) 59 | return model 60 | 61 | 62 | class IdentityLoss(object): 63 | """Very simple loss, as the computation of the loss as been directly implemented in the model.""" 64 | def __init__(self, keepdims=True): 65 | self.keepdims = keepdims 66 | 67 | def loss(self, y_true, y_predicted): 68 | """Because the metrics is already calculated in the model, we simply return y_predicted. 69 | We still need to put y_true in the inputs, as it's expected by keras.""" 70 | loss = y_predicted 71 | 72 | tf.debugging.check_numerics(loss, 'Loss not finite') 73 | return loss 74 | -------------------------------------------------------------------------------- /SynthSeg/model_inputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import numpy as np 19 | import numpy.random as npr 20 | 21 | # third-party imports 22 | from ext.lab2im import utils 23 | 24 | 25 | def build_model_inputs(path_label_maps, 26 | n_labels, 27 | batchsize=1, 28 | n_channels=1, 29 | subjects_prob=None, 30 | generation_classes=None, 31 | prior_distributions='uniform', 32 | prior_means=None, 33 | prior_stds=None, 34 | use_specific_stats_for_channel=False, 35 | mix_prior_and_random=False): 36 | """ 37 | This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the 38 | input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch). 39 | :param path_label_maps: list of the paths of the input label maps. 40 | :param n_labels: number of labels in the input label maps. 41 | :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. 42 | :param n_channels: (optional) number of channels to be synthesised. Default is 1. 43 | :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick 44 | the provided label maps at each minibatch. Must be a 1D numpy array, as long as path_label_maps. 45 | :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity 46 | distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence or a 47 | 1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K 48 | is the total number of classes. Default is all labels have different classes. 49 | :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. 50 | Can either be 'uniform', or 'normal'. Default is 'uniform'. 51 | :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because 52 | these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: 53 | 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is 54 | uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each 55 | mini_batch from the same distribution. 56 | 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is 57 | not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch 58 | from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from 59 | N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. 60 | 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived 61 | from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a 62 | modality from the n_mod possibilities, and we sample the GMM means like in 2). 63 | If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel 64 | (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. 65 | 4) the path to such a numpy array. 66 | Default is None, which corresponds to prior_means = [25, 225]. 67 | :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. 68 | Default is None, which corresponds to prior_stds = [5, 25]. 69 | :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be 70 | only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. 71 | :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default 72 | values for half of these cases, and thus generate images of random contrast. 73 | """ 74 | 75 | # allocate unique class to each label if generation classes is not given 76 | if generation_classes is None: 77 | generation_classes = np.arange(n_labels) 78 | n_classes = len(np.unique(generation_classes)) 79 | 80 | # make sure subjects_prob sums to 1 81 | subjects_prob = utils.load_array_if_path(subjects_prob) 82 | if subjects_prob is not None: 83 | subjects_prob /= np.sum(subjects_prob) 84 | 85 | # Generate! 86 | while True: 87 | 88 | # randomly pick as many images as batchsize 89 | indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob) 90 | 91 | # initialise input lists 92 | list_label_maps = [] 93 | list_means = [] 94 | list_stds = [] 95 | 96 | for idx in indices: 97 | 98 | # load input label map 99 | lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4)) 100 | if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]): 101 | lab[lab == 24] = 0 102 | 103 | # add label map to inputs 104 | list_label_maps.append(utils.add_axis(lab, axis=[0, -1])) 105 | 106 | # add means and standard deviations to inputs 107 | means = np.empty((1, n_labels, 0)) 108 | stds = np.empty((1, n_labels, 0)) 109 | for channel in range(n_channels): 110 | 111 | # retrieve channel specific stats if necessary 112 | if isinstance(prior_means, np.ndarray): 113 | if (prior_means.shape[0] > 2) & use_specific_stats_for_channel: 114 | if prior_means.shape[0] / 2 != n_channels: 115 | raise ValueError("the number of blocks in prior_means does not match n_channels. This " 116 | "message is printed because use_specific_stats_for_channel is True.") 117 | tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :] 118 | else: 119 | tmp_prior_means = prior_means 120 | else: 121 | tmp_prior_means = prior_means 122 | if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5): 123 | tmp_prior_means = None 124 | if isinstance(prior_stds, np.ndarray): 125 | if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel: 126 | if prior_stds.shape[0] / 2 != n_channels: 127 | raise ValueError("the number of blocks in prior_stds does not match n_channels. This " 128 | "message is printed because use_specific_stats_for_channel is True.") 129 | tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :] 130 | else: 131 | tmp_prior_stds = prior_stds 132 | else: 133 | tmp_prior_stds = prior_stds 134 | if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5): 135 | tmp_prior_stds = None 136 | 137 | # draw means and std devs from priors 138 | tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions, 139 | 125., 125., positive_only=True) 140 | tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions, 141 | 15., 15., positive_only=True) 142 | random_coef = npr.uniform() 143 | if random_coef > 0.95: # reset the background to 0 in 5% of cases 144 | tmp_classes_means[0] = 0 145 | tmp_classes_stds[0] = 0 146 | elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases 147 | tmp_classes_means[0] = npr.uniform(0, 15) 148 | tmp_classes_stds[0] = npr.uniform(0, 5) 149 | tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1]) 150 | tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1]) 151 | means = np.concatenate([means, tmp_means], axis=-1) 152 | stds = np.concatenate([stds, tmp_stds], axis=-1) 153 | list_means.append(means) 154 | list_stds.append(stds) 155 | 156 | # build list of inputs for generation model 157 | list_inputs = [list_label_maps, list_means, list_stds] 158 | if batchsize > 1: # concatenate each input type if batchsize > 1 159 | list_inputs = [np.concatenate(item, 0) for item in list_inputs] 160 | else: 161 | list_inputs = [item[0] for item in list_inputs] 162 | 163 | yield list_inputs 164 | -------------------------------------------------------------------------------- /SynthSeg/predict_qc.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import os 19 | import numpy as np 20 | import tensorflow as tf 21 | import keras.layers as KL 22 | from keras.models import Model 23 | 24 | # project imports 25 | from SynthSeg import evaluate 26 | 27 | # third-party imports 28 | from ext.lab2im import utils 29 | from ext.lab2im import edit_volumes 30 | from ext.neuron import models as nrn_models 31 | 32 | 33 | def predict(path_predictions, 34 | path_qc_results, 35 | path_model, 36 | labels_list, 37 | labels_to_convert=None, 38 | convert_gt=False, 39 | shape=224, 40 | n_levels=5, 41 | nb_conv_per_level=3, 42 | conv_size=5, 43 | unet_feat_count=24, 44 | feat_multiplier=2, 45 | activation='relu', 46 | path_gts=None, 47 | verbose=True): 48 | 49 | # prepare input/output filepaths 50 | path_predictions, path_gts, path_qc_results, path_gt_results, path_diff = \ 51 | prepare_output_files(path_predictions, path_gts, path_qc_results) 52 | 53 | # get label list 54 | labels_list, _ = utils.get_list_labels(label_list=labels_list) 55 | labels_list_unique, _ = np.unique(labels_list, return_index=True) 56 | if labels_to_convert is not None: 57 | labels_to_convert, _ = utils.get_list_labels(label_list=labels_to_convert) 58 | 59 | # prepare qc results 60 | pred_qc_results = np.zeros((len(labels_list_unique) + 1, len(path_predictions))) 61 | gt_qc_results = np.zeros((len(labels_list_unique), len(path_predictions))) if path_gt_results is not None else None 62 | 63 | # build network 64 | model_input_shape = [None, None, None, 1] 65 | net = build_qc_model(path_model=path_model, 66 | input_shape=model_input_shape, 67 | label_list=labels_list_unique, 68 | n_levels=n_levels, 69 | nb_conv_per_level=nb_conv_per_level, 70 | conv_size=conv_size, 71 | unet_feat_count=unet_feat_count, 72 | feat_multiplier=feat_multiplier, 73 | activation=activation) 74 | 75 | # perform segmentation 76 | loop_info = utils.LoopInfo(len(path_predictions), 10, 'predicting', True) 77 | for idx, (path_prediction, path_gt) in enumerate(zip(path_predictions, path_gts)): 78 | 79 | # compute segmentation only if needed 80 | if verbose: 81 | loop_info.update(idx) 82 | 83 | # preprocessing 84 | prediction, gt_scores = preprocess(path_prediction, path_gt, shape, labels_list, labels_to_convert, convert_gt) 85 | 86 | # get predicted scores 87 | pred_qc_results[-1, idx] = np.sum(prediction > 0) 88 | pred_qc_results[:-1, idx] = np.clip(np.squeeze(net.predict(prediction)), 0, 1) 89 | np.save(path_qc_results, pred_qc_results) 90 | 91 | # save GT scores if necessary 92 | if gt_scores is not None: 93 | gt_qc_results[:, idx] = gt_scores 94 | np.save(path_gt_results, gt_qc_results) 95 | 96 | if path_diff is not None: 97 | diff = pred_qc_results[:-1, :] - gt_qc_results 98 | np.save(path_diff, diff) 99 | 100 | 101 | def prepare_output_files(path_predictions, path_gts, path_qc_results): 102 | 103 | # check inputs 104 | assert path_predictions is not None, 'please specify an input file/folder (--i)' 105 | assert path_qc_results is not None, 'please specify an output file/folder (--o)' 106 | 107 | # convert path to absolute paths 108 | path_predictions = os.path.abspath(path_predictions) 109 | path_qc_results = os.path.abspath(path_qc_results) 110 | 111 | # list input predictions 112 | path_predictions = utils.list_images_in_folder(path_predictions) 113 | 114 | # build path output with qc results 115 | if path_qc_results[-4:] != '.npy': 116 | print('Path for QC outputs provided without npy extension. Adding npy extension.') 117 | path_qc_results += '.npy' 118 | utils.mkdir(os.path.dirname(path_qc_results)) 119 | 120 | if path_gts is not None: 121 | path_gts = utils.list_images_in_folder(path_gts) 122 | assert len(path_gts) == len(path_predictions), 'not the same number of predictions and GTs' 123 | path_gt_results = path_qc_results.replace('.npy', '_gt.npy') 124 | path_diff = path_qc_results.replace('.npy', '_diff.npy') 125 | else: 126 | path_gts = [None] * len(path_predictions) 127 | path_gt_results = path_diff = None 128 | 129 | return path_predictions, path_gts, path_qc_results, path_gt_results, path_diff 130 | 131 | 132 | def preprocess(path_prediction, path_gt=None, shape=224, labels_list=None, labels_to_convert=None, convert_gt=False): 133 | 134 | # read image and corresponding info 135 | pred, _, aff_pred, n_dims, _, _, _ = utils.get_volume_info(path_prediction, True) 136 | gt = utils.load_volume(path_gt, aff_ref=np.eye(4)) if path_gt is not None else None 137 | 138 | # align 139 | pred = edit_volumes.align_volume_to_ref(pred, aff_pred, aff_ref=np.eye(4), n_dims=n_dims) 140 | 141 | # pad/crop to 224, such that segmentations are in the middle of the patch 142 | if gt is not None: 143 | pred, gt = make_shape(pred, gt, shape, n_dims) 144 | else: 145 | pred, _ = edit_volumes.crop_volume_around_region(pred, cropping_shape=shape) 146 | 147 | # convert labels if necessary 148 | if labels_to_convert is not None: 149 | lut = utils.get_mapping_lut(labels_to_convert, labels_list) 150 | pred = lut[pred.astype('int32')] 151 | if convert_gt & (gt is not None): 152 | gt = lut[gt.astype('int32')] 153 | 154 | # compute GT dice scores 155 | gt_scores = evaluate.fast_dice(pred, gt, np.unique(labels_list)) if gt is not None else None 156 | 157 | # add batch and channel axes 158 | pred = utils.add_axis(pred, axis=0) # channel axis will be added later when computing one-hot 159 | 160 | return pred, gt_scores 161 | 162 | 163 | def make_shape(pred, gt, shape, n_dims): 164 | 165 | mask = ((pred > 0) & (pred != 24)) | (gt > 0) 166 | vol_shape = np.array(pred.shape[:n_dims]) 167 | 168 | if np.any(mask): 169 | 170 | # find cropping indices 171 | indices = np.nonzero(mask) 172 | min_idx = np.maximum(np.array([np.min(idx) for idx in indices]), 0) 173 | max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1, vol_shape) 174 | 175 | # expand/retract (depending on the desired shape) the cropping region around the centre 176 | intermediate_vol_shape = max_idx - min_idx 177 | cropping_shape = np.array(utils.reformat_to_list(shape, length=n_dims)) 178 | min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) 179 | max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) 180 | 181 | # crop volume 182 | cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) 183 | pred = edit_volumes.crop_volume_with_idx(pred, cropping, n_dims=n_dims) 184 | gt = edit_volumes.crop_volume_with_idx(gt, cropping, n_dims=n_dims) 185 | 186 | # check if we need to pad the output to the desired shape 187 | min_padding = np.abs(np.minimum(min_idx, 0)) 188 | max_padding = np.maximum(max_idx - vol_shape, 0) 189 | if np.any(min_padding > 0) | np.any(max_padding > 0): 190 | pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) 191 | pred = np.pad(pred, pad_margins, mode='constant', constant_values=0) 192 | gt = np.pad(gt, pad_margins, mode='constant', constant_values=0) 193 | 194 | return pred, gt 195 | 196 | 197 | def build_qc_model(path_model, 198 | input_shape, 199 | label_list, 200 | n_levels, 201 | nb_conv_per_level, 202 | conv_size, 203 | unet_feat_count, 204 | feat_multiplier, 205 | activation): 206 | 207 | assert os.path.isfile(path_model), "The provided model path does not exist." 208 | label_list_unique = np.unique(label_list) 209 | n_labels = len(label_list_unique) 210 | 211 | # one-hot encoding of the input prediction as the network expects soft probabilities 212 | input_labels = KL.Input(input_shape[:-1]) 213 | labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(input_labels) 214 | net = Model(inputs=input_labels, outputs=labels) 215 | 216 | # build model 217 | model = nrn_models.conv_enc(input_model=net, 218 | input_shape=input_shape, 219 | nb_levels=n_levels, 220 | nb_conv_per_level=nb_conv_per_level, 221 | conv_size=conv_size, 222 | nb_features=unet_feat_count, 223 | feat_mult=feat_multiplier, 224 | activation=activation, 225 | batch_norm=-1, 226 | use_residuals=True, 227 | name='qc') 228 | last = model.outputs[0] 229 | 230 | conv_kwargs = {'padding': 'same', 'activation': 'relu', 'data_format': 'channels_last'} 231 | last = KL.MaxPool3D(pool_size=(2, 2, 2), name='qc_maxpool_%s' % (n_levels - 1), padding='same')(last) 232 | last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_0')(last) 233 | last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_1')(last) 234 | last = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name='qc_final_pred')(last) 235 | 236 | net = Model(inputs=net.inputs, outputs=last) 237 | net.load_weights(path_model, by_name=True) 238 | 239 | return net 240 | -------------------------------------------------------------------------------- /SynthSeg/training_denoiser.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import os 19 | import numpy as np 20 | import tensorflow as tf 21 | from keras import models 22 | from keras import layers as KL 23 | 24 | # project imports 25 | from SynthSeg import metrics_model as metrics 26 | from SynthSeg.training import train_model 27 | from SynthSeg.labels_to_image_model import get_shapes 28 | from SynthSeg.training_supervised import build_model_inputs 29 | 30 | # third-party imports 31 | from ext.lab2im import utils, layers 32 | from ext.neuron import models as nrn_models 33 | 34 | 35 | def training(list_paths_input_labels, 36 | list_paths_target_labels, 37 | model_dir, 38 | input_segmentation_labels, 39 | target_segmentation_labels=None, 40 | subjects_prob=None, 41 | batchsize=1, 42 | output_shape=None, 43 | scaling_bounds=.2, 44 | rotation_bounds=15, 45 | shearing_bounds=.012, 46 | nonlin_std=3., 47 | nonlin_scale=.04, 48 | prob_erosion_dilation=0.3, 49 | min_erosion_dilation=4, 50 | max_erosion_dilation=5, 51 | n_levels=5, 52 | nb_conv_per_level=2, 53 | conv_size=5, 54 | unet_feat_count=16, 55 | feat_multiplier=2, 56 | activation='elu', 57 | skip_n_concatenations=2, 58 | lr=1e-4, 59 | wl2_epochs=1, 60 | dice_epochs=50, 61 | steps_per_epoch=10000, 62 | checkpoint=None): 63 | """ 64 | 65 | This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on 66 | label maps. We regroup the parameters in four categories: General, Augmentation, Architecture, Training. 67 | 68 | # IMPORTANT !!! 69 | # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), 70 | # these values refer to the RAS axes. 71 | 72 | :param list_paths_input_labels: list of all the paths of the input label maps. These correspond to "noisy" 73 | segmentations that the denoiser will be trained to correct. 74 | :param list_paths_target_labels: list of all the paths of the output label maps. Must have the same order as 75 | list_paths_input_labels. These are the target label maps that the network will learn to produce given the "noisy" 76 | input label maps. 77 | :param model_dir: path of a directory where the models will be saved during training. 78 | :param input_segmentation_labels: list of all the label values present in the input label maps. 79 | :param target_segmentation_labels: list of all the label values present in the output label maps. By default (None) 80 | this will be taken to be the same as input_segmentation_labels. 81 | 82 | # ----------------------------------------------- General parameters ----------------------------------------------- 83 | # label maps parameters 84 | :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick 85 | the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an array, and it 86 | must be as long as path_label_maps. By default, all label maps are chosen with the same importance. 87 | 88 | # output-related parameters 89 | :param batchsize: (optional) number of images to generate per mini-batch. Default is 1. 90 | :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image 91 | Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. 92 | Default is None, where no cropping is performed. 93 | 94 | # --------------------------------------------- Augmentation parameters -------------------------------------------- 95 | # spatial deformation parameters 96 | :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is 97 | sampled from a uniform distribution of predefined bounds. Can either be: 98 | 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds 99 | (1-scaling_bounds, 1+scaling_bounds) for each dimension. 100 | 2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor in dimension i is sampled from 101 | the uniform distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. 102 | 3) False, in which case scaling is completely turned off. 103 | Default is scaling_bounds = 0.2 (case 1) 104 | :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the 105 | bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]). 106 | Default is rotation_bounds = 15. 107 | :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. 108 | :param nonlin_std: (optional) Standard deviation of the normal distribution from which we sample the first 109 | tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation. 110 | :param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled 111 | tensor for synthesising the elastic deformation field. 112 | 113 | # degradation of the input labels 114 | :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or 115 | dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network. 116 | :param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random 117 | coefficients are applied. Set the minimum erosion/dilation coefficient here. 118 | :param max_erosion_dilation: (optional) Set the maximum erosion/dilation coefficient here. 119 | 120 | # ------------------------------------------ UNet architecture parameters ------------------------------------------ 121 | :param n_levels: (optional) number of level for the Unet. Default is 5. 122 | :param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2. 123 | :param conv_size: (optional) size of the convolution kernels. Default is 2. 124 | :param unet_feat_count: (optional) number of feature for the first layer of the UNet. Default is 24. 125 | :param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 2. 126 | :param activation: (optional) activation function. Can be 'elu', 'relu'. 127 | :param skip_n_concatenations: (optional) number of levels for which to remove the traditional skip connections of 128 | the UNet architecture. default is zero, which corresponds to the classic UNet architecture. Example: 129 | If skip_n_concatenations = 2, then we will remove the concatenation link between the two top levels of the UNet. 130 | 131 | # ----------------------------------------------- Training parameters ---------------------------------------------- 132 | :param lr: (optional) learning rate for the training. Default is 1e-4 133 | :param wl2_epochs: (optional) number of epochs for which the network (except the soft-max layer) is trained with L2 134 | norm loss function. Default is 1. 135 | :param dice_epochs: (optional) number of epochs with the soft Dice loss function. Default is 50. 136 | :param steps_per_epoch: (optional) number of steps per epoch. Default is 10000. Since no online validation is 137 | possible, this is equivalent to the frequency at which the models are saved. 138 | :param checkpoint: (optional) path of an already saved model to load before starting the training. 139 | """ 140 | 141 | # check epochs 142 | assert (wl2_epochs > 0) | (dice_epochs > 0), \ 143 | 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) 144 | 145 | # prepare data files 146 | input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels) 147 | if target_segmentation_labels is None: 148 | target_label_list = input_label_list 149 | else: 150 | target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels) 151 | n_labels = np.size(target_label_list) 152 | 153 | # create augmentation model 154 | labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4)) 155 | augmentation_model = build_augmentation_model(labels_shape, 156 | input_label_list, 157 | crop_shape=output_shape, 158 | output_div_by_n=2 ** n_levels, 159 | scaling_bounds=scaling_bounds, 160 | rotation_bounds=rotation_bounds, 161 | shearing_bounds=shearing_bounds, 162 | nonlin_std=nonlin_std, 163 | nonlin_scale=nonlin_scale, 164 | prob_erosion_dilation=prob_erosion_dilation, 165 | min_erosion_dilation=min_erosion_dilation, 166 | max_erosion_dilation=max_erosion_dilation) 167 | unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:] 168 | 169 | # prepare the segmentation model 170 | l2l_model = nrn_models.unet(input_model=augmentation_model, 171 | input_shape=unet_input_shape, 172 | nb_labels=n_labels, 173 | nb_levels=n_levels, 174 | nb_conv_per_level=nb_conv_per_level, 175 | conv_size=conv_size, 176 | nb_features=unet_feat_count, 177 | feat_mult=feat_multiplier, 178 | activation=activation, 179 | batch_norm=-1, 180 | skip_n_concatenations=skip_n_concatenations, 181 | name='l2l') 182 | 183 | # input generator 184 | model_inputs = build_model_inputs(path_inputs=list_paths_input_labels, 185 | path_outputs=list_paths_target_labels, 186 | batchsize=batchsize, 187 | subjects_prob=subjects_prob, 188 | dtype_input='int32') 189 | input_generator = utils.build_training_generator(model_inputs, batchsize) 190 | 191 | # pre-training with weighted L2, input is fit to the softmax rather than the probabilities 192 | if wl2_epochs > 0: 193 | wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output]) 194 | wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2') 195 | train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) 196 | checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) 197 | 198 | # fine-tuning with dice metric 199 | dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice') 200 | train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) 201 | 202 | 203 | def build_augmentation_model(labels_shape, 204 | segmentation_labels, 205 | crop_shape=None, 206 | output_div_by_n=None, 207 | scaling_bounds=0.15, 208 | rotation_bounds=15, 209 | shearing_bounds=0.012, 210 | translation_bounds=False, 211 | nonlin_std=3., 212 | nonlin_scale=.0625, 213 | prob_erosion_dilation=0.3, 214 | min_erosion_dilation=4, 215 | max_erosion_dilation=7): 216 | 217 | # reformat resolutions and get shapes 218 | labels_shape = utils.reformat_to_list(labels_shape) 219 | n_dims, _ = utils.get_dims(labels_shape) 220 | n_labels = len(segmentation_labels) 221 | 222 | # get shapes 223 | crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n) 224 | 225 | # define model inputs 226 | net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32') 227 | target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32') 228 | 229 | # deform labels 230 | noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, 231 | rotation_bounds=rotation_bounds, 232 | shearing_bounds=shearing_bounds, 233 | translation_bounds=translation_bounds, 234 | nonlin_std=nonlin_std, 235 | nonlin_scale=nonlin_scale, 236 | inter_method='nearest')([net_input, target_input]) 237 | 238 | # cropping 239 | if crop_shape != labels_shape: 240 | noisy_labels, target = layers.RandomCrop(crop_shape)([noisy_labels, target]) 241 | 242 | # random erosion 243 | if prob_erosion_dilation > 0: 244 | noisy_labels = layers.RandomDilationErosion(min_erosion_dilation, 245 | max_erosion_dilation, 246 | prob=prob_erosion_dilation)(noisy_labels) 247 | 248 | # convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot 249 | noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels) 250 | target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target) 251 | noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels), 252 | name='noisy_labels_out')([noisy_labels, target]) 253 | 254 | # build model and return 255 | brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target]) 256 | return brain_model 257 | -------------------------------------------------------------------------------- /SynthSeg/validate_denoiser.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import os 19 | 20 | # project imports 21 | from SynthSeg.predict_denoiser import predict 22 | 23 | # third-party imports 24 | from ext.lab2im import utils 25 | 26 | 27 | def validate_training(prediction_dir, 28 | gt_dir, 29 | models_dir, 30 | validation_main_dir, 31 | target_segmentation_labels, 32 | input_segmentation_labels=None, 33 | evaluation_labels=None, 34 | step_eval=1, 35 | min_pad=None, 36 | cropping=None, 37 | topology_classes=None, 38 | sigma_smoothing=0, 39 | keep_biggest_component=False, 40 | n_levels=5, 41 | nb_conv_per_level=2, 42 | conv_size=3, 43 | unet_feat_count=24, 44 | feat_multiplier=2, 45 | activation='elu', 46 | skip_n_concatenations=0, 47 | recompute=True): 48 | 49 | # create result folder 50 | utils.mkdir(validation_main_dir) 51 | 52 | # loop over models 53 | list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval] 54 | # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 2 == 0] 55 | loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) 56 | for model_idx, path_model in enumerate(list_models): 57 | 58 | # build names and create folders 59 | model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) 60 | dice_path = os.path.join(model_val_dir, 'dice.npy') 61 | utils.mkdir(model_val_dir) 62 | 63 | if (not os.path.isfile(dice_path)) | recompute: 64 | loop_info.update(model_idx) 65 | predict(path_predictions=prediction_dir, 66 | path_corrections=model_val_dir, 67 | path_model=path_model, 68 | target_segmentation_labels=target_segmentation_labels, 69 | input_segmentation_labels=input_segmentation_labels, 70 | min_pad=min_pad, 71 | cropping=cropping, 72 | topology_classes=topology_classes, 73 | sigma_smoothing=sigma_smoothing, 74 | keep_biggest_component=keep_biggest_component, 75 | n_levels=n_levels, 76 | nb_conv_per_level=nb_conv_per_level, 77 | conv_size=conv_size, 78 | unet_feat_count=unet_feat_count, 79 | feat_multiplier=feat_multiplier, 80 | activation=activation, 81 | skip_n_concatenations=skip_n_concatenations, 82 | gt_folder=gt_dir, 83 | evaluation_labels=evaluation_labels, 84 | recompute=recompute, 85 | verbose=False) 86 | -------------------------------------------------------------------------------- /SynthSeg/validate_group.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import os 19 | 20 | # project imports 21 | from SynthSeg.predict_group import predict 22 | 23 | # third-party imports 24 | from ext.lab2im import utils 25 | 26 | 27 | def validate_training(image_dir, 28 | mask_dir, 29 | gt_dir, 30 | models_dir, 31 | validation_main_dir, 32 | labels_segmentation, 33 | labels_mask, 34 | evaluation_labels=None, 35 | step_eval=1, 36 | min_pad=None, 37 | cropping=None, 38 | sigma_smoothing=0, 39 | strict_masking=False, 40 | keep_biggest_component=False, 41 | n_levels=5, 42 | nb_conv_per_level=2, 43 | conv_size=3, 44 | unet_feat_count=24, 45 | feat_multiplier=2, 46 | activation='elu', 47 | list_incorrect_labels=None, 48 | list_correct_labels=None, 49 | recompute=False): 50 | 51 | # create result folder 52 | utils.mkdir(validation_main_dir) 53 | 54 | # loop over models 55 | list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval] 56 | # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0] 57 | loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) 58 | for model_idx, path_model in enumerate(list_models): 59 | 60 | # build names and create folders 61 | model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) 62 | dice_path = os.path.join(model_val_dir, 'dice.npy') 63 | utils.mkdir(model_val_dir) 64 | 65 | if (not os.path.isfile(dice_path)) | recompute: 66 | loop_info.update(model_idx) 67 | predict(path_images=image_dir, 68 | path_masks=mask_dir, 69 | path_segmentations=model_val_dir, 70 | path_model=path_model, 71 | labels_segmentation=labels_segmentation, 72 | labels_mask=labels_mask, 73 | min_pad=min_pad, 74 | cropping=cropping, 75 | sigma_smoothing=sigma_smoothing, 76 | strict_masking=strict_masking, 77 | keep_biggest_component=keep_biggest_component, 78 | n_levels=n_levels, 79 | nb_conv_per_level=nb_conv_per_level, 80 | conv_size=conv_size, 81 | unet_feat_count=unet_feat_count, 82 | feat_multiplier=feat_multiplier, 83 | activation=activation, 84 | gt_folder=gt_dir, 85 | evaluation_labels=evaluation_labels, 86 | list_incorrect_labels=list_incorrect_labels, 87 | list_correct_labels=list_correct_labels, 88 | recompute=recompute, 89 | verbose=False) 90 | -------------------------------------------------------------------------------- /SynthSeg/validate_qc.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import os 19 | import re 20 | import logging 21 | import numpy as np 22 | import matplotlib.pyplot as plt 23 | from tensorflow.python.summary.summary_iterator import summary_iterator 24 | 25 | # project imports 26 | from SynthSeg.predict_qc import predict 27 | 28 | # third-party imports 29 | from ext.lab2im import utils 30 | 31 | 32 | def validate_training(prediction_dir, 33 | gt_dir, 34 | models_dir, 35 | validation_main_dir, 36 | labels_list, 37 | labels_to_convert=None, 38 | convert_gt=False, 39 | shape=224, 40 | n_levels=5, 41 | nb_conv_per_level=3, 42 | conv_size=5, 43 | unet_feat_count=24, 44 | feat_multiplier=2, 45 | activation='relu', 46 | step_eval=1, 47 | recompute=False): 48 | 49 | # create result folder 50 | utils.mkdir(validation_main_dir) 51 | 52 | # loop over models 53 | list_models = utils.list_files(models_dir, expr=['qc', '.h5'], cond_type='and')[::step_eval] 54 | # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0] 55 | loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) 56 | for model_idx, path_model in enumerate(list_models): 57 | 58 | # build names and create folders 59 | model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) 60 | score_path = os.path.join(model_val_dir, 'pred_qc_results.npy') 61 | utils.mkdir(model_val_dir) 62 | 63 | if (not os.path.isfile(score_path)) | recompute: 64 | loop_info.update(model_idx) 65 | predict(path_predictions=prediction_dir, 66 | path_qc_results=score_path, 67 | path_model=path_model, 68 | labels_list=labels_list, 69 | labels_to_convert=labels_to_convert, 70 | convert_gt=convert_gt, 71 | shape=shape, 72 | n_levels=n_levels, 73 | nb_conv_per_level=nb_conv_per_level, 74 | conv_size=conv_size, 75 | unet_feat_count=unet_feat_count, 76 | feat_multiplier=feat_multiplier, 77 | activation=activation, 78 | path_gts=gt_dir, 79 | verbose=False) 80 | 81 | 82 | def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_indices=None, 83 | skip_first_dice_row=True, size_max_circle=100, figsize=(11, 6), y_lim=None, fontsize=18, 84 | list_linestyles=None, list_colours=None, plot_legend=False): 85 | """This function plots the validation curves of several networks, based on the results of validate_training(). 86 | It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores 87 | for the corresponding validated epoch. 88 | :param list_validation_dirs: list of all the validation folders of the trainings to plot. 89 | :param eval_indices: (optional) compute the average Dice on a subset of labels indicated by the specified indices. 90 | Can be a 1d numpy array, the path to such an array, or a list of 1d numpy arrays as long as list_validation_dirs. 91 | :param skip_first_dice_row: if eval_indices is None, skip the first row of the dice matrices (usually background) 92 | :param size_max_circle: (optional) size of the marker for epochs achieving the best validation scores. 93 | :param figsize: (optional) size of the figure to draw. 94 | :param fontsize: (optional) fontsize used for the graph.""" 95 | 96 | n_curves = len(list_validation_dirs) 97 | 98 | if eval_indices is not None: 99 | if isinstance(eval_indices, (np.ndarray, str)): 100 | if isinstance(eval_indices, str): 101 | eval_indices = np.load(eval_indices) 102 | eval_indices = np.squeeze(utils.reformat_to_n_channels_array(eval_indices, n_dims=len(eval_indices))) 103 | eval_indices = [eval_indices] * len(list_validation_dirs) 104 | elif isinstance(eval_indices, list): 105 | for (i, e) in enumerate(eval_indices): 106 | if isinstance(e, np.ndarray): 107 | eval_indices[i] = np.squeeze(utils.reformat_to_n_channels_array(e, n_dims=len(e))) 108 | else: 109 | raise TypeError('if provided as a list, eval_indices should only contain numpy arrays') 110 | else: 111 | raise TypeError('eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays.') 112 | else: 113 | eval_indices = [None] * len(list_validation_dirs) 114 | 115 | # reformat model names 116 | if architecture_names is None: 117 | architecture_names = [os.path.basename(os.path.dirname(d)) for d in list_validation_dirs] 118 | else: 119 | architecture_names = utils.reformat_to_list(architecture_names, len(list_validation_dirs)) 120 | 121 | # prepare legend labels 122 | if plot_legend is False: 123 | list_legend_labels = ['_nolegend_'] * n_curves 124 | elif plot_legend is True: 125 | list_legend_labels = architecture_names 126 | else: 127 | list_legend_labels = architecture_names 128 | list_legend_labels = ['_nolegend_' if i >= plot_legend else list_legend_labels[i] for i in range(n_curves)] 129 | 130 | # prepare linestyles 131 | if list_linestyles is not None: 132 | list_linestyles = utils.reformat_to_list(list_linestyles) 133 | else: 134 | list_linestyles = [None] * n_curves 135 | 136 | # prepare curve colours 137 | if list_colours is not None: 138 | list_colours = utils.reformat_to_list(list_colours) 139 | else: 140 | list_colours = [None] * n_curves 141 | 142 | # loop over architectures 143 | plt.figure(figsize=figsize) 144 | for idx, (net_val_dir, net_name, linestyle, colour, legend_label, eval_idx) in enumerate(zip(list_validation_dirs, 145 | architecture_names, 146 | list_linestyles, 147 | list_colours, 148 | list_legend_labels, 149 | eval_indices)): 150 | 151 | list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False) 152 | 153 | # loop over epochs 154 | list_net_scores = list() 155 | list_epochs = list() 156 | for epoch_dir in list_epochs_dir: 157 | 158 | # build names and create folders 159 | path_epoch_scores = utils.list_files(os.path.join(net_val_dir, epoch_dir), expr='diff') 160 | if len(path_epoch_scores) == 1: 161 | path_epoch_scores = path_epoch_scores[0] 162 | if eval_idx is not None: 163 | list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[eval_idx, :]))) 164 | else: 165 | if skip_first_dice_row: 166 | list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[1:, :]))) 167 | else: 168 | list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)))) 169 | list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir))) 170 | 171 | # plot validation scores for current architecture 172 | if list_net_scores: # check that archi has been validated for at least 1 epoch 173 | list_net_scores = np.array(list_net_scores) 174 | list_epochs = np.array(list_epochs) 175 | list_epochs, idx = np.unique(list_epochs, return_index=True) 176 | list_net_scores = list_net_scores[idx] 177 | min_score = np.min(list_net_scores) 178 | epoch_min_score = list_epochs[np.argmin(list_net_scores)] 179 | print('\n'+net_name) 180 | print('epoch min score: %d' % epoch_min_score) 181 | print('min score: %0.3f' % min_score) 182 | plt.plot(list_epochs, list_net_scores, label=legend_label, linestyle=linestyle, color=colour) 183 | plt.scatter(epoch_min_score, min_score, s=size_max_circle, color=colour) 184 | 185 | # finalise plot 186 | plt.grid() 187 | plt.tick_params(axis='both', labelsize=fontsize) 188 | plt.ylabel('Scores', fontsize=fontsize) 189 | plt.xlabel('Epochs', fontsize=fontsize) 190 | if y_lim is not None: 191 | plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot 192 | plt.title('Validation curves', fontsize=fontsize) 193 | if plot_legend: 194 | plt.legend(fontsize=fontsize) 195 | plt.tight_layout(pad=1) 196 | plt.show() 197 | 198 | 199 | def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, 6), fontsize=18, 200 | y_lim=None, remove_legend=False): 201 | """This function draws the learning curve of several trainings on the same graph. 202 | :param path_tensorboard_files: list of tensorboard files corresponding to the models to plot. 203 | :param architecture_names: list of the names of the models 204 | :param figsize: (optional) size of the figure to draw. 205 | :param fontsize: (optional) fontsize used for the graph. 206 | """ 207 | 208 | # reformat inputs 209 | path_tensorboard_files = utils.reformat_to_list(path_tensorboard_files) 210 | architecture_names = utils.reformat_to_list(architecture_names) 211 | assert len(path_tensorboard_files) == len(architecture_names), 'names and tensorboard lists should have same length' 212 | 213 | # loop over architectures 214 | plt.figure(figsize=figsize) 215 | for path_tensorboard_file, name in zip(path_tensorboard_files, architecture_names): 216 | 217 | path_tensorboard_file = utils.reformat_to_list(path_tensorboard_file) 218 | 219 | # extract loss at the end of all epochs 220 | list_losses = list() 221 | list_epochs = list() 222 | logging.getLogger('tensorflow').disabled = True 223 | for path in path_tensorboard_file: 224 | for e in summary_iterator(path): 225 | for v in e.summary.value: 226 | if v.tag == 'loss' or v.tag == 'accuracy' or v.tag == 'epoch_loss': 227 | list_losses.append(v.simple_value) 228 | list_epochs.append(e.step) 229 | plt.plot(np.array(list_epochs), np.array(list_losses), label=name, linewidth=2) 230 | 231 | # finalise plot 232 | plt.grid() 233 | if not remove_legend: 234 | plt.legend(fontsize=fontsize) 235 | plt.xlabel('Epochs', fontsize=fontsize) 236 | plt.ylabel('Scores', fontsize=fontsize) 237 | if y_lim is not None: 238 | plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot 239 | plt.tick_params(axis='both', labelsize=fontsize) 240 | plt.title('Learning curves', fontsize=fontsize) 241 | plt.tight_layout(pad=1) 242 | plt.show() 243 | -------------------------------------------------------------------------------- /bibtex.bib: -------------------------------------------------------------------------------- 1 | @article{billot_synthseg_2023, 2 | title = {SynthSeg: {Segmentation} of brain {MRI} scans of any contrast and resolution without retraining}, 3 | author = {Billot, Benjamin and Greve, Douglas N. and Puonti, Oula and Thielscher, Axel and Van Leemput, Koen and Fischl, Bruce and Dalca, Adrian V. and Iglesias, Juan Eugenio}, 4 | journal = {{Medical} {Image} {Analysis}}, 5 | year = {2023}, 6 | volume = {86}, 7 | pages = {102789}, 8 | issn = {1361-8415}, 9 | doi = {10.1016/j.media.2023.102789}, 10 | } 11 | 12 | @article{billot_robust_2023, 13 | title = {{Robust} machine learning segmentation for large-scale analysis of heterogeneous clinical brain {MRI} datasets}, 14 | author = {Billot, Benjamin and Colin, Magdamo Cheng, You and Das, Sudeshna and Iglesias, Juan Eugenio}, 15 | journal = {{Proceedings} of the {National} {Academy} of {Sciences} ({PNAS})}, 16 | year = {2023}, 17 | volume = {120}, 18 | number = {9}, 19 | pages = {1--10}, 20 | doi = {10.1073/pnas.2216399120}, 21 | } 22 | -------------------------------------------------------------------------------- /data/README_figures/new_features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/new_features.png -------------------------------------------------------------------------------- /data/README_figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/overview.png -------------------------------------------------------------------------------- /data/README_figures/robust.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/robust.png -------------------------------------------------------------------------------- /data/README_figures/segmentations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/segmentations.png -------------------------------------------------------------------------------- /data/README_figures/table_versions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/table_versions.png -------------------------------------------------------------------------------- /data/README_figures/training_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/training_data.png -------------------------------------------------------------------------------- /data/README_figures/youtube_link.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/youtube_link.png -------------------------------------------------------------------------------- /data/labels table.txt: -------------------------------------------------------------------------------- 1 | List of the structures segmented by SynthSeg along with their corresponding label values. 2 | 3 | The structures are given in the same order as they appear in the posteriors, i.e. the first map of the posteriors 4 | corresponds to the background, then the second map is associated to the left cerebral white matter, etc. 5 | 6 | Please note that the label values follow the FreeSurfer classification. Also, we do not provide any colour scheme, 7 | as the colour displayed for each structure depends on the used image viewer. 8 | 9 | WARNING: if you use the --v1 flag, note that it won't segment the CSF (label 24), since this was introduced in version 2 10 | 11 | labels structures 12 | 0 background 13 | 2 left cerebral white matter 14 | 3 left cerebral cortex 15 | 4 left lateral ventricle 16 | 5 left inferior lateral ventricle 17 | 7 left cerebellum white matter 18 | 8 left cerebellum cortex 19 | 10 left thalamus 20 | 11 left caudate 21 | 12 left putamen 22 | 13 left pallidum 23 | 14 3rd ventricle 24 | 15 4th ventricle 25 | 16 brain-stem 26 | 17 left hippocampus 27 | 18 left amygdala 28 | 26 left accumbens area 29 | 24 CSF 30 | 28 left ventral DC 31 | 41 right cerebral white matter 32 | 42 right cerebral cortex 33 | 43 right lateral ventricle 34 | 44 right inferior lateral ventricle 35 | 46 right cerebellum white matter 36 | 47 right cerebellum cortex 37 | 49 right thalamus 38 | 50 right caudate 39 | 51 right putamen 40 | 52 right pallidum 41 | 53 right hippocampus 42 | 54 right amygdala 43 | 58 right accumbens area 44 | 60 right ventral DC -------------------------------------------------------------------------------- /data/labels_classes_priors/generation_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/generation_classes.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/generation_classes_contrast_specific.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/generation_classes_contrast_specific.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/generation_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/generation_labels.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/prior_means_t1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/prior_means_t1.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/prior_stds_t1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/prior_stds_t1.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_denoiser_labels_2.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_denoiser_labels_2.0.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_parcellation_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_parcellation_labels.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_parcellation_names.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_parcellation_names.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_qc_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_labels.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_qc_labels_2.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_labels_2.0.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_qc_names.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_names.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_qc_names_2.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_names_2.0.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_segmentation_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_labels.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_segmentation_labels_2.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_labels_2.0.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_segmentation_names.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_names.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_segmentation_names_2.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_names_2.0.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_topological_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_topological_classes.npy -------------------------------------------------------------------------------- /data/labels_classes_priors/synthseg_topological_classes_2.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_topological_classes_2.0.npy -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_01.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_01.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_02.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_02.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_03.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_03.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_04.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_04.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_05.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_05.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_06.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_06.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_07.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_07.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_08.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_08.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_09.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_09.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_10.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_10.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_11.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_11.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_12.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_12.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_13.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_13.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_14.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_14.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_15.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_15.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_16.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_16.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_17.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_17.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_18.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_18.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_19.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_19.nii.gz -------------------------------------------------------------------------------- /data/training_label_maps/training_seg_20.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_20.nii.gz -------------------------------------------------------------------------------- /data/tutorial_7/noisy_segmentations_d/0001.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/noisy_segmentations_d/0001.nii.gz -------------------------------------------------------------------------------- /data/tutorial_7/noisy_segmentations_d/0002.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/noisy_segmentations_d/0002.nii.gz -------------------------------------------------------------------------------- /data/tutorial_7/noisy_segmentations_d/0003.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/noisy_segmentations_d/0003.nii.gz -------------------------------------------------------------------------------- /data/tutorial_7/segmentation_labels_s1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/segmentation_labels_s1.npy -------------------------------------------------------------------------------- /data/tutorial_7/target_segmentations_d/0001.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/target_segmentations_d/0001.nii.gz -------------------------------------------------------------------------------- /data/tutorial_7/target_segmentations_d/0002.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/target_segmentations_d/0002.nii.gz -------------------------------------------------------------------------------- /data/tutorial_7/target_segmentations_d/0003.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/target_segmentations_d/0003.nii.gz -------------------------------------------------------------------------------- /ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/ext/__init__.py -------------------------------------------------------------------------------- /ext/lab2im/__init__.py: -------------------------------------------------------------------------------- 1 | from . import edit_tensors 2 | from . import edit_volumes 3 | from . import image_generator 4 | from . import lab2im_model 5 | from . import layers 6 | from . import utils 7 | -------------------------------------------------------------------------------- /ext/lab2im/image_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite the first SynthSeg paper: 3 | https://github.com/BBillot/lab2im/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import numpy as np 19 | import numpy.random as npr 20 | 21 | # project imports 22 | from ext.lab2im import utils 23 | from ext.lab2im import edit_volumes 24 | from ext.lab2im.lab2im_model import lab2im_model 25 | 26 | 27 | class ImageGenerator: 28 | 29 | def __init__(self, 30 | labels_dir, 31 | generation_labels=None, 32 | output_labels=None, 33 | batchsize=1, 34 | n_channels=1, 35 | target_res=None, 36 | output_shape=None, 37 | output_div_by_n=None, 38 | generation_classes=None, 39 | prior_distributions='uniform', 40 | prior_means=None, 41 | prior_stds=None, 42 | use_specific_stats_for_channel=False, 43 | blur_range=1.15): 44 | """ 45 | This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels 46 | maps, and a python generator that supplies the input data for this model. 47 | To generate pairs of image/labels you can just call the method generate_image() on an object of this class. 48 | 49 | :param labels_dir: path of folder with all input label maps, or to a single label map. 50 | 51 | # IMPORTANT !!! 52 | # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), 53 | # these values refer to the RAS axes. 54 | 55 | # label maps-related parameters 56 | :param generation_labels: (optional) list of all possible label values in the input label maps. 57 | Default is None, where the label values are directly gotten from the provided label maps. 58 | If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array. 59 | :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in 60 | the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps 61 | will be converted to output_labels[i] in the returned label maps. Examples: 62 | Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. 63 | Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. 64 | Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. 65 | 66 | # output-related parameters 67 | :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. 68 | :param n_channels: (optional) number of channels to be synthetised. Default is 1. 69 | :param target_res: (optional) target resolution of the generated images and corresponding label maps. 70 | If None, the outputs will have the same resolution as the input label maps. 71 | Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array. 72 | :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image. 73 | Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. 74 | :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites 75 | output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or 76 | the path to a 1d numpy array. 77 | 78 | # GMM-sampling parameters 79 | :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity 80 | distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a 81 | 1d numpy array, or the path to a 1d numpy array. 82 | It should have the same length as generation_labels, and contain values between 0 and K-1, where K is the total 83 | number of classes. Default is all labels have different classes (K=len(generation_labels)). 84 | :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. 85 | Can either be 'uniform', or 'normal'. Default is 'uniform'. 86 | :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because 87 | these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: 88 | 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is 89 | uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each 90 | mini_batch from the same distribution. 91 | 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is 92 | not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each 93 | mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from 94 | N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. 95 | 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived 96 | from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a 97 | modality from the n_mod possibilities, and we sample the GMM means like in 2). 98 | If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel 99 | (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. 100 | 4) the path to such a numpy array. 101 | Default is None, which corresponds to prior_means = [25, 225]. 102 | :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. 103 | Default is None, which corresponds to prior_stds = [5, 25]. 104 | :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be 105 | only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. 106 | 107 | # blurring parameters 108 | :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is 109 | given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a c 110 | coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range]. 111 | If None, no randomisation. Default is 1.15. 112 | """ 113 | 114 | # prepare data files 115 | self.labels_paths = utils.list_images_in_folder(labels_dir) 116 | 117 | # generation parameters 118 | self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ 119 | utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) 120 | self.n_channels = n_channels 121 | if generation_labels is not None: 122 | self.generation_labels = utils.load_array_if_path(generation_labels) 123 | else: 124 | self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir) 125 | if output_labels is not None: 126 | self.output_labels = utils.load_array_if_path(output_labels) 127 | else: 128 | self.output_labels = self.generation_labels 129 | self.target_res = utils.load_array_if_path(target_res) 130 | self.batchsize = batchsize 131 | # preliminary operations 132 | self.output_shape = utils.load_array_if_path(output_shape) 133 | self.output_div_by_n = output_div_by_n 134 | # GMM parameters 135 | self.prior_distributions = prior_distributions 136 | if generation_classes is not None: 137 | self.generation_classes = utils.load_array_if_path(generation_classes) 138 | assert self.generation_classes.shape == self.generation_labels.shape, \ 139 | 'if provided, generation labels should have the same shape as generation_labels' 140 | unique_classes = np.unique(self.generation_classes) 141 | assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ 142 | 'generation_classes should a linear range between 0 and its maximum value.' 143 | else: 144 | self.generation_classes = np.arange(self.generation_labels.shape[0]) 145 | self.prior_means = utils.load_array_if_path(prior_means) 146 | self.prior_stds = utils.load_array_if_path(prior_stds) 147 | self.use_specific_stats_for_channel = use_specific_stats_for_channel 148 | 149 | # blurring parameters 150 | self.blur_range = blur_range 151 | 152 | # build transformation model 153 | self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() 154 | 155 | # build generator for model inputs 156 | self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels)) 157 | 158 | # build brain generator 159 | self.image_generator = self._build_image_generator() 160 | 161 | def _build_lab2im_model(self): 162 | # build_model 163 | lab_to_im_model = lab2im_model(labels_shape=self.labels_shape, 164 | n_channels=self.n_channels, 165 | generation_labels=self.generation_labels, 166 | output_labels=self.output_labels, 167 | atlas_res=self.atlas_res, 168 | target_res=self.target_res, 169 | output_shape=self.output_shape, 170 | output_div_by_n=self.output_div_by_n, 171 | blur_range=self.blur_range) 172 | out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] 173 | return lab_to_im_model, out_shape 174 | 175 | def _build_image_generator(self): 176 | while True: 177 | model_inputs = next(self.model_inputs_generator) 178 | [image, labels] = self.labels_to_image_model.predict(model_inputs) 179 | yield image, labels 180 | 181 | def generate_image(self): 182 | """call this method when an object of this class has been instantiated to generate new brains""" 183 | (image, labels) = next(self.image_generator) 184 | # put back images in native space 185 | list_images = list() 186 | list_labels = list() 187 | for i in range(self.batchsize): 188 | list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff, 189 | n_dims=self.n_dims)) 190 | list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff, 191 | n_dims=self.n_dims)) 192 | image = np.stack(list_images, axis=0) 193 | labels = np.stack(list_labels, axis=0) 194 | return np.squeeze(image), np.squeeze(labels) 195 | 196 | def _build_model_inputs(self, n_labels): 197 | 198 | # get label info 199 | _, _, n_dims, _, _, _ = utils.get_volume_info(self.labels_paths[0]) 200 | 201 | # Generate! 202 | while True: 203 | 204 | # randomly pick as many images as batchsize 205 | indices = npr.randint(len(self.labels_paths), size=self.batchsize) 206 | 207 | # initialise input lists 208 | list_label_maps = [] 209 | list_means = [] 210 | list_stds = [] 211 | 212 | for idx in indices: 213 | 214 | # load label in identity space, and add them to inputs 215 | y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4)) 216 | list_label_maps.append(utils.add_axis(y, axis=[0, -1])) 217 | 218 | # add means and standard deviations to inputs 219 | means = np.empty((1, n_labels, 0)) 220 | stds = np.empty((1, n_labels, 0)) 221 | for channel in range(self.n_channels): 222 | 223 | # retrieve channel specific stats if necessary 224 | if isinstance(self.prior_means, np.ndarray): 225 | if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel: 226 | if self.prior_means.shape[0] / 2 != self.n_channels: 227 | raise ValueError("the number of blocks in prior_means does not match n_channels. This " 228 | "message is printed because use_specific_stats_for_channel is True.") 229 | tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :] 230 | else: 231 | tmp_prior_means = self.prior_means 232 | else: 233 | tmp_prior_means = self.prior_means 234 | if isinstance(self.prior_stds, np.ndarray): 235 | if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel: 236 | if self.prior_stds.shape[0] / 2 != self.n_channels: 237 | raise ValueError("the number of blocks in prior_stds does not match n_channels. This " 238 | "message is printed because use_specific_stats_for_channel is True.") 239 | tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :] 240 | else: 241 | tmp_prior_stds = self.prior_stds 242 | else: 243 | tmp_prior_stds = self.prior_stds 244 | 245 | # draw means and std devs from priors 246 | tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, 247 | self.prior_distributions, 125., 100., 248 | positive_only=True) 249 | tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, 250 | self.prior_distributions, 15., 10., 251 | positive_only=True) 252 | tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1]) 253 | tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1]) 254 | means = np.concatenate([means, tmp_means], axis=-1) 255 | stds = np.concatenate([stds, tmp_stds], axis=-1) 256 | list_means.append(means) 257 | list_stds.append(stds) 258 | 259 | # build list of inputs of augmentation model 260 | list_inputs = [list_label_maps, list_means, list_stds] 261 | if self.batchsize > 1: # concatenate individual input types if batchsize > 1 262 | list_inputs = [np.concatenate(item, 0) for item in list_inputs] 263 | else: 264 | list_inputs = [item[0] for item in list_inputs] 265 | 266 | yield list_inputs 267 | -------------------------------------------------------------------------------- /ext/lab2im/lab2im_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite the first SynthSeg paper: 3 | https://github.com/BBillot/lab2im/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # python imports 18 | import numpy as np 19 | import keras.layers as KL 20 | from keras.models import Model 21 | 22 | # project imports 23 | from ext.lab2im import utils 24 | from ext.lab2im import layers 25 | from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling 26 | 27 | 28 | def lab2im_model(labels_shape, 29 | n_channels, 30 | generation_labels, 31 | output_labels, 32 | atlas_res, 33 | target_res, 34 | output_shape=None, 35 | output_div_by_n=None, 36 | blur_range=1.15): 37 | """ 38 | This function builds a keras/tensorflow model to generate images from provided label maps. 39 | The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. 40 | The model will take as inputs: 41 | -a label map 42 | -a vector containing the means of the Gaussian Mixture Model for each label, 43 | -a vector containing the standard deviations of the Gaussian Mixture Model for each label, 44 | -an array of size batch*(n_dims+1)*(n_dims+1) representing a linear transformation 45 | The model returns: 46 | -the generated image normalised between 0 and 1. 47 | -the corresponding label map, with only the labels present in output_labels (the other are reset to zero). 48 | :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array. 49 | :param n_channels: number of channels to be synthetised. 50 | :param generation_labels: list of all possible label values in the input label maps. 51 | Can be a sequence or a 1d numpy array. 52 | :param output_labels: list of the same length as generation_labels to indicate which values to use in the label maps 53 | returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be converted to 54 | output_labels[i] in the returned label maps. Examples: 55 | Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. 56 | Set output_labels[i]=generation_labels[i] if you wish to keep the value generation_labels[i] in the returned maps. 57 | Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. 58 | :param atlas_res: resolution of the input label maps. 59 | Can be a number (isotropic resolution), a sequence, or a 1d numpy array. 60 | :param target_res: target resolution of the generated images and corresponding label maps. 61 | Can be a number (isotropic resolution), a sequence, or a 1d numpy array. 62 | :param output_shape: (optional) desired shape of the output images. 63 | If the atlas and target resolutions are the same, the output will be cropped to output_shape, and if the two 64 | resolutions are different, the output will be resized with trilinear interpolation to output_shape. 65 | Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. 66 | :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape 67 | if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. 68 | :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given 69 | or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled 70 | from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15. 71 | """ 72 | 73 | # reformat resolutions 74 | labels_shape = utils.reformat_to_list(labels_shape) 75 | n_dims, _ = utils.get_dims(labels_shape) 76 | atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] 77 | target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] 78 | 79 | # get shapes 80 | crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) 81 | 82 | # define model inputs 83 | labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32') 84 | means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') 85 | stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input') 86 | 87 | # deform labels 88 | labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input) 89 | 90 | # cropping 91 | if crop_shape != labels_shape: 92 | labels._keras_shape = tuple(labels.get_shape().as_list()) 93 | labels = layers.RandomCrop(crop_shape)(labels) 94 | 95 | # build synthetic image 96 | labels._keras_shape = tuple(labels.get_shape().as_list()) 97 | image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) 98 | 99 | # apply bias field 100 | image._keras_shape = tuple(image.get_shape().as_list()) 101 | image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image) 102 | 103 | # intensity augmentation 104 | image._keras_shape = tuple(image.get_shape().as_list()) 105 | image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image) 106 | 107 | # blur image 108 | sigma = blurring_sigma_for_downsampling(atlas_res, target_res) 109 | image._keras_shape = tuple(image.get_shape().as_list()) 110 | image = layers.GaussianBlur(sigma=sigma, random_blur_range=blur_range)(image) 111 | 112 | # resample to target res 113 | if crop_shape != output_shape: 114 | image = resample_tensor(image, output_shape, interp_method='linear') 115 | labels = resample_tensor(labels, output_shape, interp_method='nearest') 116 | 117 | # reset unwanted labels to zero 118 | labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) 119 | 120 | # build model (dummy layer enables to keep the labels when plugging this model to other models) 121 | image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) 122 | brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels]) 123 | 124 | return brain_model 125 | 126 | 127 | def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n): 128 | 129 | n_dims = len(atlas_res) 130 | 131 | # get resampling factor 132 | if atlas_res.tolist() != target_res.tolist(): 133 | resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)] 134 | else: 135 | resample_factor = None 136 | 137 | # output shape specified, need to get cropping shape, and resample shape if necessary 138 | if output_shape is not None: 139 | output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') 140 | 141 | # make sure that output shape is smaller or equal to label shape 142 | if resample_factor is not None: 143 | output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] 144 | else: 145 | output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] 146 | 147 | # make sure output shape is divisible by output_div_by_n 148 | if output_div_by_n is not None: 149 | tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) 150 | for s in output_shape] 151 | if output_shape != tmp_shape: 152 | print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, 153 | tmp_shape)) 154 | output_shape = tmp_shape 155 | 156 | # get cropping and resample shape 157 | if resample_factor is not None: 158 | cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] 159 | else: 160 | cropping_shape = output_shape 161 | 162 | # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n 163 | else: 164 | cropping_shape = labels_shape 165 | if resample_factor is not None: 166 | output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)] 167 | else: 168 | output_shape = cropping_shape 169 | # make sure output shape is divisible by output_div_by_n 170 | if output_div_by_n is not None: 171 | output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer') 172 | for s in output_shape] 173 | 174 | return cropping_shape, output_shape 175 | -------------------------------------------------------------------------------- /ext/neuron/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import models 3 | from . import utils 4 | -------------------------------------------------------------------------------- /models/synthseg_1.0.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/models/synthseg_1.0.h5 -------------------------------------------------------------------------------- /requirements_python3.6.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | astor==0.8.1 3 | cached-property==1.5.2 4 | cachetools==4.2.4 5 | certifi==2022.12.7 6 | charset-normalizer==2.0.12 7 | cycler==0.11.0 8 | dataclasses==0.8 9 | gast==0.2.2 10 | google-auth==1.35.0 11 | google-auth-oauthlib==0.4.6 12 | google-pasta==0.2.0 13 | grpcio==1.48.2 14 | h5py==2.10.0 15 | idna==3.4 16 | importlib-metadata==4.8.3 17 | Keras==2.3.1 18 | Keras-Applications==1.0.8 19 | Keras-Preprocessing==1.1.2 20 | kiwisolver==1.3.1 21 | Markdown==3.3.7 22 | matplotlib==3.3.4 23 | nibabel==3.2.2 24 | numpy==1.19.5 25 | oauthlib==3.2.2 26 | opt-einsum==3.3.0 27 | packaging==21.3 28 | Pillow==8.4.0 29 | protobuf==3.19.6 30 | pyasn1==0.4.8 31 | pyasn1-modules==0.2.8 32 | pyparsing==3.0.9 33 | python-dateutil==2.8.2 34 | PyYAML==6.0 35 | requests==2.27.1 36 | requests-oauthlib==1.3.1 37 | rsa==4.9 38 | scipy==1.5.4 39 | six==1.16.0 40 | tensorboard==2.0.2 41 | tensorflow-estimator==2.0.1 42 | tensorflow-gpu==2.0.0 43 | termcolor==1.1.0 44 | typing_extensions==4.1.1 45 | urllib3==1.26.15 46 | Werkzeug==2.0.3 47 | wrapt==1.15.0 48 | zipp==3.6.0 49 | -------------------------------------------------------------------------------- /requirements_python3.8.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | astunparse==1.6.3 3 | cachetools==4.2.4 4 | certifi==2022.12.7 5 | charset-normalizer==3.1.0 6 | contourpy==1.0.7 7 | cycler==0.11.0 8 | fonttools==4.39.2 9 | gast==0.3.3 10 | google-auth==1.35.0 11 | google-auth-oauthlib==0.4.6 12 | google-pasta==0.2.0 13 | grpcio==1.51.3 14 | h5py==2.10.0 15 | idna==3.4 16 | importlib-metadata==6.0.0 17 | Keras==2.3.1 18 | Keras-Applications==1.0.8 19 | Keras-Preprocessing==1.1.2 20 | kiwisolver==1.4.4 21 | Markdown==3.4.1 22 | MarkupSafe==2.1.2 23 | matplotlib==3.6.2 24 | nibabel==5.0.1 25 | numpy==1.23.5 26 | oauthlib==3.2.2 27 | opt-einsum==3.3.0 28 | packaging==23.0 29 | Pillow==9.4.0 30 | protobuf==3.20.3 31 | pyasn1==0.4.8 32 | pyasn1-modules==0.2.8 33 | pyparsing==3.0.9 34 | python-dateutil==2.8.2 35 | PyYAML==6.0 36 | requests==2.28.2 37 | requests-oauthlib==1.3.1 38 | rsa==4.9 39 | scipy==1.4.1 40 | six==1.16.0 41 | tensorboard==2.2.2 42 | tensorboard-plugin-wit==1.8.1 43 | tensorflow-estimator==2.2.0 44 | tensorflow-gpu==2.2.0 45 | termcolor==2.2.0 46 | urllib3==1.26.15 47 | Werkzeug==2.2.3 48 | wrapt==1.15.0 49 | zipp==3.15.0 50 | -------------------------------------------------------------------------------- /scripts/commands/SynthSeg_predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script enables to launch predictions with SynthSeg from the terminal. 3 | 4 | If you use this code, please cite one of the SynthSeg papers: 5 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 6 | 7 | Copyright 2020 Benjamin Billot 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 10 | compliance with the License. You may obtain a copy of the License at 11 | https://www.apache.org/licenses/LICENSE-2.0 12 | Unless required by applicable law or agreed to in writing, software distributed under the License is 13 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 14 | implied. See the License for the specific language governing permissions and limitations under the 15 | License. 16 | """ 17 | 18 | # python imports 19 | import os 20 | import sys 21 | from argparse import ArgumentParser 22 | 23 | # add main folder to python path and import ./SynthSeg/predict_synthseg.py 24 | synthseg_home = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))) 25 | sys.path.append(synthseg_home) 26 | model_dir = os.path.join(synthseg_home, 'models') 27 | labels_dir = os.path.join(synthseg_home, 'data/labels_classes_priors') 28 | from SynthSeg.predict_synthseg import predict 29 | 30 | 31 | # parse arguments 32 | parser = ArgumentParser(description="SynthSeg", epilog='\n') 33 | 34 | # input/outputs 35 | parser.add_argument("--i", help="Image(s) to segment. Can be a path to an image or to a folder.") 36 | parser.add_argument("--o", help="Segmentation output(s). Must be a folder if --i designates a folder.") 37 | parser.add_argument("--parc", action="store_true", help="(optional) Whether to perform cortex parcellation.") 38 | parser.add_argument("--robust", action="store_true", help="(optional) Whether to use robust predictions (slower).") 39 | parser.add_argument("--fast", action="store_true", help="(optional) Bypass some postprocessing for faster predictions.") 40 | parser.add_argument("--ct", action="store_true", help="(optional) Clip intensities to [0,80] for CT scans.") 41 | parser.add_argument("--vol", help="(optional) Path to output CSV file with volumes (mm3) for all regions and subjects.") 42 | parser.add_argument("--qc", help="(optional) Path to output CSV file with qc scores for all subjects.") 43 | parser.add_argument("--post", help="(optional) Posteriors output(s). Must be a folder if --i designates a folder.") 44 | parser.add_argument("--resample", help="(optional) Resampled image(s). Must be a folder if --i designates a folder.") 45 | parser.add_argument("--crop", nargs='+', type=int, help="(optional) Size of 3D patches to analyse. Default is 192.") 46 | parser.add_argument("--threads", type=int, default=1, help="(optional) Number of cores to be used. Default is 1.") 47 | parser.add_argument("--cpu", action="store_true", help="(optional) Enforce running with CPU rather than GPU.") 48 | parser.add_argument("--v1", action="store_true", help="(optional) Use SynthSeg 1.0 (updated 25/06/22).") 49 | 50 | # check for no arguments 51 | if len(sys.argv) < 2: 52 | parser.print_help() 53 | sys.exit(1) 54 | 55 | # parse commandline 56 | args = vars(parser.parse_args()) 57 | 58 | # print SynthSeg version and checks boolean params for SynthSeg-robust 59 | if args['robust']: 60 | args['fast'] = True 61 | assert not args['v1'], 'The flag --v1 cannot be used with --robust since SynthSeg-robust only came out with 2.0.' 62 | version = 'SynthSeg-robust 2.0' 63 | else: 64 | version = 'SynthSeg 1.0' if args['v1'] else 'SynthSeg 2.0' 65 | if args['fast']: 66 | version += ' (fast)' 67 | print('\n' + version + '\n') 68 | 69 | # enforce CPU processing if necessary 70 | if args['cpu']: 71 | print('using CPU, hiding all CUDA_VISIBLE_DEVICES') 72 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 73 | 74 | # limit the number of threads to be used if running on CPU 75 | import tensorflow as tf 76 | if args['threads'] == 1: 77 | print('using 1 thread') 78 | else: 79 | print('using %s threads' % args['threads']) 80 | tf.config.threading.set_inter_op_parallelism_threads(args['threads']) 81 | tf.config.threading.set_intra_op_parallelism_threads(args['threads']) 82 | 83 | # path models 84 | if args['robust']: 85 | args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_robust_2.0.h5') 86 | else: 87 | args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_2.0.h5') 88 | args['path_model_parcellation'] = os.path.join(model_dir, 'synthseg_parc_2.0.h5') 89 | args['path_model_qc'] = os.path.join(model_dir, 'synthseg_qc_2.0.h5') 90 | 91 | # path labels 92 | args['labels_segmentation'] = os.path.join(labels_dir, 'synthseg_segmentation_labels_2.0.npy') 93 | args['labels_denoiser'] = os.path.join(labels_dir, 'synthseg_denoiser_labels_2.0.npy') 94 | args['labels_parcellation'] = os.path.join(labels_dir, 'synthseg_parcellation_labels.npy') 95 | args['labels_qc'] = os.path.join(labels_dir, 'synthseg_qc_labels_2.0.npy') 96 | args['names_segmentation_labels'] = os.path.join(labels_dir, 'synthseg_segmentation_names_2.0.npy') 97 | args['names_parcellation_labels'] = os.path.join(labels_dir, 'synthseg_parcellation_names.npy') 98 | args['names_qc_labels'] = os.path.join(labels_dir, 'synthseg_qc_names_2.0.npy') 99 | args['topology_classes'] = os.path.join(labels_dir, 'synthseg_topological_classes_2.0.npy') 100 | args['n_neutral_labels'] = 19 101 | 102 | # use previous model if needed 103 | if args['v1']: 104 | args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_1.0.h5') 105 | args['labels_segmentation'] = args['labels_segmentation'].replace('_2.0.npy', '.npy') 106 | args['labels_qc'] = args['labels_qc'].replace('_2.0.npy', '.npy') 107 | args['names_segmentation_labels'] = args['names_segmentation_labels'].replace('_2.0.npy', '.npy') 108 | args['names_qc_labels'] = args['names_qc_labels'].replace('_2.0.npy', '.npy') 109 | args['topology_classes'] = args['topology_classes'].replace('_2.0.npy', '.npy') 110 | args['n_neutral_labels'] = 18 111 | 112 | # run prediction 113 | predict(path_images=args['i'], 114 | path_segmentations=args['o'], 115 | path_model_segmentation=args['path_model_segmentation'], 116 | labels_segmentation=args['labels_segmentation'], 117 | robust=args['robust'], 118 | fast=args['fast'], 119 | v1=args['v1'], 120 | do_parcellation=args['parc'], 121 | n_neutral_labels=args['n_neutral_labels'], 122 | names_segmentation=args['names_segmentation_labels'], 123 | labels_denoiser=args['labels_denoiser'], 124 | path_posteriors=args['post'], 125 | path_resampled=args['resample'], 126 | path_volumes=args['vol'], 127 | path_model_parcellation=args['path_model_parcellation'], 128 | labels_parcellation=args['labels_parcellation'], 129 | names_parcellation=args['names_parcellation_labels'], 130 | path_model_qc=args['path_model_qc'], 131 | labels_qc=args['labels_qc'], 132 | path_qc_scores=args['qc'], 133 | names_qc=args['names_qc_labels'], 134 | cropping=args['crop'], 135 | topology_classes=args['topology_classes'], 136 | ct=args['ct']) 137 | -------------------------------------------------------------------------------- /scripts/commands/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # imports 18 | from argparse import ArgumentParser 19 | from SynthSeg.predict import predict 20 | 21 | parser = ArgumentParser() 22 | 23 | # Positional arguments 24 | parser.add_argument("path_images", type=str, help="path single image or path of the folders with training labels") 25 | parser.add_argument("path_segmentations", type=str, help="segmentations folder/path") 26 | parser.add_argument("path_model", type=str, help="model file path") 27 | 28 | # labels parameters 29 | parser.add_argument("labels_segmentation", type=str, help="path label list") 30 | parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None) 31 | parser.add_argument("--names_list", type=str, dest="names_segmentation", default=None, 32 | help="path list of label names, only used if --vol is specified") 33 | 34 | # Saving paths 35 | parser.add_argument("--post", type=str, dest="path_posteriors", default=None, help="posteriors folder/path") 36 | parser.add_argument("--resampled", type=str, dest="path_resampled", default=None, 37 | help="path/folder of the images resampled at the given target resolution") 38 | parser.add_argument("--vol", type=str, dest="path_volumes", default=None, help="path volume file") 39 | 40 | # Processing parameters 41 | parser.add_argument("--min_pad", type=int, dest="min_pad", default=None, 42 | help="margin of the padding") 43 | parser.add_argument("--cropping", type=int, dest="cropping", default=None, 44 | help="crop volume before processing. Segmentations will have the same size as input image.") 45 | parser.add_argument("--target_res", type=float, dest="target_res", default=1., 46 | help="Target resolution at which segmentations will be given.") 47 | parser.add_argument("--flip", action='store_true', dest="flip", 48 | help="to activate test-time augmentation (right/left flipping)") 49 | parser.add_argument("--topology_classes", type=str, dest="topology_classes", default=None, 50 | help="path list of classes, for topologically enhanced biggest connected component analysis") 51 | parser.add_argument("--smoothing", type=float, dest="sigma_smoothing", default=0.5, 52 | help="var for gaussian blurring of the posteriors") 53 | parser.add_argument("--biggest_component", action='store_true', dest="keep_biggest_component", 54 | help="only keep biggest component in segmentation (recommended)") 55 | 56 | # Architecture parameters 57 | parser.add_argument("--conv_size", type=int, dest="conv_size", default=3, help="size of unet convolution masks") 58 | parser.add_argument("--n_levels", type=int, dest="n_levels", default=5, help="number of levels for unet") 59 | parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2, help="conv par level") 60 | parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24, 61 | help="number of features of unet first layer") 62 | parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2, 63 | help="factor of new feature maps per level") 64 | parser.add_argument("--activation", type=str, dest="activation", default='elu', help="activation function") 65 | 66 | # Evaluation parameters 67 | parser.add_argument("--gt", type=str, default=None, dest="gt_folder", 68 | help="folder containing ground truth segmentations, which triggers the evaluation.") 69 | parser.add_argument("--eval_label_list", type=str, dest="evaluation_labels", default=None, 70 | help="labels to evaluate Dice scores on if gt is provided. Default is the same as label_list.") 71 | parser.add_argument("--incorrect_labels", type=str, default=None, dest="list_incorrect_labels", 72 | help="path list labels to correct.") 73 | parser.add_argument("--correct_labels", type=str, default=None, dest="list_correct_labels", 74 | help="path list correct labels.") 75 | 76 | args = parser.parse_args() 77 | predict(**vars(args)) 78 | -------------------------------------------------------------------------------- /scripts/commands/training.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | from argparse import ArgumentParser 18 | from SynthSeg.training import training 19 | from ext.lab2im.utils import infer 20 | 21 | parser = ArgumentParser() 22 | 23 | # ------------------------------------------------- General parameters ------------------------------------------------- 24 | # Positional arguments 25 | parser.add_argument("labels_dir", type=str) 26 | parser.add_argument("model_dir", type=str) 27 | 28 | # ---------------------------------------------- Generation parameters ---------------------------------------------- 29 | # label maps parameters 30 | parser.add_argument("--generation_labels", type=str, dest="generation_labels", default=None) 31 | parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None) 32 | parser.add_argument("--segmentation_labels", type=str, dest="segmentation_labels", default=None) 33 | parser.add_argument("--subjects_prob", type=str, dest="subjects_prob", default=None) 34 | 35 | # output-related parameters 36 | parser.add_argument("--batch_size", type=int, dest="batchsize", default=1) 37 | parser.add_argument("--channels", type=int, dest="n_channels", default=1) 38 | parser.add_argument("--target_res", type=float, dest="target_res", default=None) 39 | parser.add_argument("--output_shape", type=int, dest="output_shape", default=None) 40 | 41 | # GMM-sampling parameters 42 | parser.add_argument("--generation_classes", type=str, dest="generation_classes", default=None) 43 | parser.add_argument("--prior_type", type=str, dest="prior_distributions", default='uniform') 44 | parser.add_argument("--prior_means", type=str, dest="prior_means", default=None) 45 | parser.add_argument("--prior_stds", type=str, dest="prior_stds", default=None) 46 | parser.add_argument("--specific_stats", action='store_true', dest="use_specific_stats_for_channel") 47 | parser.add_argument("--mix_prior_and_random", action='store_true', dest="mix_prior_and_random") 48 | 49 | # spatial deformation parameters 50 | parser.add_argument("--no_flipping", action='store_false', dest="flipping") 51 | parser.add_argument("--scaling", dest="scaling_bounds", type=infer, default=0.2) 52 | parser.add_argument("--rotation", dest="rotation_bounds", type=infer, default=15) 53 | parser.add_argument("--shearing", dest="shearing_bounds", type=infer, default=.012) 54 | parser.add_argument("--translation", dest="translation_bounds", type=infer, default=False) 55 | parser.add_argument("--nonlin_std", type=float, dest="nonlin_std", default=4.) 56 | parser.add_argument("--nonlin_scale", type=float, dest="nonlin_scale", default=.04) 57 | 58 | # blurring/resampling parameters 59 | parser.add_argument("--randomise_res", action='store_true', dest="randomise_res") 60 | parser.add_argument("--max_res_iso", type=float, dest="max_res_iso", default=4.) 61 | parser.add_argument("--max_res_aniso", type=float, dest="max_res_aniso", default=8.) 62 | parser.add_argument("--data_res", dest="data_res", type=infer, default=None) 63 | parser.add_argument("--thickness", dest="thickness", type=infer, default=None) 64 | 65 | # bias field parameters 66 | parser.add_argument("--bias_std", type=float, dest="bias_field_std", default=.7) 67 | parser.add_argument("--bias_scale", type=float, dest="bias_scale", default=.025) 68 | 69 | parser.add_argument("--gradients", action='store_true', dest="return_gradients") 70 | 71 | # -------------------------------------------- UNet architecture parameters -------------------------------------------- 72 | parser.add_argument("--n_levels", type=int, dest="n_levels", default=5) 73 | parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2) 74 | parser.add_argument("--conv_size", type=int, dest="conv_size", default=3) 75 | parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24) 76 | parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2) 77 | parser.add_argument("--activation", type=str, dest="activation", default='elu') 78 | 79 | # ------------------------------------------------- Training parameters ------------------------------------------------ 80 | parser.add_argument("--lr", type=float, dest="lr", default=1e-4) 81 | parser.add_argument("--wl2_epochs", type=int, dest="wl2_epochs", default=1) 82 | parser.add_argument("--dice_epochs", type=int, dest="dice_epochs", default=50) 83 | parser.add_argument("--steps_per_epoch", type=int, dest="steps_per_epoch", default=10000) 84 | parser.add_argument("--checkpoint", type=str, dest="checkpoint", default=None) 85 | 86 | args = parser.parse_args() 87 | training(**vars(args)) 88 | -------------------------------------------------------------------------------- /scripts/commands/training_supervised.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use this code, please cite one of the SynthSeg papers: 3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 4 | 5 | Copyright 2020 Benjamin Billot 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 8 | compliance with the License. You may obtain a copy of the License at 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is 11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing permissions and limitations under the 13 | License. 14 | """ 15 | 16 | 17 | # imports 18 | from argparse import ArgumentParser 19 | from SynthSeg.training_supervised import training 20 | from ext.lab2im.utils import infer 21 | 22 | parser = ArgumentParser() 23 | 24 | # ------------------------------------------------- General parameters ------------------------------------------------- 25 | # Positional arguments 26 | parser.add_argument("image_dir", type=str) 27 | parser.add_argument("labels_dir", type=str) 28 | parser.add_argument("model_dir", type=str) 29 | 30 | # label maps parameters 31 | parser.add_argument("--segmentation_labels", type=str, dest="segmentation_labels", default=None) 32 | parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None) 33 | parser.add_argument("--subjects_prob", type=str, dest="subjects_prob", default=None) 34 | 35 | # output-related parameters 36 | parser.add_argument("--batch_size", type=int, dest="batchsize", default=1) 37 | parser.add_argument("--target_res", type=int, dest="target_res", default=None) 38 | parser.add_argument("--output_shape", type=int, dest="output_shape", default=None) 39 | 40 | # ----------------------------------------------- Augmentation parameters ---------------------------------------------- 41 | # spatial deformation parameters 42 | parser.add_argument("--no_flipping", action='store_false', dest="flipping") 43 | parser.add_argument("--scaling", dest="scaling_bounds", type=infer, default=.2) 44 | parser.add_argument("--rotation", dest="rotation_bounds", type=infer, default=15) 45 | parser.add_argument("--shearing", dest="shearing_bounds", type=infer, default=.012) 46 | parser.add_argument("--translation", dest="translation_bounds", type=infer, default=False) 47 | parser.add_argument("--nonlin_std", type=float, dest="nonlin_std", default=4.) 48 | parser.add_argument("--nonlin_scale", type=float, dest="nonlin_scale", default=.04) 49 | 50 | # resampling parameters 51 | parser.add_argument("--randomise_res", action='store_true', dest="randomise_res") 52 | parser.add_argument("--max_res_iso", type=float, dest="max_res_iso", default=4.) 53 | parser.add_argument("--max_res_aniso", type=float, dest="max_res_aniso", default=8.) 54 | parser.add_argument("--data_res", dest="data_res", type=infer, default=None) 55 | parser.add_argument("--thickness", dest="thickness", type=infer, default=None) 56 | 57 | # bias field parameters 58 | parser.add_argument("--bias_std", type=float, dest="bias_field_std", default=.7) 59 | parser.add_argument("--bias_scale", type=float, dest="bias_scale", default=.025) 60 | 61 | parser.add_argument("--gradients", action='store_true', dest="return_gradients") 62 | 63 | # -------------------------------------------- UNet architecture parameters -------------------------------------------- 64 | parser.add_argument("--n_levels", type=int, dest="n_levels", default=5) 65 | parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2) 66 | parser.add_argument("--conv_size", type=int, dest="conv_size", default=3) 67 | parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24) 68 | parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2) 69 | parser.add_argument("--activation", type=str, dest="activation", default='elu') 70 | 71 | # ------------------------------------------------- Training parameters ------------------------------------------------ 72 | parser.add_argument("--lr", type=float, dest="lr", default=1e-4) 73 | parser.add_argument("--wl2_epochs", type=int, dest="wl2_epochs", default=1) 74 | parser.add_argument("--dice_epochs", type=int, dest="dice_epochs", default=50) 75 | parser.add_argument("--steps_per_epoch", type=int, dest="steps_per_epoch", default=10000) 76 | parser.add_argument("--checkpoint", type=str, dest="checkpoint", default=None) 77 | 78 | args = parser.parse_args() 79 | training(**vars(args)) 80 | -------------------------------------------------------------------------------- /scripts/tutorials/1-generation_visualisation.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Very simple script to generate an example of the synthetic data used to train SynthSeg. 4 | This is for visualisation purposes, since it uses all the default parameters. 5 | 6 | 7 | 8 | If you use this code, please cite one of the SynthSeg papers: 9 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 10 | 11 | Copyright 2020 Benjamin Billot 12 | 13 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 14 | compliance with the License. You may obtain a copy of the License at 15 | https://www.apache.org/licenses/LICENSE-2.0 16 | Unless required by applicable law or agreed to in writing, software distributed under the License is 17 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 18 | implied. See the License for the specific language governing permissions and limitations under the 19 | License. 20 | """ 21 | 22 | 23 | from ext.lab2im import utils 24 | from SynthSeg.brain_generator import BrainGenerator 25 | 26 | # generate an image from the label map. 27 | brain_generator = BrainGenerator('../../data/training_label_maps/training_seg_01.nii.gz') 28 | im, lab = brain_generator.generate_brain() 29 | 30 | # save output image and label map under SynthSeg/generated_examples 31 | utils.save_volume(im, brain_generator.aff, brain_generator.header, './outputs_tutorial_1/image.nii.gz') 32 | utils.save_volume(lab, brain_generator.aff, brain_generator.header, './outputs_tutorial_1/labels.nii.gz') 33 | -------------------------------------------------------------------------------- /scripts/tutorials/2-generation_explained.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This script explains how the different parameters controlling the generation of the synthetic data. 4 | These parameters will be reused in the training function, but we describe them here, as the synthetic images are saved, 5 | and thus can be visualised. 6 | Note that most of the parameters here are set to their default value, but we show them nonetheless, just to explain 7 | their effect. Moreover, we encourage the user to play with them to get a sense of their impact on the generation. 8 | 9 | 10 | 11 | If you use this code, please cite one of the SynthSeg papers: 12 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 13 | 14 | Copyright 2020 Benjamin Billot 15 | 16 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 17 | compliance with the License. You may obtain a copy of the License at 18 | https://www.apache.org/licenses/LICENSE-2.0 19 | Unless required by applicable law or agreed to in writing, software distributed under the License is 20 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 21 | implied. See the License for the specific language governing permissions and limitations under the 22 | License. 23 | """ 24 | 25 | 26 | import os 27 | from ext.lab2im import utils 28 | from SynthSeg.brain_generator import BrainGenerator 29 | 30 | # script parameters 31 | n_examples = 5 # number of examples to generate in this script 32 | result_dir = './outputs_tutorial_2' # folder where examples will be saved 33 | 34 | 35 | # ---------- Input label maps and associated values ---------- 36 | 37 | # folder containing label maps to generate images from (note that they must have a ".nii", ".nii.gz" or ".mgz" format) 38 | path_label_map = '../../data/training_label_maps' 39 | 40 | # Here we specify the structures in the label maps for which we want to generate intensities. 41 | # This is given as a list of label values, which do not necessarily need to be present in every label map. 42 | # However, these labels must follow a specific order: first the background, and then all the other labels. Moreover, if 43 | # 1) the label maps contain some right/left-specific label values, and 2) we activate flipping augmentation (which is 44 | # true by default), then the rest of the labels must follow a strict order: 45 | # first the non-sided labels (i.e. those which are not right/left specific), then all the left labels, and finally the 46 | # corresponding right labels (in the same order as the left ones). Please make sure each that each sided label has a 47 | # right and a left value (this is essential!!!). 48 | # 49 | # Example: generation_labels = [0, # background 50 | # 24, # CSF 51 | # 507, # extra-cerebral soft tissues 52 | # 2, # left white matter 53 | # 3, # left cerebral cortex 54 | # 4, # left lateral ventricle 55 | # 17, # left hippocampus 56 | # 25, # left lesions 57 | # 41, # right white matter 58 | # 42, # right cerebral cortex 59 | # 43, # right lateral ventricle 60 | # 53, # right hippocampus 61 | # 57] # right lesions 62 | # Note that plenty of structures are not represented here..... but it's just an example ! :) 63 | generation_labels = '../../data/labels_classes_priors/generation_labels.npy' 64 | 65 | 66 | # We also have to specify the number of non-sided labels in order to differentiate them from the labels with 67 | # right/left values. 68 | # Example: (continuing the previous one): in this example it would be 3 (background, CSF, extra-cerebral soft tissues). 69 | n_neutral_labels = 18 70 | 71 | # By default, the output label maps (i.e. the target segmentations) contain all the labels used for generation. 72 | # However, we may want not to predict all the generation labels (e.g. extra-cerebral soft tissues). 73 | # For this reason, we specify here the target segmentation label corresponding to every generation structure. 74 | # This new list must have the same length as generation_labels, and follow the same order. 75 | # 76 | # Example: (continuing the previous one) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57] 77 | # output_labels = [0, 0, 0, 2, 3, 4, 17, 2, 41, 42, 43, 53, 41] 78 | # Note that in this example the labels 24 (CSF), and 507 (extra-cerebral soft tissues) are not predicted, or said 79 | # differently they are segmented as background. 80 | # Also, the left and right lesions (labels 25 and 57) are segmented as left and right white matter (labels 2 and 41). 81 | output_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy' 82 | 83 | 84 | # ---------- Shape and resolution of the outputs ---------- 85 | 86 | # number of channel to synthesise for multi-modality settings. Set this to 1 (default) in the uni-modality scenario. 87 | n_channels = 1 88 | 89 | # We have the possibility to generate training examples at a different resolution than the training label maps (e.g. 90 | # when using ultra HR training label maps). Here we want to generate at the same resolution as the training label maps, 91 | # so we set this to None. 92 | target_res = None 93 | 94 | # The generative model offers the possibility to randomly crop the training examples to a given size. 95 | # Here we crop them to 160^3, such that the produced images fit on the GPU during training. 96 | output_shape = 160 97 | 98 | 99 | # ---------- GMM sampling parameters ---------- 100 | 101 | # Here we use uniform prior distribution to sample the means/stds of the GMM. Because we don't specify prior_means and 102 | # prior_stds, those priors will have default bounds of [0, 250], and [0, 35]. Those values enable to generate a wide 103 | # range of contrasts (often unrealistic), which will make the segmentation network contrast-agnostic. 104 | prior_distributions = 'uniform' 105 | 106 | # We regroup labels with similar tissue types into K "classes", so that intensities of similar regions are sampled 107 | # from the same Gaussian distribution. This is achieved by providing a list indicating the class of each label. 108 | # It should have the same length as generation_labels, and follow the same order. Importantly the class values must be 109 | # between 0 and K-1, where K is the total number of different classes. 110 | # 111 | # Example: (continuing the previous one) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57] 112 | # generation_classes = [0, 1, 2, 3, 4, 5, 4, 6, 7, 8, 9, 8, 10] 113 | # In this example labels 3 and 17 are in the same *class* 4 (that has nothing to do with *label* 4), and thus will be 114 | # associated to the same Gaussian distribution when sampling the GMM. 115 | generation_classes = '../../data/labels_classes_priors/generation_classes.npy' 116 | 117 | 118 | # ---------- Spatial augmentation ---------- 119 | 120 | # We now introduce some parameters concerning the spatial deformation. They enable to set the range of the uniform 121 | # distribution from which the corresponding parameters are selected. 122 | # We note that because the label maps will be resampled with nearest neighbour interpolation, they can look less smooth 123 | # than the original segmentations. 124 | 125 | flipping = True # enable right/left flipping 126 | scaling_bounds = 0.2 # the scaling coefficients will be sampled from U(1-scaling_bounds; 1+scaling_bounds) 127 | rotation_bounds = 15 # the rotation angles will be sampled from U(-rotation_bounds; rotation_bounds) 128 | shearing_bounds = 0.012 # the shearing coefficients will be sampled from U(-shearing_bounds; shearing_bounds) 129 | translation_bounds = False # no translation is performed, as this is already modelled by the random cropping 130 | nonlin_std = 4. # this controls the maximum elastic deformation (higher = more deformation) 131 | bias_field_std = 0.7 # this controls the maximum bias field corruption (higher = more bias) 132 | 133 | 134 | # ---------- Resolution parameters ---------- 135 | 136 | # This enables us to randomise the resolution of the produces images. 137 | # Although being only one parameter, this is crucial !! 138 | randomise_res = True 139 | 140 | 141 | # ------------------------------------------------------ Generate ------------------------------------------------------ 142 | 143 | # instantiate BrainGenerator object 144 | brain_generator = BrainGenerator(labels_dir=path_label_map, 145 | generation_labels=generation_labels, 146 | n_neutral_labels=n_neutral_labels, 147 | prior_distributions=prior_distributions, 148 | generation_classes=generation_classes, 149 | output_labels=output_labels, 150 | n_channels=n_channels, 151 | target_res=target_res, 152 | output_shape=output_shape, 153 | flipping=flipping, 154 | scaling_bounds=scaling_bounds, 155 | rotation_bounds=rotation_bounds, 156 | shearing_bounds=shearing_bounds, 157 | translation_bounds=translation_bounds, 158 | nonlin_std=nonlin_std, 159 | bias_field_std=bias_field_std, 160 | randomise_res=randomise_res) 161 | 162 | for n in range(n_examples): 163 | 164 | # generate new image and corresponding labels 165 | im, lab = brain_generator.generate_brain() 166 | 167 | # save output image and label map 168 | utils.save_volume(im, brain_generator.aff, brain_generator.header, 169 | os.path.join(result_dir, 'image_%s.nii.gz' % n)) 170 | utils.save_volume(lab, brain_generator.aff, brain_generator.header, 171 | os.path.join(result_dir, 'labels_%s.nii.gz' % n)) 172 | -------------------------------------------------------------------------------- /scripts/tutorials/3-training.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This script shows how we trained SynthSeg. 4 | Importantly, it reuses numerous parameters seen in the previous tutorial about image generation 5 | (i.e., 2-generation_explained.py), which we strongly recommend reading before this one. 6 | 7 | 8 | 9 | If you use this code, please cite one of the SynthSeg papers: 10 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 11 | 12 | Copyright 2020 Benjamin Billot 13 | 14 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 15 | compliance with the License. You may obtain a copy of the License at 16 | https://www.apache.org/licenses/LICENSE-2.0 17 | Unless required by applicable law or agreed to in writing, software distributed under the License is 18 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 19 | implied. See the License for the specific language governing permissions and limitations under the 20 | License. 21 | """ 22 | 23 | 24 | # project imports 25 | from SynthSeg.training import training 26 | 27 | 28 | # path training label maps 29 | path_training_label_maps = '../../data/training_label_maps' 30 | path_model_dir = './outputs_tutorial_3/' 31 | batchsize = 1 32 | 33 | # architecture parameters 34 | n_levels = 5 # number of resolution levels 35 | nb_conv_per_level = 2 # number of convolution per level 36 | conv_size = 3 # size of the convolution kernel (e.g. 3x3x3) 37 | unet_feat_count = 24 # number of feature maps after the first convolution 38 | activation = 'elu' # activation for all convolution layers except the last, which will use softmax regardless 39 | feat_multiplier = 2 # if feat_multiplier is set to 1, we will keep the number of feature maps constant throughout the 40 | # network; 2 will double them(resp. half) after each max-pooling (resp. upsampling); 41 | # 3 will triple them, etc. 42 | 43 | # training parameters 44 | lr = 1e-4 # learning rate 45 | wl2_epochs = 1 # number of pre-training epochs with wl2 metric w.r.t. the layer before the softmax 46 | dice_epochs = 100 # number of training epochs 47 | steps_per_epoch = 5000 # number of iteration per epoch 48 | 49 | 50 | # ---------- Generation parameters ---------- 51 | # these parameters are from the previous tutorial, and thus we do not explain them again here 52 | 53 | # generation and segmentation labels 54 | path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy' 55 | n_neutral_labels = 18 56 | path_segmentation_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy' 57 | 58 | # shape and resolution of the outputs 59 | target_res = None 60 | output_shape = 160 61 | n_channels = 1 62 | 63 | # GMM sampling 64 | prior_distributions = 'uniform' 65 | path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy' 66 | 67 | # spatial deformation parameters 68 | flipping = True 69 | scaling_bounds = .2 70 | rotation_bounds = 15 71 | shearing_bounds = .012 72 | translation_bounds = False 73 | nonlin_std = 4. 74 | bias_field_std = .7 75 | 76 | # acquisition resolution parameters 77 | randomise_res = True 78 | 79 | # ------------------------------------------------------ Training ------------------------------------------------------ 80 | 81 | training(path_training_label_maps, 82 | path_model_dir, 83 | generation_labels=path_generation_labels, 84 | segmentation_labels=path_segmentation_labels, 85 | n_neutral_labels=n_neutral_labels, 86 | batchsize=batchsize, 87 | n_channels=n_channels, 88 | target_res=target_res, 89 | output_shape=output_shape, 90 | prior_distributions=prior_distributions, 91 | generation_classes=path_generation_classes, 92 | flipping=flipping, 93 | scaling_bounds=scaling_bounds, 94 | rotation_bounds=rotation_bounds, 95 | shearing_bounds=shearing_bounds, 96 | translation_bounds=translation_bounds, 97 | nonlin_std=nonlin_std, 98 | randomise_res=randomise_res, 99 | bias_field_std=bias_field_std, 100 | n_levels=n_levels, 101 | nb_conv_per_level=nb_conv_per_level, 102 | conv_size=conv_size, 103 | unet_feat_count=unet_feat_count, 104 | feat_multiplier=feat_multiplier, 105 | activation=activation, 106 | lr=lr, 107 | wl2_epochs=wl2_epochs, 108 | dice_epochs=dice_epochs, 109 | steps_per_epoch=steps_per_epoch) 110 | -------------------------------------------------------------------------------- /scripts/tutorials/4-prediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This script shows how to perform inference after having trained your own model. 4 | Importantly, it reuses some of the parameters used in tutorial 3-training. 5 | Moreover, we emphasise that this tutorial explains how to perform inference on your own trained models. 6 | To predict segmentations based on the distributed mode for SynthSeg, please refer to the README.md file. 7 | 8 | 9 | 10 | If you use this code, please cite one of the SynthSeg papers: 11 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 12 | 13 | Copyright 2020 Benjamin Billot 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 16 | compliance with the License. You may obtain a copy of the License at 17 | https://www.apache.org/licenses/LICENSE-2.0 18 | Unless required by applicable law or agreed to in writing, software distributed under the License is 19 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 20 | implied. See the License for the specific language governing permissions and limitations under the 21 | License. 22 | """ 23 | 24 | # project imports 25 | from SynthSeg.predict import predict 26 | 27 | # paths to input/output files 28 | # Here we assume the availability of an image that we wish to segment with a model we have just trained. 29 | # We emphasise that we do not provide such an image (this is just an example after all :)) 30 | # Input images must have a .nii, .nii.gz, or .mgz extension. 31 | # Note that path_images can also be the path to an entire folder, in which case all the images within this folder will 32 | # be segmented. In this case, please provide path_segm (and possibly path_posteriors, and path_resampled) as folder. 33 | path_images = '/a/path/to/an/image/im.nii.gz' 34 | # path to the output segmentation 35 | path_segm = './outputs_tutorial_4/predicted_segmentations/im_seg.nii.gz' 36 | # we can also provide paths for optional files containing the probability map for all predicted labels 37 | path_posteriors = './outputs_tutorial_4/predicted_information/im_post.nii.gz' 38 | # and for a csv file that will contain the volumes of each segmented structure 39 | path_vol = './outputs_tutorial_4/predicted_information/volumes.csv' 40 | 41 | # of course we need to provide the path to the trained model (here we use the main synthseg model). 42 | path_model = '../../models/synthseg_1.0.h5' 43 | # but we also need to provide the path to the segmentation labels used during training 44 | path_segmentation_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy' 45 | # optionally we can give a numpy array with the names corresponding to the structures in path_segmentation_labels 46 | path_segmentation_names = '../../data/labels_classes_priors/synthseg_segmentation_names.npy' 47 | 48 | # We can now provide various parameters to control the preprocessing of the input. 49 | # First we can play with the size of the input. Remember that the size of input must be divisible by 2**n_levels, so the 50 | # input image will be automatically padded to the nearest shape divisible by 2**n_levels (this is just for processing, 51 | # the output will then be cropped to the original image size). 52 | # Alternatively, you can crop the input to a smaller shape for faster processing, or to make it fit on your GPU. 53 | cropping = 192 54 | # Finally, we finish preprocessing the input by resampling it to the resolution at which the network has been trained to 55 | # produce predictions. If the input image has a resolution outside the range [target_res-0.05, target_res+0.05], it will 56 | # automatically be resampled to target_res. 57 | target_res = 1. 58 | # Note that if the image is indeed resampled, you have the option to save the resampled image. 59 | path_resampled = './outputs_tutorial_4/predicted_information/im_resampled_target_res.nii.gz' 60 | 61 | # After the image has been processed by the network, there are again various options to postprocess it. 62 | # First, we can apply some test-time augmentation by flipping the input along the right-left axis and segmenting 63 | # the resulting image. In this case, and if the network has right/left specific labels, it is also very important to 64 | # provide the number of neutral labels. This must be the exact same as the one used during training. 65 | flip = True 66 | n_neutral_labels = 18 67 | # Second, we can smooth the probability maps produced by the network. This doesn't change much the results, but helps to 68 | # reduce high frequency noise in the obtained segmentations. 69 | sigma_smoothing = 0.5 70 | # Then we can operate some fancier version of biggest connected component, by regrouping structures within so-called 71 | # "topological classes". For each class we successively: 1) sum all the posteriors corresponding to the labels of this 72 | # class, 2) obtain a mask for this class by thresholding the summed posteriors by a low value (arbitrarily set to 0.1), 73 | # 3) keep the biggest connected component, and 4) individually apply the obtained mask to the posteriors of all the 74 | # labels for this class. 75 | # Example: (continuing the previous one) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57] 76 | # output_labels = [0, 0, 0, 2, 3, 4, 17, 2, 41, 42, 43, 53, 41] 77 | # topological_classes = [0, 0, 0, 1, 1, 2, 3, 1, 4, 4, 5, 6, 7] 78 | # Here we regroup labels 2 and 3 in the same topological class, same for labels 41 and 42. The topological class of 79 | # unsegmented structures must be set to 0 (like for 24 and 507). 80 | topology_classes = '../../data/labels_classes_priors/synthseg_topological_classes.npy' 81 | # Finally, we can also operate a strict version of biggest connected component, to get rid of unwanted noisy label 82 | # patch that can sometimes occur in the background. If so, we do recommend to use the smoothing option described above. 83 | keep_biggest_component = True 84 | 85 | # Regarding the architecture of the network, we must provide the predict function with the same parameters as during 86 | # training. 87 | n_levels = 5 88 | nb_conv_per_level = 2 89 | conv_size = 3 90 | unet_feat_count = 24 91 | activation = 'elu' 92 | feat_multiplier = 2 93 | 94 | # Finally, we can set up an evaluation step after all images have been segmented. 95 | # In this purpose, we need to provide the path to the ground truth corresponding to the input image(s). 96 | # This is done by using the "gt_folder" parameter, which must have the same type as path_images (i.e., the path to a 97 | # single image or to a folder). If provided as a folder, ground truths must be sorted in the same order as images in 98 | # path_images. 99 | # Just set this to None if you do not want to run evaluation. 100 | gt_folder = '/the/path/to/the/ground_truth/gt.nii.gz' 101 | # Dice scores will be computed and saved as a numpy array in the folder containing the segmentation(s). 102 | # This numpy array will be organised as follows: rows correspond to structures, and columns to subjects. Importantly, 103 | # rows are given in a sorted order. 104 | # Example: we segment 2 subjects, where output_labels = [0, 0, 0, 2, 3, 4, 17, 2, 41, 42, 43, 53, 41] 105 | # so sorted output_labels = [0, 2, 3, 4, 17, 41, 42, 43, 53] 106 | # dice = [[xxx, xxx], # scores for label 0 107 | # [xxx, xxx], # scores for label 2 108 | # [xxx, xxx], # scores for label 3 109 | # [xxx, xxx], # scores for label 4 110 | # [xxx, xxx], # scores for label 17 111 | # [xxx, xxx], # scores for label 41 112 | # [xxx, xxx], # scores for label 42 113 | # [xxx, xxx], # scores for label 43 114 | # [xxx, xxx]] # scores for label 53 115 | # / \ 116 | # subject 1 subject 2 117 | # 118 | # Also we can compute different surface distances (Hausdorff, Hausdorff99, Hausdorff95 and mean surface distance). The 119 | # results will be saved in arrays similar to the Dice scores. 120 | compute_distances = True 121 | 122 | # All right, we're ready to make predictions !! 123 | predict(path_images, 124 | path_segm, 125 | path_model, 126 | path_segmentation_labels, 127 | n_neutral_labels=n_neutral_labels, 128 | path_posteriors=path_posteriors, 129 | path_resampled=path_resampled, 130 | path_volumes=path_vol, 131 | names_segmentation=path_segmentation_names, 132 | cropping=cropping, 133 | target_res=target_res, 134 | flip=flip, 135 | topology_classes=topology_classes, 136 | sigma_smoothing=sigma_smoothing, 137 | keep_biggest_component=keep_biggest_component, 138 | n_levels=n_levels, 139 | nb_conv_per_level=nb_conv_per_level, 140 | conv_size=conv_size, 141 | unet_feat_count=unet_feat_count, 142 | feat_multiplier=feat_multiplier, 143 | activation=activation, 144 | gt_folder=gt_folder, 145 | compute_distances=compute_distances) 146 | -------------------------------------------------------------------------------- /scripts/tutorials/5-generation_advanced.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This script shows how to generate synthetic images with narrowed intensity distributions (e.g. T1-weighted scans) and 4 | at a specific resolution. All the arguments shown here can be used in the training function. 5 | These parameters were not explained in the previous tutorials as they were not used for the training of SynthSeg. 6 | 7 | Specifically, this script generates 5 examples of training data simulating 3mm axial T1 scans, which have been resampled 8 | at 1mm resolution to be segmented. 9 | Contrast-specificity is achieved by now imposing Gaussian priors (instead of uniform) over the GMM parameters. 10 | Resolution-specificity is achieved by first blurring and downsampling to the simulated LR. The data will then be 11 | upsampled back to HR, so that the downstream network is trained to segment at HR. This upsampling step mimics the 12 | process that will happen at test time. 13 | 14 | 15 | 16 | If you use this code, please cite one of the SynthSeg papers: 17 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 18 | 19 | Copyright 2020 Benjamin Billot 20 | 21 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 22 | compliance with the License. You may obtain a copy of the License at 23 | https://www.apache.org/licenses/LICENSE-2.0 24 | Unless required by applicable law or agreed to in writing, software distributed under the License is 25 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 26 | implied. See the License for the specific language governing permissions and limitations under the 27 | License. 28 | """ 29 | 30 | 31 | import os 32 | import numpy as np 33 | from ext.lab2im import utils 34 | from SynthSeg.brain_generator import BrainGenerator 35 | 36 | # script parameters 37 | n_examples = 5 # number of examples to generate in this script 38 | result_dir = './outputs_tutorial_5' # folder where examples will be saved 39 | 40 | 41 | # path training label maps 42 | path_label_map = '../../data/training_label_maps' 43 | generation_labels = '../../data/labels_classes_priors/generation_labels.npy' 44 | output_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy' 45 | n_neutral_labels = 18 46 | output_shape = 160 47 | 48 | 49 | # ---------- GMM sampling parameters ---------- 50 | 51 | # Here we use Gaussian priors to control the means and standard deviations of the GMM. 52 | prior_distributions = 'normal' 53 | 54 | # Here we still regroup labels into classes of similar tissue types: 55 | # Example: (continuing the example of tutorial 1) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57] 56 | # generation_classes = [0, 1, 2, 3, 4, 5, 4, 6, 3, 4, 5, 4, 6] 57 | # Note that structures with right/left labels are now associated with the same class. 58 | generation_classes = '../../data/labels_classes_priors/generation_classes_contrast_specific.npy' 59 | 60 | # We specify here the hyperparameters governing the prior distribution of the GMM. 61 | # As these prior distributions are Gaussian, they are each controlled by a mean and a standard deviation. 62 | # Therefore, the numpy array pointed by prior_means is of size (2, K), where K is the total number of classes specified 63 | # in generation_classes. The first row of prior_means correspond to the means of the Gaussian priors, and the second row 64 | # correspond to standard deviations. 65 | # 66 | # Example: (continuing the previous one) prior_means = np.array([[0, 30, 80, 110, 95, 40, 70] 67 | # [0, 10, 50, 15, 10, 15, 30]]) 68 | # This means that intensities of label 3 and 17, which are both in class 4, will be drawn from the Gaussian 69 | # distribution, whose mean will be sampled from the Gaussian distribution with index 4 in prior_means N(95, 10). 70 | # Here is the complete table of correspondence for this example: 71 | # mean of Gaussian for label 0 drawn from N(0,0)=0 72 | # mean of Gaussian for label 24 drawn from N(30,10) 73 | # mean of Gaussian for label 507 drawn from N(80,50) 74 | # mean of Gaussian for labels 2 and 41 drawn from N(110,15) 75 | # mean of Gaussian for labels 3, 17, 42, 53 drawn from N(95,10) 76 | # mean of Gaussian for labels 4 and 43 drawn from N(40,15) 77 | # mean of Gaussian for labels 25 and 57 drawn from N(70,30) 78 | # These hyperparameters were estimated with the function SynthSR/estimate_priors.py/build_intensity_stats() 79 | prior_means = '../../data/labels_classes_priors/prior_means_t1.npy' 80 | # same as for prior_means, but for the standard deviations of the GMM. 81 | prior_stds = '../../data/labels_classes_priors/prior_stds_t1.npy' 82 | 83 | # ---------- Resolution parameters ---------- 84 | 85 | # here we aim to synthesise data at a specific resolution, thus we do not randomise it anymore ! 86 | randomise_res = False 87 | 88 | # blurring/downsampling parameters 89 | # We specify here the slice spacing/thickness that we want the synthetic scans to mimic. The axes refer to the *RAS* 90 | # axes, as all the provided data (label maps and images) will be automatically aligned to those axes during training. 91 | # RAS refers to Right-left/Anterior-posterior/Superior-inferior axes, i.e. sagittal/coronal/axial directions. 92 | data_res = np.array([1., 1., 3.]) # slice spacing i.e. resolution to mimic 93 | thickness = np.array([1., 1., 3.]) # slice thickness 94 | 95 | # ------------------------------------------------------ Generate ------------------------------------------------------ 96 | 97 | # instantiate BrainGenerator object 98 | brain_generator = BrainGenerator(labels_dir=path_label_map, 99 | generation_labels=generation_labels, 100 | output_labels=output_labels, 101 | n_neutral_labels=n_neutral_labels, 102 | output_shape=output_shape, 103 | prior_distributions=prior_distributions, 104 | generation_classes=generation_classes, 105 | prior_means=prior_means, 106 | prior_stds=prior_stds, 107 | randomise_res=randomise_res, 108 | data_res=data_res, 109 | thickness=thickness) 110 | 111 | for n in range(n_examples): 112 | 113 | # generate new image and corresponding labels 114 | im, lab = brain_generator.generate_brain() 115 | 116 | # save output image and label map 117 | utils.save_volume(im, brain_generator.aff, brain_generator.header, 118 | os.path.join(result_dir, 'image_t1_%s.nii.gz' % n)) 119 | utils.save_volume(lab, brain_generator.aff, brain_generator.header, 120 | os.path.join(result_dir, 'labels_t1_%s.nii.gz' % n)) 121 | -------------------------------------------------------------------------------- /scripts/tutorials/6-intensity_estimation.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Examples to show how to estimate of the hyperparameters governing the GMM prior distributions. 4 | This in the case where you want to train contrast-specific versions of SynthSeg. 5 | Beware, if you do so, your model will not be able to segment any contrast at test time ! 6 | We do not provide example images and associated label maps, so do not try to run this directly ! 7 | 8 | 9 | 10 | 11 | If you use this code, please cite one of the SynthSeg papers: 12 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 13 | 14 | Copyright 2020 Benjamin Billot 15 | 16 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 17 | compliance with the License. You may obtain a copy of the License at 18 | https://www.apache.org/licenses/LICENSE-2.0 19 | Unless required by applicable law or agreed to in writing, software distributed under the License is 20 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 21 | implied. See the License for the specific language governing permissions and limitations under the 22 | License. 23 | """ 24 | 25 | 26 | from SynthSeg.estimate_priors import build_intensity_stats 27 | 28 | # ----------------------------------------------- simple uni-modal case ------------------------------------------------ 29 | 30 | # paths of directories containing the images and corresponding label maps 31 | image_dir = '/image_folder/t1' 32 | labels_dir = '/labels_folder' 33 | # list of labels from which we want to evaluate the GMM prior distributions 34 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy' 35 | # path of folder where to write estimated priors 36 | result_dir = './outputs_tutorial_6/t1_priors' 37 | 38 | build_intensity_stats(list_image_dir=image_dir, 39 | list_labels_dir=labels_dir, 40 | estimation_labels=estimation_labels, 41 | result_dir=result_dir, 42 | rescale=True) 43 | 44 | # ------------------------------------ building Gaussian priors from several labels ------------------------------------ 45 | 46 | # same as before 47 | image_dir = '/image_folder/t1' 48 | labels_dir = '/labels_folder' 49 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy' 50 | result_dir = './outputs_tutorial_6/estimated_t1_priors_classes' 51 | 52 | # In the previous example, each label value is used to build the priors of a single Gaussian distribution. 53 | # We show here how to build Gaussian priors from intensities associated to several label values. For example, that could 54 | # be building the Gaussian prior of white matter by using the labels of right and left white matter. 55 | # This is done by specifying a vector, which regroups label values into "classes". 56 | # Labels sharing the same class will contribute to the construction of the same Gaussian prior. 57 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy' 58 | 59 | build_intensity_stats(list_image_dir=image_dir, 60 | list_labels_dir=labels_dir, 61 | estimation_labels=estimation_labels, 62 | estimation_classes=estimation_classes, 63 | result_dir=result_dir, 64 | rescale=True) 65 | 66 | # ---------------------------------------------- simple multi-modal case ----------------------------------------------- 67 | 68 | # Here we have multi-modal images, where every image contains all channels. 69 | # Channels are supposed to be sorted in the same order for all subjects. 70 | image_dir = '/image_folder/multi-modal_t1_t2' 71 | 72 | # same as before 73 | labels_dir = '/labels_folder' 74 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy' 75 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy' 76 | result_dir = './outputs_tutorial_6/estimated_priors_multi_modal' 77 | 78 | build_intensity_stats(list_image_dir=image_dir, 79 | list_labels_dir=labels_dir, 80 | estimation_labels=estimation_labels, 81 | estimation_classes=estimation_classes, 82 | result_dir=result_dir, 83 | rescale=True) 84 | 85 | # ------------------------------------- multi-modal images with separate channels ------------------------------------- 86 | 87 | # Here we have multi-modal images, where the different channels are stored in separate directories. 88 | # We provide the these different directories as a list. 89 | list_image_dir = ['/image_folder/t1', '/image_folder/t2'] 90 | # In this example, we assume that channels are registered and at the same resolutions. 91 | # Therefore we can use the same label maps for all channels. 92 | labels_dir = '/labels_folder' 93 | 94 | # same as before 95 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy' 96 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy' 97 | result_dir = './outputs_tutorial_6/estimated_priors_multi_modal' 98 | 99 | build_intensity_stats(list_image_dir=list_image_dir, 100 | list_labels_dir=labels_dir, 101 | estimation_labels=estimation_labels, 102 | estimation_classes=estimation_classes, 103 | result_dir=result_dir, 104 | rescale=True) 105 | 106 | # ------------------------------------ multi-modal case with unregistered channels ------------------------------------- 107 | 108 | # Again, we have multi-modal images where the different channels are stored in separate directories. 109 | list_image_dir = ['/image_folder/t1', '/image_folder/t2'] 110 | # In this example, we assume that the channels are no longer registered. 111 | # Therefore we cannot use the same label maps for all channels, and must provide label maps for all modalities. 112 | labels_dir = ['/labels_folder/t1', '/labels_folder/t2'] 113 | 114 | # same as before 115 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy' 116 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy' 117 | result_dir = './outputs_tutorial_6/estimated_unregistered_multi_modal' 118 | 119 | build_intensity_stats(list_image_dir=list_image_dir, 120 | list_labels_dir=labels_dir, 121 | estimation_labels=estimation_labels, 122 | estimation_classes=estimation_classes, 123 | result_dir=result_dir, 124 | rescale=True) 125 | -------------------------------------------------------------------------------- /scripts/tutorials/7-synthseg+.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Very simple script to show how we trained SynthSeg+, which extends SynthSeg by building robustness to clinical 4 | acquisitions. For more details, please look at our MICCAI 2022 paper: 5 | 6 | Robust Segmentation of Brain MRI in the Wild with Hierarchical CNNs and no Retraining, 7 | Billot, Magdamo, Das, Arnold, Iglesias 8 | MICCAI 2022 9 | 10 | If you use this code, please cite one of the SynthSeg papers: 11 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib 12 | 13 | Copyright 2020 Benjamin Billot 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 16 | compliance with the License. You may obtain a copy of the License at 17 | https://www.apache.org/licenses/LICENSE-2.0 18 | Unless required by applicable law or agreed to in writing, software distributed under the License is 19 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 20 | implied. See the License for the specific language governing permissions and limitations under the 21 | License. 22 | """ 23 | from SynthSeg.training import training as training_s1 24 | from SynthSeg.training_denoiser import training as training_d 25 | from SynthSeg.training_group import training as training_s2 26 | 27 | import numpy as np 28 | 29 | # ------------------ segmenter S1 30 | # Here the purpose is to train a first network to produce preliminary segmentations of input scans with five general 31 | # labels: 0-background, 1-white matter, 2-grey matter, 3-fluids, 4-cerebellum. 32 | 33 | # As in tutorial 3, S1 is trained with synthetic images with randomised contrasts/resolution/artefacts such that it can 34 | # readily segment a wide range of test scans without retraining. The synthetic scans are obtained from the same label 35 | # maps and generative model as in the previous tutorials. 36 | labels_dir_s1 = '../../data/training_label_maps' 37 | path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy' 38 | path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy' 39 | # However, because we now wish to segment scans using only five labels, we use a different list of segmentation labels 40 | # where all label values in generation_labels are assigned to a target value between [0, 4]. 41 | path_segmentation_labels_s1 = '../../data/tutorial_7/segmentation_labels_s1.npy' 42 | 43 | model_dir_s1 = './outputs_tutorial_7/training_s1' # folder where the models will be saved 44 | 45 | 46 | training_s1(labels_dir=labels_dir_s1, 47 | model_dir=model_dir_s1, 48 | generation_labels=path_generation_labels, 49 | segmentation_labels=path_segmentation_labels_s1, 50 | n_neutral_labels=18, 51 | generation_classes=path_generation_classes, 52 | target_res=1, 53 | output_shape=160, 54 | prior_distributions='uniform', 55 | prior_means=[0, 255], 56 | prior_stds=[0, 50], 57 | randomise_res=True) 58 | 59 | # ------------------ denoiser D 60 | # The purpose of this network is to perform label-to-label correction in order to correct potential mistakes made by S1 61 | # at test time. Therefore, D is trained with two sets of label maps: noisy segmentations from S1 (used as inputs to D), 62 | # and their corresponding ground truth (used as target to train D). In order to obtain input segmentations 63 | # representative of the mistakes of S1, these are obtained by degrading real images with extreme augmentation (spatial, 64 | # intensity, resolution, etc.), and feeding them to S1. 65 | 66 | # Obtaining the input/target segmentations is done offline by using the following function: sample_segmentation_pairs.py 67 | # In practice we sample a lot of them (i.e. 10,000), but we give here 8 example pairs. Note that these segmentations 68 | # have the same label values as the output of S1 (i.e. between [0, 4]). 69 | list_input_labels = ['../../data/tutorial_7/noisy_segmentations_d/0001.nii.gz', 70 | '../../data/tutorial_7/noisy_segmentations_d/0002.nii.gz', 71 | '../../data/tutorial_7/noisy_segmentations_d/0003.nii.gz'] 72 | list_target_labels = ['../../data/tutorial_7/target_segmentations_d/0001.nii.gz', 73 | '../../data/tutorial_7/target_segmentations_d/0002.nii.gz', 74 | '../../data/tutorial_7/target_segmentations_d/0003.nii.gz'] 75 | 76 | # Moreover, we perform spatial augmentation on the sampled pairs, in order to further increase the morphological 77 | # variability seen by the network. Furthermore, the input "noisy" segmentations are further augmented with random 78 | # erosion/dilation: 79 | prob_erosion_dilation = 0.3 # probability of performing random erosion/dilation 80 | min_erosion_dilation = 4, # minimum coefficient for erosion/dilation 81 | max_erosion_dilation = 5 # maximum coefficient for erosion/dilation 82 | 83 | # This is the list of label values included in the input/GT label maps. This list must contain unique values. 84 | input_segmentation_labels = np.array([0, 1, 2, 3, 4]) 85 | 86 | model_dir_d = './outputs_tutorial_7/training_d' # folder where the models will be saved 87 | 88 | training_d(list_paths_input_labels=list_input_labels, 89 | list_paths_target_labels=list_target_labels, 90 | model_dir=model_dir_d, 91 | input_segmentation_labels=input_segmentation_labels, 92 | output_shape=160, 93 | prob_erosion_dilation=prob_erosion_dilation, 94 | min_erosion_dilation=min_erosion_dilation, 95 | max_erosion_dilation=max_erosion_dilation, 96 | conv_size=5, 97 | unet_feat_count=16, 98 | skip_n_concatenations=2) 99 | 100 | # ------------------ segmenter S2 101 | # Final segmentations are obtained with a last segmenter S2, which takes as inputs an image as well as the preliminary 102 | # segmentations of S1 that are corrected by D. 103 | 104 | # Here S2 is trained with synthetic images sampled from the usual training label maps with associated generation labels, 105 | # classes. Also, we now use the same segmentation labels as in tutorials 2, 3, and 4, as we now segment all the usual 106 | # regions. 107 | labels_dir_s2 = '../../data/training_label_maps' # these are the same as for S1 108 | path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy' 109 | path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy' 110 | path_segmentation_labels_s2 = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy' 111 | 112 | # The preliminary segmentations are given as soft probability maps and are directly derived from the ground truth. 113 | # Specifically, we take the structures that were segmented by S1, and regroup them into the same "groups" as before. 114 | grouping_labels = '../../data/tutorial_7/segmentation_labels_s1.npy' 115 | # However, in order to simulate test-time imperfections made by D, we these soft probability maps are slightly 116 | # augmented with spatial transforms, and sometimes undergo a random dilation/erosion. 117 | 118 | model_dir_s2 = './outputs_tutorial_7/training_s2' # folder where the models will be saved 119 | 120 | training_s2(labels_dir=labels_dir_s2, 121 | model_dir=model_dir_s2, 122 | generation_labels=path_generation_labels, 123 | n_neutral_labels=18, 124 | segmentation_labels=path_segmentation_labels_s2, 125 | generation_classes=path_generation_classes, 126 | grouping_labels=grouping_labels, 127 | target_res=1, 128 | output_shape=160, 129 | prior_distributions='uniform', 130 | prior_means=[0, 255], 131 | prior_stds=[0, 50], 132 | randomise_res=True) 133 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import setuptools 5 | 6 | python_version = sys.version[:3] 7 | 8 | if (python_version != '3.6') & (python_version != '3.8'): 9 | raise Exception('Setup.py only works with python version 3.6 or 3.8, not {}'.format(python_version)) 10 | 11 | else: 12 | 13 | with open('requirements_python' + python_version + '.txt') as f: 14 | required_packages = [line.strip() for line in f.readlines()] 15 | 16 | print(setuptools.find_packages()) 17 | 18 | setuptools.setup(name='SynthSeg', 19 | version='2.0', 20 | license='Apache 2.0', 21 | description='Domain-agnostic segmentation of brain scans', 22 | author='Benjamin Billot', 23 | url='https://github.com/BBillot/SynthSeg', 24 | keywords=['segmentation', 'domain-agnostic', 'brain'], 25 | packages=setuptools.find_packages(), 26 | python_requires='>=3.6', 27 | install_requires=required_packages, 28 | include_package_data=True) 29 | --------------------------------------------------------------------------------