├── .gitignore
├── LICENSE.txt
├── README.md
├── SynthSeg
├── __init__.py
├── brain_generator.py
├── estimate_priors.py
├── evaluate.py
├── labels_to_image_model.py
├── metrics_model.py
├── model_inputs.py
├── predict.py
├── predict_denoiser.py
├── predict_group.py
├── predict_qc.py
├── predict_synthseg.py
├── sample_segmentation_pairs_d.py
├── training.py
├── training_denoiser.py
├── training_group.py
├── training_qc.py
├── training_supervised.py
├── validate.py
├── validate_denoiser.py
├── validate_group.py
└── validate_qc.py
├── bibtex.bib
├── data
├── README_figures
│ ├── new_features.png
│ ├── overview.png
│ ├── robust.png
│ ├── segmentations.png
│ ├── table_versions.png
│ ├── training_data.png
│ └── youtube_link.png
├── labels table.txt
├── labels_classes_priors
│ ├── generation_classes.npy
│ ├── generation_classes_contrast_specific.npy
│ ├── generation_labels.npy
│ ├── prior_means_t1.npy
│ ├── prior_stds_t1.npy
│ ├── synthseg_denoiser_labels_2.0.npy
│ ├── synthseg_parcellation_labels.npy
│ ├── synthseg_parcellation_names.npy
│ ├── synthseg_qc_labels.npy
│ ├── synthseg_qc_labels_2.0.npy
│ ├── synthseg_qc_names.npy
│ ├── synthseg_qc_names_2.0.npy
│ ├── synthseg_segmentation_labels.npy
│ ├── synthseg_segmentation_labels_2.0.npy
│ ├── synthseg_segmentation_names.npy
│ ├── synthseg_segmentation_names_2.0.npy
│ ├── synthseg_topological_classes.npy
│ └── synthseg_topological_classes_2.0.npy
├── training_label_maps
│ ├── training_seg_01.nii.gz
│ ├── training_seg_02.nii.gz
│ ├── training_seg_03.nii.gz
│ ├── training_seg_04.nii.gz
│ ├── training_seg_05.nii.gz
│ ├── training_seg_06.nii.gz
│ ├── training_seg_07.nii.gz
│ ├── training_seg_08.nii.gz
│ ├── training_seg_09.nii.gz
│ ├── training_seg_10.nii.gz
│ ├── training_seg_11.nii.gz
│ ├── training_seg_12.nii.gz
│ ├── training_seg_13.nii.gz
│ ├── training_seg_14.nii.gz
│ ├── training_seg_15.nii.gz
│ ├── training_seg_16.nii.gz
│ ├── training_seg_17.nii.gz
│ ├── training_seg_18.nii.gz
│ ├── training_seg_19.nii.gz
│ └── training_seg_20.nii.gz
└── tutorial_7
│ ├── noisy_segmentations_d
│ ├── 0001.nii.gz
│ ├── 0002.nii.gz
│ └── 0003.nii.gz
│ ├── segmentation_labels_s1.npy
│ └── target_segmentations_d
│ ├── 0001.nii.gz
│ ├── 0002.nii.gz
│ └── 0003.nii.gz
├── ext
├── __init__.py
├── lab2im
│ ├── __init__.py
│ ├── edit_tensors.py
│ ├── edit_volumes.py
│ ├── image_generator.py
│ ├── lab2im_model.py
│ ├── layers.py
│ └── utils.py
└── neuron
│ ├── __init__.py
│ ├── layers.py
│ ├── models.py
│ └── utils.py
├── models
└── synthseg_1.0.h5
├── requirements_python3.6.txt
├── requirements_python3.8.txt
├── scripts
├── commands
│ ├── SynthSeg_predict.py
│ ├── predict.py
│ ├── training.py
│ └── training_supervised.py
└── tutorials
│ ├── 1-generation_visualisation.py
│ ├── 2-generation_explained.py
│ ├── 3-training.py
│ ├── 4-prediction.py
│ ├── 5-generation_advanced.py
│ ├── 6-intensity_estimation.py
│ └── 7-synthseg+.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | dist
3 | SynthSeg.egg-info
4 | venv
5 | **/__pycache__/
6 | .idea
7 |
8 | ext/pynd
9 | ext/statannot
10 |
11 | data/labels_classes_priors/generation_classes_2.0.npy
12 | data/labels_classes_priors/generation_labels_2.0.npy
13 |
14 | models/synthseg_2.0.h5
15 | models/synthseg_parc_2.0.h5
16 | models/synthseg_qc_2.0.h5
17 | models/synthseg_robust_2.0.h5
18 |
19 | scripts/checks
20 | scripts/commands/SynthSeg_predict_claustrum.py
21 | scripts/previous_papers
22 | scripts/*.py
23 |
24 | SynthSeg/auc_roc_delong.py
25 | SynthSeg/boxplots.py
26 | SynthSeg/brain_generator_*
27 | SynthSeg/check_*
28 | SynthSeg/predict_synthseg_*
29 | SynthSeg/predict_vae.py
30 | SynthSeg/qc.py
31 | SynthSeg/training_heart.py
32 | SynthSeg/training_vae.py
33 | SynthSeg/validate_vae.py
34 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SynthSeg
2 |
3 |
4 | In this repository, we present SynthSeg, the first deep learning tool for segmentation of brain scans of
5 | any contrast and resolution. SynthSeg works out-of-the-box without any retraining, and is robust to:
6 | - any contrast
7 | - any resolution up to 10mm slice spacing
8 | - a wide array of populations: from young and healthy to ageing and diseased
9 | - scans with or without preprocessing: bias field correction, skull stripping, normalisation, etc.
10 | - white matter lesions.
11 | \
12 | \
13 | 
14 |
15 |
16 | \
17 | SynthSeg was first presented for the automated segmentation of brain scans of any contrast and resolution.
18 |
19 | **SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining** \
20 | B. Billot, D.N. Greve, O. Puonti, A. Thielscher, K. Van Leemput, B. Fischl, A.V. Dalca, J.E. Iglesias \
21 | Medical Image Analysis (2023) \
22 | [ [article](https://www.sciencedirect.com/science/article/pii/S1361841523000506) | [arxiv](https://arxiv.org/abs/2107.09559) | [bibtex](bibtex.bib) ]
23 | \
24 | \
25 | Then, we extended it to work on heterogeneous clinical scans, and to perform cortical parcellation and automated
26 | quality control.
27 |
28 | **Robust machine learning segmentation for large-scale analysis of heterogeneous clinical brain MRI datasets** \
29 | B. Billot, M. Colin, Y. Cheng, S.E. Arnold, S. Das, J.E. Iglesias \
30 | PNAS (2023) \
31 | [ [article](https://www.pnas.org/doi/full/10.1073/pnas.2216399120#bibliography) | [arxiv](https://arxiv.org/abs/2203.01969) | [bibtex](bibtex.bib) ]
32 |
33 | \
34 | Here, we distribute our model to enable users to run SynthSeg on their own data. We emphasise that
35 | predictions are always given at 1mm isotropic resolution (regardless of the input resolution). The code can be run on
36 | the GPU (~15s per scan) or on the CPU (~1min).
37 |
38 |
39 | ----------------
40 |
41 | ### New features and updates
42 |
43 | \
44 | 01/03/2023: **The papers for SynthSeg and SynthSeg 2.0 are out! :open_book: :open_book:** \
45 | After a long review process for SynthSeg (Medical Image Analysis), and a much faster one for SynthSeg 2.0 (PNAS), both
46 | papers have been accepted nearly at the same time ! See the references above, or in the citation section.
47 |
48 | \
49 | 04/10/2022: **SynthSeg is available with Matlab!** :star: \
50 | We are delighted that Matlab 2022b (and onwards) now includes SynthSeg in its Medical Image
51 | Toolbox. They have a [documented example](https://www.mathworks.com/help/medical-imaging/ug/Brain-MRI-Segmentation-Using-Trained-3-D-U-Net.html)
52 | on how to use it. But, to simplify things, we wrote our own Matlab wrapper, which you can call in one single line.
53 | Just download [this zip file](https://liveuclac-my.sharepoint.com/:u:/g/personal/rmappmb_ucl_ac_uk/EctEe3hOP8dDh1hYHlFS_rUBo80yFg7MQY5WnagHlWcS6A?e=e8bK0f),
54 | uncompress it, open Matlab, and type `help SynthSeg` for instructions.
55 |
56 | \
57 | 29/06/2022: **SynthSeg 2.0 is out !** :v: \
58 | In addition to whole-brain segmentation, it now also performs **Cortical parcellation, automated QC, and intracranial
59 | volume (ICV) estimation** (see figure below). Also, most of these features are compatible with SynthSeg 1.0. (see table).
60 | \
61 | \
62 | 
63 |
64 | 
65 |
66 | \
67 | 01/03/2022: **Robust version** :hammer: \
68 | SynthSeg sometimes falters on scans with low signal-to-noise ratio, or with very low tissue contrast. For this reason,
69 | we developed a new model for increased robustness, named "SynthSeg-robust". You can use this mode when SynthSeg gives
70 | results like in the figure below:
71 | \
72 | \
73 | 
74 |
75 | \
76 | 29/10/2021: **SynthSeg is now available on the dev version of
77 | [FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/DownloadAndInstall) !!** :tada: \
78 | See [here](https://surfer.nmr.mgh.harvard.edu/fswiki/SynthSeg) on how to use it.
79 |
80 | ----------------
81 |
82 | ### Try it in one command !
83 |
84 | Once all the python packages are installed (see below), you can simply test SynthSeg on your own data with:
85 | ```
86 | python ./scripts/commands/SynthSeg_predict.py --i --o [--parc --robust --ct --vol --qc --post --resample ]
87 | ```
88 |
89 |
90 | where:
91 | - ` ` path to a scan to segment, or to a folder. This can also be the path to a text file, where each line is the
92 | path of an image to segment.
93 | - `` path where the output segmentations will be saved. This must be the same type as ` ` (i.e., the path
94 | to a file, a folder, or a text file where each line is the path to an output segmentation).
95 | - `--parc` (optional) to perform cortical parcellation in addition to whole-brain segmentation.
96 | - `--robust` (optional) to use the variant for increased robustness (e.g., when analysing clinical data with large space
97 | spacing). This can be slower than the other model.
98 | - `--ct` (optional) use on CT scans in Hounsfield scale. It clips intensities to [0, 80].
99 | - `` (optional) path to a CSV file where the volumes (in mm3 ) of all segmented regions will be saved for all scans
100 | (e.g. /path/to/volumes.csv). If ` ` is a text file, so must be ``, for which each line is the path to a
101 | different CSV file corresponding to one subject only.
102 | - `` (optional) path to a CSV file where QC scores will be saved. The same formatting requirements as `` apply.
103 | - `` (optional) path where the posteriors, given as soft probability maps, will be saved (same formatting
104 | requirements as for ``).
105 | - `` (optional) SynthSeg segmentations are always given at 1mm isotropic resolution. Hence,
106 | images are always resampled internally to this resolution (except if they are already at 1mm resolution).
107 | Use this flag to save the resampled images (same formatting requirements as for ``).
108 |
109 | Additional optional flags are also available:
110 | - `--cpu`: (optional) to enforce the code to run on the CPU, even if a GPU is available.
111 | - `--threads`: (optional) number of threads to be used by Tensorflow (default uses one core). Increase it to decrease
112 | the runtime when using the CPU version.
113 | - `--crop`: (optional) to crop the input images to a given shape before segmentation. This must be divisible by 32.
114 | Images are cropped around their centre, and their segmentations are given at the original size. It can be given as a
115 | single (i.e., `--crop 160`), or several integers (i.e, `--crop 160 128 192`, ordered in RAS coordinates). By default the
116 | whole image is processed. Use this flag for faster analysis or to fit in your GPU.
117 | - `--fast`: (optional) to disable some operations for faster prediction (twice as fast, but slightly less accurate).
118 | This doesn't apply when the --robust flag is used.
119 | - `--v1`: (optional) to run the first version of SynthSeg (SynthSeg 1.0, updated 29/06/2022).
120 |
121 |
122 | **IMPORTANT:** SynthSeg always give results at 1mm isotropic resolution, regardless of the input. However, this can
123 | cause some viewers to not correctly overlay segmentations on their corresponding images. In this case, you can use the
124 | `--resample` flag to obtain a resampled image that lives in the same space as the segmentation, such that they can be
125 | visualised together with any viewer.
126 |
127 | The complete list of segmented structures is available in [labels table.txt](data/labels%20table.txt) along with their
128 | corresponding values. This table also details the order in which the posteriors maps are sorted.
129 |
130 |
131 | ----------------
132 |
133 | ### Installation
134 |
135 | 1. Clone this repository.
136 |
137 | 2. Create a virtual environment (i.e., with pip or conda) and install all the required packages. \
138 | These depend on your python version, and here we list the requirements for Python 3.6
139 | ([requirements_3.6](requirements_python3.6.txt)) and Python 3.8 (see [requirements_3.8](requirements_python3.8.txt)).
140 | The choice is yours, but in each case, please stick to the exact package versions.\
141 | A first solution to install the dependencies, if you use pip, is to run setup.py (with and activated virtual
142 | environment): `python setup.py install`. Otherwise, we also give here the minimal commands to install the required
143 | packages using pip/conda for Python 3.6/3.8.
144 |
145 | ```
146 | # Conda, Python 3.6:
147 | conda create -n synthseg_36 python=3.6 tensorflow-gpu=2.0.0 keras=2.3.1 h5py==2.10.0 nibabel matplotlib -c anaconda -c conda-forge
148 |
149 | # Conda, Python 3.8:
150 | conda create -n synthseg_38 python=3.8 tensorflow-gpu=2.2.0 keras=2.3.1 nibabel matplotlib -c anaconda -c conda-forge
151 |
152 | # Pip, Python 3.6:
153 | pip install tensorflow-gpu==2.0.0 keras==2.3.1 nibabel==3.2.2 matplotlib==3.3.4
154 |
155 | # Pip, Python 3.8:
156 | pip install tensorflow-gpu==2.2.0 keras==2.3.1 protobuf==3.20.3 numpy==1.23.5 nibabel==5.0.1 matplotlib==3.6.2
157 | ```
158 |
159 | 3. Go to this link [UCL dropbox](https://liveuclac-my.sharepoint.com/:f:/g/personal/rmappmb_ucl_ac_uk/EtlNnulBSUtAvOP6S99KcAIBYzze7jTPsmFk2_iHqKDjEw?e=rBP0RO), and download the missing models. Then simply copy them to [models](models).
160 |
161 | 4. If you wish to run on the GPU, you will also need to install Cuda (10.0 for Python 3.6, 10.1 for Python 3.8), and
162 | CUDNN (7.6.5 for both). Note that if you used conda, these were already automatically installed.
163 |
164 | That's it ! You're now ready to use SynthSeg ! :tada:
165 |
166 |
167 | ----------------
168 |
169 | ### How does it work ?
170 |
171 | In short, we train a network with synthetic images sampled on the fly from a generative model based on the forward
172 | model of Bayesian segmentation. Crucially, we adopt a domain randomisation strategy where we fully randomise the
173 | generation parameters which are drawn at each minibatch from uninformative uniform priors. By exposing the network to
174 | extremely variable input data, we force it to learn domain-agnostic features. As a result, SynthSeg is able to readily
175 | segment real scans of any target domain, without retraining or fine-tuning.
176 |
177 | The following figure first illustrates the workflow of a training iteration, and then provides an overview of the
178 | different steps of the generative model:
179 | \
180 | \
181 | 
182 | \
183 | \
184 | Finally we show additional examples of the synthesised images along with an overlay of their target segmentations:
185 | \
186 | \
187 | 
188 | \
189 | \
190 | If you are interested to learn more about SynthSeg, you can read the associated publication (see below), and watch this
191 | presentation, which was given at MIDL 2020 for a related article on a preliminary version of SynthSeg (robustness to
192 | MR contrast but not resolution).
193 | \
194 | \
195 | [](https://www.youtube.com/watch?v=Bfp3cILSKZg&t=1s)
196 |
197 |
198 | ----------------
199 |
200 | ### Train your own model
201 |
202 | This repository contains all the code and data necessary to train, validate, and test your own network. Importantly, the
203 | proposed method only requires a set of anatomical segmentations to be trained (no images), which we include in
204 | [data](data/training_label_maps). While the provided functions are thoroughly documented, we highly recommend to start
205 | with the following tutorials:
206 |
207 | - [1-generation_visualisation](scripts/tutorials/1-generation_visualisation.py): This very simple script shows examples
208 | of the synthetic images used to train SynthSeg.
209 |
210 | - [2-generation_explained](scripts/tutorials/2-generation_explained.py): This second script describes all the parameters
211 | used to control the generative model. We advise you to thoroughly follow this tutorial, as it is essential to understand
212 | how the synthetic data is formed before you start training your own models.
213 |
214 | - [3-training](scripts/tutorials/3-training.py): This scripts re-uses the parameters explained in the previous tutorial
215 | and focuses on the learning/architecture parameters. The script here is the very one we used to train SynthSeg !
216 |
217 | - [4-prediction](scripts/tutorials/4-prediction.py): This scripts shows how to make predictions, once the network has
218 | been trained.
219 |
220 | - [5-generation_advanced](scripts/tutorials/5-generation_advanced.py): Here we detail more advanced generation options,
221 | in the case of training a version of SynthSeg that is specific to a given contrast and/or resolution (although these
222 | types of variants were shown to be outperformed by the SynthSeg model trained in the 3rd tutorial).
223 |
224 | - [6-intensity_estimation](scripts/tutorials/6-intensity_estimation.py): This script shows how to estimate the
225 | Gaussian priors of the GMM when training a contrast-specific version of SynthSeg.
226 |
227 | - [7-synthseg+](scripts/tutorials/7-synthseg+.py): Finally, we show how the robust version of SynthSeg was
228 | trained.
229 |
230 | These tutorials cover a lot of materials and will enable you to train your own SynthSeg model. Moreover, even more
231 | detailed information is provided in the docstrings of all functions, so don't hesitate to have a look at these !
232 |
233 |
234 | ----------------
235 |
236 | ### Content
237 |
238 | - [SynthSeg](SynthSeg): this is the main folder containing the generative model and training function:
239 |
240 | - [labels_to_image_model.py](SynthSeg/labels_to_image_model.py): contains the generative model for MRI scans.
241 |
242 | - [brain_generator.py](SynthSeg/brain_generator.py): contains the class `BrainGenerator`, which is a wrapper around
243 | `labels_to_image_model`. New images can simply be generated by instantiating an object of this class, and call the
244 | method `generate_image()`.
245 |
246 | - [training.py](SynthSeg/training.py): contains code to train the segmentation network (with explanations for all
247 | training parameters). This function also shows how to integrate the generative model in a training setting.
248 |
249 | - [predict.py](SynthSeg/predict.py): prediction and testing.
250 |
251 | - [validate.py](SynthSeg/validate.py): includes code for validation (which has to be done offline on real images).
252 |
253 | - [models](models): this is where you will find the trained model for SynthSeg.
254 |
255 | - [data](data): this folder contains some examples of brain label maps if you wish to train your own SynthSeg model.
256 |
257 | - [script](scripts): contains tutorials as well as scripts to launch trainings and testings from a terminal.
258 |
259 | - [ext](ext): includes external packages, especially the *lab2im* package, and a modified version of *neuron*.
260 |
261 |
262 | ----------------
263 |
264 | ### Citation/Contact
265 |
266 | This code is under [Apache 2.0](LICENSE.txt) licensing.
267 |
268 | - If you use the **cortical parcellation**, **automated QC**, or **robust version**, please cite the following paper:
269 |
270 | **Robust machine learning segmentation for large-scale analysisof heterogeneous clinical brain MRI datasets** \
271 | B. Billot, M. Colin, Y. Cheng, S.E. Arnold, S. Das, J.E. Iglesias \
272 | PNAS (2023) \
273 | [ [article](https://www.pnas.org/doi/full/10.1073/pnas.2216399120#bibliography) | [arxiv](https://arxiv.org/abs/2203.01969) | [bibtex](bibtex.bib) ]
274 |
275 |
276 | - Otherwise, please cite:
277 |
278 | **SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining** \
279 | B. Billot, D.N. Greve, O. Puonti, A. Thielscher, K. Van Leemput, B. Fischl, A.V. Dalca, J.E. Iglesias \
280 | Medical Image Analysis (2023) \
281 | [ [article](https://www.sciencedirect.com/science/article/pii/S1361841523000506) | [arxiv](https://arxiv.org/abs/2107.09559) | [bibtex](bibtex.bib) ]
282 |
283 | If you have any question regarding the usage of this code, or any suggestions to improve it, please raise an issue or
284 | contact us at: bbillot@mit.edu
285 |
--------------------------------------------------------------------------------
/SynthSeg/__init__.py:
--------------------------------------------------------------------------------
1 | from . import brain_generator
2 | from . import estimate_priors
3 | from . import evaluate
4 | from . import labels_to_image_model
5 | from . import metrics_model
6 | from . import model_inputs
7 | from . import predict
8 | from . import training_supervised
9 | from . import training
10 |
--------------------------------------------------------------------------------
/SynthSeg/estimate_priors.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import os
19 | import numpy as np
20 | try:
21 | from scipy.stats import median_absolute_deviation
22 | except ImportError:
23 | from scipy.stats import median_abs_deviation as median_absolute_deviation
24 |
25 |
26 | # third-party imports
27 | from ext.lab2im import utils
28 | from ext.lab2im import edit_volumes
29 |
30 |
31 | def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True):
32 | """This function takes an image and corresponding segmentation as inputs. It estimates the mean and std intensity
33 | for all specified label values. Labels can share the same statistics by being regrouped into K classes.
34 | :param image: image from which to evaluate mean intensity and std deviation.
35 | :param segmentation: segmentation of the input image. Must have the same size as image.
36 | :param labels_list: list of labels for which to evaluate mean and std intensity.
37 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
38 | :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics.
39 | Intensities associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
40 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
41 | It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
42 | classes. Default is all labels have different classes (K=len(labels_list)).
43 | :param keep_strictly_positive: (optional) whether to only keep strictly positive intensity values when
44 | computing stats. This doesn't apply to the first label in label_list (or class if class_list is provided), for
45 | which we keep positive and zero values, as we consider it to be the background label.
46 | :return: a numpy array of size (2, K), the first row being the mean intensity for each structure,
47 | and the second being the median absolute deviation (robust estimation of std).
48 | """
49 |
50 | # reformat labels and classes
51 | labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
52 | if classes_list is not None:
53 | classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
54 | else:
55 | classes_list = np.arange(labels_list.shape[0])
56 | assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'
57 |
58 | # get unique classes
59 | unique_classes, unique_indices = np.unique(classes_list, return_index=True)
60 | n_classes = len(unique_classes)
61 | if not np.array_equal(unique_classes, np.arange(n_classes)):
62 | raise ValueError('classes_list should only contain values between 0 and K-1, '
63 | 'where K is the total number of classes. Here K = %d' % n_classes)
64 |
65 | # compute mean/std of specified classes
66 | means = np.zeros(n_classes)
67 | stds = np.zeros(n_classes)
68 | for idx, tmp_class in enumerate(unique_classes):
69 |
70 | # get list of all intensity values for the current class
71 | class_labels = labels_list[classes_list == tmp_class]
72 | intensities = np.empty(0)
73 | for label in class_labels:
74 | tmp_intensities = image[segmentation == label]
75 | intensities = np.concatenate([intensities, tmp_intensities])
76 | if tmp_class: # i.e. if not background
77 | if keep_strictly_positive:
78 | intensities = intensities[intensities > 0]
79 |
80 | # compute stats for class and put them to the location of corresponding label values
81 | if len(intensities) != 0:
82 | means[idx] = np.nanmedian(intensities)
83 | stds[idx] = median_absolute_deviation(intensities, nan_policy='omit')
84 |
85 | return np.stack([means, stds])
86 |
87 |
88 | def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3,
89 | rescale=True):
90 | """This function aims at estimating the intensity distributions of K different structure types from a set of images.
91 | The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
92 | Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the
93 | parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described
94 | by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation.
95 | This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters.
96 | Structures can share the same statistics by being regrouped into classes of similar structure types.
97 | Images can be multi-modal (n_channels), in which case different statistics are estimated for each modality.
98 | :param image_dir: path of directory with images to estimate the intensity distribution
99 | :param labels_dir: path of directory with segmentation of input images.
100 | They are matched with images by sorting order.
101 | :param labels_list: list of labels for which to evaluate mean and std intensity.
102 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
103 | :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics.
104 | Intensities associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
105 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
106 | It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
107 | classes. Default is all labels have different classes (K=len(labels_list)).
108 | :param max_channel: (optional) maximum number of channels to consider if the data is multi-spectral. Default is 3.
109 | :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation
110 | :return: 2 numpy arrays of size (2*n_channels, K), one with the evaluated means/std for the mean
111 | intensity, and one for the mean/std for the standard deviation.
112 | Each block of two rows correspond to a different modality (channel). For each block of two rows, the first row
113 | represents the mean, and the second represents the std.
114 | """
115 |
116 | # list files
117 | path_images = utils.list_images_in_folder(image_dir)
118 | path_labels = utils.list_images_in_folder(labels_dir)
119 | assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files'
120 |
121 | # reformat list labels and classes
122 | labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
123 | if classes_list is not None:
124 | classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
125 | else:
126 | classes_list = np.arange(labels_list.shape[0])
127 | assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'
128 |
129 | # get unique classes
130 | unique_classes, unique_indices = np.unique(classes_list, return_index=True)
131 | n_classes = len(unique_classes)
132 | if not np.array_equal(unique_classes, np.arange(n_classes)):
133 | raise ValueError('classes_list should only contain values between 0 and K-1, '
134 | 'where K is the total number of classes. Here K = %d' % n_classes)
135 |
136 | # initialise result arrays
137 | n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel)
138 | means = np.zeros((len(path_images), n_classes, n_channels))
139 | stds = np.zeros((len(path_images), n_classes, n_channels))
140 |
141 | # loop over images
142 | loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True)
143 | for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)):
144 | loop_info.update(idx)
145 |
146 | # load image and label map
147 | image = utils.load_volume(path_im)
148 | la = utils.load_volume(path_la)
149 | if n_channels == 1:
150 | image = utils.add_axis(image, -1)
151 |
152 | # loop over channels
153 | for channel in range(n_channels):
154 | im = image[..., channel]
155 | if rescale:
156 | im = edit_volumes.rescale_volume(im)
157 | stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list)
158 | means[idx, :, channel] = stats[0, :]
159 | stds[idx, :, channel] = stats[1, :]
160 |
161 | # compute prior parameters for mean/std
162 | mean_means = np.mean(means, axis=0)
163 | std_means = np.std(means, axis=0)
164 | mean_stds = np.mean(stds, axis=0)
165 | std_stds = np.std(stds, axis=0)
166 |
167 | # regroup prior parameters in two different arrays: one for the mean and one for the std
168 | prior_means = np.zeros((2 * n_channels, n_classes))
169 | prior_stds = np.zeros((2 * n_channels, n_classes))
170 | for channel in range(n_channels):
171 | prior_means[2 * channel, :] = mean_means[:, channel]
172 | prior_means[2 * channel + 1, :] = std_means[:, channel]
173 | prior_stds[2 * channel, :] = mean_stds[:, channel]
174 | prior_stds[2 * channel + 1, :] = std_stds[:, channel]
175 |
176 | return prior_means, prior_stds
177 |
178 |
179 | def build_intensity_stats(list_image_dir,
180 | list_labels_dir,
181 | result_dir,
182 | estimation_labels,
183 | estimation_classes=None,
184 | max_channel=3,
185 | rescale=True):
186 | """This function aims at estimating the intensity distributions of K different structure types from a set of images.
187 | The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
188 | Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the
189 | parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described
190 | by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation.
191 | This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters.
192 | Additionally, it can estimate the 4*K parameters for several image datasets, that we call here n_datasets.
193 | This function writes 2 numpy arrays of size (2*n_datasets, K), one with the evaluated means/std for the mean
194 | intensities, and one for the mean/std for the standard deviations.
195 | In these arrays, each block of two rows refer to a different dataset.
196 | Within each block of two rows, the first row represents the mean, and the second represents the std.
197 | :param list_image_dir: path of folders with images for intensity distribution estimation.
198 | Can be the path of single directory (n_datasets=1), or a list of folders, each being a separate dataset.
199 | Images can be multimodal, in which case each modality is treated as a different dataset, i.e. each modality will
200 | have a separate block (of size (2, K)) in the result arrays.
201 | :param list_labels_dir: path of folders with label maps corresponding to input images.
202 | If list_image_dir is a list of several folders, list_labels_dir can either be a list of folders (one for each image
203 | folder), or the path to a single folder, which will be used for all datasets.
204 | If a dataset has multi-modal images, the same label map is applied to all modalities.
205 | :param result_dir: path of directory where estimated priors will be writen.
206 | :param estimation_labels: labels to estimate intensity statistics from.
207 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
208 | :param estimation_classes: (optional) enables to regroup structures into classes of similar intensity statistics.
209 | Intensities associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
210 | Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
211 | It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
212 | classes. Default is all labels have different classes (K=len(estimation_labels)).
213 | :param max_channel: (optional) maximum number of channels to consider if the data is multi-spectral. Default is 3.
214 | :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation
215 | """
216 |
217 | # handle results directories
218 | utils.mkdir(result_dir)
219 |
220 | # reformat image/labels dir into lists
221 | list_image_dir = utils.reformat_to_list(list_image_dir)
222 | list_labels_dir = utils.reformat_to_list(list_labels_dir, length=len(list_image_dir))
223 |
224 | # reformat list estimation labels and classes
225 | estimation_labels = np.array(utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype='int'))
226 | if estimation_classes is not None:
227 | estimation_classes = np.array(utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype='int'))
228 | else:
229 | estimation_classes = np.arange(estimation_labels.shape[0])
230 | assert len(estimation_classes) == len(estimation_labels), 'estimation labels and classes should be of same length'
231 |
232 | # get unique classes
233 | unique_estimation_classes, unique_indices = np.unique(estimation_classes, return_index=True)
234 | n_classes = len(unique_estimation_classes)
235 | if not np.array_equal(unique_estimation_classes, np.arange(n_classes)):
236 | raise ValueError('estimation_classes should only contain values between 0 and N-1, '
237 | 'where K is the total number of classes. Here N = %d' % n_classes)
238 |
239 | # loop over dataset
240 | list_datasets_prior_means = list()
241 | list_datasets_prior_stds = list()
242 | for image_dir, labels_dir in zip(list_image_dir, list_labels_dir):
243 |
244 | # get prior stats for dataset
245 | tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset(image_dir,
246 | labels_dir,
247 | estimation_labels,
248 | estimation_classes,
249 | max_channel=max_channel,
250 | rescale=rescale)
251 |
252 | # add stats arrays to list of datasets-wise statistics
253 | list_datasets_prior_means.append(tmp_prior_means)
254 | list_datasets_prior_stds.append(tmp_prior_stds)
255 |
256 | # stack all modalities together
257 | prior_means = np.concatenate(list_datasets_prior_means, axis=0)
258 | prior_stds = np.concatenate(list_datasets_prior_stds, axis=0)
259 |
260 | # save files
261 | np.save(os.path.join(result_dir, 'prior_means.npy'), prior_means)
262 | np.save(os.path.join(result_dir, 'prior_stds.npy'), prior_stds)
263 |
264 | return prior_means, prior_stds
265 |
--------------------------------------------------------------------------------
/SynthSeg/metrics_model.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import numpy as np
19 | import tensorflow as tf
20 | import keras.layers as KL
21 | from keras.models import Model
22 |
23 | # third-party imports
24 | from ext.lab2im import layers
25 |
26 |
27 | def metrics_model(input_model, label_list, metrics='dice'):
28 |
29 | # get prediction
30 | last_tensor = input_model.outputs[0]
31 | input_shape = last_tensor.get_shape().as_list()[1:]
32 |
33 | # check shapes
34 | n_labels = input_shape[-1]
35 | label_list = np.unique(label_list)
36 | assert n_labels == len(label_list), 'label_list should be as long as the posteriors channels'
37 |
38 | # get GT and convert it to probabilistic values
39 | labels_gt = input_model.get_layer('labels_out').output
40 | labels_gt = layers.ConvertLabels(label_list)(labels_gt)
41 | labels_gt = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels_gt)
42 | labels_gt = KL.Reshape(input_shape)(labels_gt)
43 |
44 | # make sure the tensors have the right keras shape
45 | last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
46 | labels_gt._keras_shape = tuple(labels_gt.get_shape().as_list())
47 |
48 | if metrics == 'dice':
49 | last_tensor = layers.DiceLoss()([labels_gt, last_tensor])
50 |
51 | elif metrics == 'wl2':
52 | last_tensor = layers.WeightedL2Loss(target_value=5)([labels_gt, last_tensor])
53 |
54 | else:
55 | raise Exception('metrics should either be "dice or "wl2, got {}'.format(metrics))
56 |
57 | # create the model and return
58 | model = Model(inputs=input_model.inputs, outputs=last_tensor)
59 | return model
60 |
61 |
62 | class IdentityLoss(object):
63 | """Very simple loss, as the computation of the loss as been directly implemented in the model."""
64 | def __init__(self, keepdims=True):
65 | self.keepdims = keepdims
66 |
67 | def loss(self, y_true, y_predicted):
68 | """Because the metrics is already calculated in the model, we simply return y_predicted.
69 | We still need to put y_true in the inputs, as it's expected by keras."""
70 | loss = y_predicted
71 |
72 | tf.debugging.check_numerics(loss, 'Loss not finite')
73 | return loss
74 |
--------------------------------------------------------------------------------
/SynthSeg/model_inputs.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import numpy as np
19 | import numpy.random as npr
20 |
21 | # third-party imports
22 | from ext.lab2im import utils
23 |
24 |
25 | def build_model_inputs(path_label_maps,
26 | n_labels,
27 | batchsize=1,
28 | n_channels=1,
29 | subjects_prob=None,
30 | generation_classes=None,
31 | prior_distributions='uniform',
32 | prior_means=None,
33 | prior_stds=None,
34 | use_specific_stats_for_channel=False,
35 | mix_prior_and_random=False):
36 | """
37 | This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the
38 | input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch).
39 | :param path_label_maps: list of the paths of the input label maps.
40 | :param n_labels: number of labels in the input label maps.
41 | :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
42 | :param n_channels: (optional) number of channels to be synthesised. Default is 1.
43 | :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick
44 | the provided label maps at each minibatch. Must be a 1D numpy array, as long as path_label_maps.
45 | :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
46 | distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence or a
47 | 1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K
48 | is the total number of classes. Default is all labels have different classes.
49 | :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
50 | Can either be 'uniform', or 'normal'. Default is 'uniform'.
51 | :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
52 | these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
53 | 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
54 | uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
55 | mini_batch from the same distribution.
56 | 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
57 | not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch
58 | from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from
59 | N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
60 | 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
61 | from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
62 | modality from the n_mod possibilities, and we sample the GMM means like in 2).
63 | If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
64 | (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
65 | 4) the path to such a numpy array.
66 | Default is None, which corresponds to prior_means = [25, 225].
67 | :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
68 | Default is None, which corresponds to prior_stds = [5, 25].
69 | :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
70 | only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
71 | :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default
72 | values for half of these cases, and thus generate images of random contrast.
73 | """
74 |
75 | # allocate unique class to each label if generation classes is not given
76 | if generation_classes is None:
77 | generation_classes = np.arange(n_labels)
78 | n_classes = len(np.unique(generation_classes))
79 |
80 | # make sure subjects_prob sums to 1
81 | subjects_prob = utils.load_array_if_path(subjects_prob)
82 | if subjects_prob is not None:
83 | subjects_prob /= np.sum(subjects_prob)
84 |
85 | # Generate!
86 | while True:
87 |
88 | # randomly pick as many images as batchsize
89 | indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob)
90 |
91 | # initialise input lists
92 | list_label_maps = []
93 | list_means = []
94 | list_stds = []
95 |
96 | for idx in indices:
97 |
98 | # load input label map
99 | lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4))
100 | if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]):
101 | lab[lab == 24] = 0
102 |
103 | # add label map to inputs
104 | list_label_maps.append(utils.add_axis(lab, axis=[0, -1]))
105 |
106 | # add means and standard deviations to inputs
107 | means = np.empty((1, n_labels, 0))
108 | stds = np.empty((1, n_labels, 0))
109 | for channel in range(n_channels):
110 |
111 | # retrieve channel specific stats if necessary
112 | if isinstance(prior_means, np.ndarray):
113 | if (prior_means.shape[0] > 2) & use_specific_stats_for_channel:
114 | if prior_means.shape[0] / 2 != n_channels:
115 | raise ValueError("the number of blocks in prior_means does not match n_channels. This "
116 | "message is printed because use_specific_stats_for_channel is True.")
117 | tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :]
118 | else:
119 | tmp_prior_means = prior_means
120 | else:
121 | tmp_prior_means = prior_means
122 | if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
123 | tmp_prior_means = None
124 | if isinstance(prior_stds, np.ndarray):
125 | if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel:
126 | if prior_stds.shape[0] / 2 != n_channels:
127 | raise ValueError("the number of blocks in prior_stds does not match n_channels. This "
128 | "message is printed because use_specific_stats_for_channel is True.")
129 | tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :]
130 | else:
131 | tmp_prior_stds = prior_stds
132 | else:
133 | tmp_prior_stds = prior_stds
134 | if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
135 | tmp_prior_stds = None
136 |
137 | # draw means and std devs from priors
138 | tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions,
139 | 125., 125., positive_only=True)
140 | tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions,
141 | 15., 15., positive_only=True)
142 | random_coef = npr.uniform()
143 | if random_coef > 0.95: # reset the background to 0 in 5% of cases
144 | tmp_classes_means[0] = 0
145 | tmp_classes_stds[0] = 0
146 | elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases
147 | tmp_classes_means[0] = npr.uniform(0, 15)
148 | tmp_classes_stds[0] = npr.uniform(0, 5)
149 | tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1])
150 | tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1])
151 | means = np.concatenate([means, tmp_means], axis=-1)
152 | stds = np.concatenate([stds, tmp_stds], axis=-1)
153 | list_means.append(means)
154 | list_stds.append(stds)
155 |
156 | # build list of inputs for generation model
157 | list_inputs = [list_label_maps, list_means, list_stds]
158 | if batchsize > 1: # concatenate each input type if batchsize > 1
159 | list_inputs = [np.concatenate(item, 0) for item in list_inputs]
160 | else:
161 | list_inputs = [item[0] for item in list_inputs]
162 |
163 | yield list_inputs
164 |
--------------------------------------------------------------------------------
/SynthSeg/predict_qc.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import os
19 | import numpy as np
20 | import tensorflow as tf
21 | import keras.layers as KL
22 | from keras.models import Model
23 |
24 | # project imports
25 | from SynthSeg import evaluate
26 |
27 | # third-party imports
28 | from ext.lab2im import utils
29 | from ext.lab2im import edit_volumes
30 | from ext.neuron import models as nrn_models
31 |
32 |
33 | def predict(path_predictions,
34 | path_qc_results,
35 | path_model,
36 | labels_list,
37 | labels_to_convert=None,
38 | convert_gt=False,
39 | shape=224,
40 | n_levels=5,
41 | nb_conv_per_level=3,
42 | conv_size=5,
43 | unet_feat_count=24,
44 | feat_multiplier=2,
45 | activation='relu',
46 | path_gts=None,
47 | verbose=True):
48 |
49 | # prepare input/output filepaths
50 | path_predictions, path_gts, path_qc_results, path_gt_results, path_diff = \
51 | prepare_output_files(path_predictions, path_gts, path_qc_results)
52 |
53 | # get label list
54 | labels_list, _ = utils.get_list_labels(label_list=labels_list)
55 | labels_list_unique, _ = np.unique(labels_list, return_index=True)
56 | if labels_to_convert is not None:
57 | labels_to_convert, _ = utils.get_list_labels(label_list=labels_to_convert)
58 |
59 | # prepare qc results
60 | pred_qc_results = np.zeros((len(labels_list_unique) + 1, len(path_predictions)))
61 | gt_qc_results = np.zeros((len(labels_list_unique), len(path_predictions))) if path_gt_results is not None else None
62 |
63 | # build network
64 | model_input_shape = [None, None, None, 1]
65 | net = build_qc_model(path_model=path_model,
66 | input_shape=model_input_shape,
67 | label_list=labels_list_unique,
68 | n_levels=n_levels,
69 | nb_conv_per_level=nb_conv_per_level,
70 | conv_size=conv_size,
71 | unet_feat_count=unet_feat_count,
72 | feat_multiplier=feat_multiplier,
73 | activation=activation)
74 |
75 | # perform segmentation
76 | loop_info = utils.LoopInfo(len(path_predictions), 10, 'predicting', True)
77 | for idx, (path_prediction, path_gt) in enumerate(zip(path_predictions, path_gts)):
78 |
79 | # compute segmentation only if needed
80 | if verbose:
81 | loop_info.update(idx)
82 |
83 | # preprocessing
84 | prediction, gt_scores = preprocess(path_prediction, path_gt, shape, labels_list, labels_to_convert, convert_gt)
85 |
86 | # get predicted scores
87 | pred_qc_results[-1, idx] = np.sum(prediction > 0)
88 | pred_qc_results[:-1, idx] = np.clip(np.squeeze(net.predict(prediction)), 0, 1)
89 | np.save(path_qc_results, pred_qc_results)
90 |
91 | # save GT scores if necessary
92 | if gt_scores is not None:
93 | gt_qc_results[:, idx] = gt_scores
94 | np.save(path_gt_results, gt_qc_results)
95 |
96 | if path_diff is not None:
97 | diff = pred_qc_results[:-1, :] - gt_qc_results
98 | np.save(path_diff, diff)
99 |
100 |
101 | def prepare_output_files(path_predictions, path_gts, path_qc_results):
102 |
103 | # check inputs
104 | assert path_predictions is not None, 'please specify an input file/folder (--i)'
105 | assert path_qc_results is not None, 'please specify an output file/folder (--o)'
106 |
107 | # convert path to absolute paths
108 | path_predictions = os.path.abspath(path_predictions)
109 | path_qc_results = os.path.abspath(path_qc_results)
110 |
111 | # list input predictions
112 | path_predictions = utils.list_images_in_folder(path_predictions)
113 |
114 | # build path output with qc results
115 | if path_qc_results[-4:] != '.npy':
116 | print('Path for QC outputs provided without npy extension. Adding npy extension.')
117 | path_qc_results += '.npy'
118 | utils.mkdir(os.path.dirname(path_qc_results))
119 |
120 | if path_gts is not None:
121 | path_gts = utils.list_images_in_folder(path_gts)
122 | assert len(path_gts) == len(path_predictions), 'not the same number of predictions and GTs'
123 | path_gt_results = path_qc_results.replace('.npy', '_gt.npy')
124 | path_diff = path_qc_results.replace('.npy', '_diff.npy')
125 | else:
126 | path_gts = [None] * len(path_predictions)
127 | path_gt_results = path_diff = None
128 |
129 | return path_predictions, path_gts, path_qc_results, path_gt_results, path_diff
130 |
131 |
132 | def preprocess(path_prediction, path_gt=None, shape=224, labels_list=None, labels_to_convert=None, convert_gt=False):
133 |
134 | # read image and corresponding info
135 | pred, _, aff_pred, n_dims, _, _, _ = utils.get_volume_info(path_prediction, True)
136 | gt = utils.load_volume(path_gt, aff_ref=np.eye(4)) if path_gt is not None else None
137 |
138 | # align
139 | pred = edit_volumes.align_volume_to_ref(pred, aff_pred, aff_ref=np.eye(4), n_dims=n_dims)
140 |
141 | # pad/crop to 224, such that segmentations are in the middle of the patch
142 | if gt is not None:
143 | pred, gt = make_shape(pred, gt, shape, n_dims)
144 | else:
145 | pred, _ = edit_volumes.crop_volume_around_region(pred, cropping_shape=shape)
146 |
147 | # convert labels if necessary
148 | if labels_to_convert is not None:
149 | lut = utils.get_mapping_lut(labels_to_convert, labels_list)
150 | pred = lut[pred.astype('int32')]
151 | if convert_gt & (gt is not None):
152 | gt = lut[gt.astype('int32')]
153 |
154 | # compute GT dice scores
155 | gt_scores = evaluate.fast_dice(pred, gt, np.unique(labels_list)) if gt is not None else None
156 |
157 | # add batch and channel axes
158 | pred = utils.add_axis(pred, axis=0) # channel axis will be added later when computing one-hot
159 |
160 | return pred, gt_scores
161 |
162 |
163 | def make_shape(pred, gt, shape, n_dims):
164 |
165 | mask = ((pred > 0) & (pred != 24)) | (gt > 0)
166 | vol_shape = np.array(pred.shape[:n_dims])
167 |
168 | if np.any(mask):
169 |
170 | # find cropping indices
171 | indices = np.nonzero(mask)
172 | min_idx = np.maximum(np.array([np.min(idx) for idx in indices]), 0)
173 | max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1, vol_shape)
174 |
175 | # expand/retract (depending on the desired shape) the cropping region around the centre
176 | intermediate_vol_shape = max_idx - min_idx
177 | cropping_shape = np.array(utils.reformat_to_list(shape, length=n_dims))
178 | min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2))
179 | max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2))
180 |
181 | # crop volume
182 | cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)])
183 | pred = edit_volumes.crop_volume_with_idx(pred, cropping, n_dims=n_dims)
184 | gt = edit_volumes.crop_volume_with_idx(gt, cropping, n_dims=n_dims)
185 |
186 | # check if we need to pad the output to the desired shape
187 | min_padding = np.abs(np.minimum(min_idx, 0))
188 | max_padding = np.maximum(max_idx - vol_shape, 0)
189 | if np.any(min_padding > 0) | np.any(max_padding > 0):
190 | pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)])
191 | pred = np.pad(pred, pad_margins, mode='constant', constant_values=0)
192 | gt = np.pad(gt, pad_margins, mode='constant', constant_values=0)
193 |
194 | return pred, gt
195 |
196 |
197 | def build_qc_model(path_model,
198 | input_shape,
199 | label_list,
200 | n_levels,
201 | nb_conv_per_level,
202 | conv_size,
203 | unet_feat_count,
204 | feat_multiplier,
205 | activation):
206 |
207 | assert os.path.isfile(path_model), "The provided model path does not exist."
208 | label_list_unique = np.unique(label_list)
209 | n_labels = len(label_list_unique)
210 |
211 | # one-hot encoding of the input prediction as the network expects soft probabilities
212 | input_labels = KL.Input(input_shape[:-1])
213 | labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(input_labels)
214 | net = Model(inputs=input_labels, outputs=labels)
215 |
216 | # build model
217 | model = nrn_models.conv_enc(input_model=net,
218 | input_shape=input_shape,
219 | nb_levels=n_levels,
220 | nb_conv_per_level=nb_conv_per_level,
221 | conv_size=conv_size,
222 | nb_features=unet_feat_count,
223 | feat_mult=feat_multiplier,
224 | activation=activation,
225 | batch_norm=-1,
226 | use_residuals=True,
227 | name='qc')
228 | last = model.outputs[0]
229 |
230 | conv_kwargs = {'padding': 'same', 'activation': 'relu', 'data_format': 'channels_last'}
231 | last = KL.MaxPool3D(pool_size=(2, 2, 2), name='qc_maxpool_%s' % (n_levels - 1), padding='same')(last)
232 | last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_0')(last)
233 | last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_1')(last)
234 | last = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name='qc_final_pred')(last)
235 |
236 | net = Model(inputs=net.inputs, outputs=last)
237 | net.load_weights(path_model, by_name=True)
238 |
239 | return net
240 |
--------------------------------------------------------------------------------
/SynthSeg/training_denoiser.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import os
19 | import numpy as np
20 | import tensorflow as tf
21 | from keras import models
22 | from keras import layers as KL
23 |
24 | # project imports
25 | from SynthSeg import metrics_model as metrics
26 | from SynthSeg.training import train_model
27 | from SynthSeg.labels_to_image_model import get_shapes
28 | from SynthSeg.training_supervised import build_model_inputs
29 |
30 | # third-party imports
31 | from ext.lab2im import utils, layers
32 | from ext.neuron import models as nrn_models
33 |
34 |
35 | def training(list_paths_input_labels,
36 | list_paths_target_labels,
37 | model_dir,
38 | input_segmentation_labels,
39 | target_segmentation_labels=None,
40 | subjects_prob=None,
41 | batchsize=1,
42 | output_shape=None,
43 | scaling_bounds=.2,
44 | rotation_bounds=15,
45 | shearing_bounds=.012,
46 | nonlin_std=3.,
47 | nonlin_scale=.04,
48 | prob_erosion_dilation=0.3,
49 | min_erosion_dilation=4,
50 | max_erosion_dilation=5,
51 | n_levels=5,
52 | nb_conv_per_level=2,
53 | conv_size=5,
54 | unet_feat_count=16,
55 | feat_multiplier=2,
56 | activation='elu',
57 | skip_n_concatenations=2,
58 | lr=1e-4,
59 | wl2_epochs=1,
60 | dice_epochs=50,
61 | steps_per_epoch=10000,
62 | checkpoint=None):
63 | """
64 |
65 | This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on
66 | label maps. We regroup the parameters in four categories: General, Augmentation, Architecture, Training.
67 |
68 | # IMPORTANT !!!
69 | # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
70 | # these values refer to the RAS axes.
71 |
72 | :param list_paths_input_labels: list of all the paths of the input label maps. These correspond to "noisy"
73 | segmentations that the denoiser will be trained to correct.
74 | :param list_paths_target_labels: list of all the paths of the output label maps. Must have the same order as
75 | list_paths_input_labels. These are the target label maps that the network will learn to produce given the "noisy"
76 | input label maps.
77 | :param model_dir: path of a directory where the models will be saved during training.
78 | :param input_segmentation_labels: list of all the label values present in the input label maps.
79 | :param target_segmentation_labels: list of all the label values present in the output label maps. By default (None)
80 | this will be taken to be the same as input_segmentation_labels.
81 |
82 | # ----------------------------------------------- General parameters -----------------------------------------------
83 | # label maps parameters
84 | :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick
85 | the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an array, and it
86 | must be as long as path_label_maps. By default, all label maps are chosen with the same importance.
87 |
88 | # output-related parameters
89 | :param batchsize: (optional) number of images to generate per mini-batch. Default is 1.
90 | :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
91 | Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
92 | Default is None, where no cropping is performed.
93 |
94 | # --------------------------------------------- Augmentation parameters --------------------------------------------
95 | # spatial deformation parameters
96 | :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is
97 | sampled from a uniform distribution of predefined bounds. Can either be:
98 | 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
99 | (1-scaling_bounds, 1+scaling_bounds) for each dimension.
100 | 2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor in dimension i is sampled from
101 | the uniform distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
102 | 3) False, in which case scaling is completely turned off.
103 | Default is scaling_bounds = 0.2 (case 1)
104 | :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the
105 | bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]).
106 | Default is rotation_bounds = 15.
107 | :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
108 | :param nonlin_std: (optional) Standard deviation of the normal distribution from which we sample the first
109 | tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation.
110 | :param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled
111 | tensor for synthesising the elastic deformation field.
112 |
113 | # degradation of the input labels
114 | :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or
115 | dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network.
116 | :param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random
117 | coefficients are applied. Set the minimum erosion/dilation coefficient here.
118 | :param max_erosion_dilation: (optional) Set the maximum erosion/dilation coefficient here.
119 |
120 | # ------------------------------------------ UNet architecture parameters ------------------------------------------
121 | :param n_levels: (optional) number of level for the Unet. Default is 5.
122 | :param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2.
123 | :param conv_size: (optional) size of the convolution kernels. Default is 2.
124 | :param unet_feat_count: (optional) number of feature for the first layer of the UNet. Default is 24.
125 | :param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 2.
126 | :param activation: (optional) activation function. Can be 'elu', 'relu'.
127 | :param skip_n_concatenations: (optional) number of levels for which to remove the traditional skip connections of
128 | the UNet architecture. default is zero, which corresponds to the classic UNet architecture. Example:
129 | If skip_n_concatenations = 2, then we will remove the concatenation link between the two top levels of the UNet.
130 |
131 | # ----------------------------------------------- Training parameters ----------------------------------------------
132 | :param lr: (optional) learning rate for the training. Default is 1e-4
133 | :param wl2_epochs: (optional) number of epochs for which the network (except the soft-max layer) is trained with L2
134 | norm loss function. Default is 1.
135 | :param dice_epochs: (optional) number of epochs with the soft Dice loss function. Default is 50.
136 | :param steps_per_epoch: (optional) number of steps per epoch. Default is 10000. Since no online validation is
137 | possible, this is equivalent to the frequency at which the models are saved.
138 | :param checkpoint: (optional) path of an already saved model to load before starting the training.
139 | """
140 |
141 | # check epochs
142 | assert (wl2_epochs > 0) | (dice_epochs > 0), \
143 | 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs)
144 |
145 | # prepare data files
146 | input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels)
147 | if target_segmentation_labels is None:
148 | target_label_list = input_label_list
149 | else:
150 | target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels)
151 | n_labels = np.size(target_label_list)
152 |
153 | # create augmentation model
154 | labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4))
155 | augmentation_model = build_augmentation_model(labels_shape,
156 | input_label_list,
157 | crop_shape=output_shape,
158 | output_div_by_n=2 ** n_levels,
159 | scaling_bounds=scaling_bounds,
160 | rotation_bounds=rotation_bounds,
161 | shearing_bounds=shearing_bounds,
162 | nonlin_std=nonlin_std,
163 | nonlin_scale=nonlin_scale,
164 | prob_erosion_dilation=prob_erosion_dilation,
165 | min_erosion_dilation=min_erosion_dilation,
166 | max_erosion_dilation=max_erosion_dilation)
167 | unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:]
168 |
169 | # prepare the segmentation model
170 | l2l_model = nrn_models.unet(input_model=augmentation_model,
171 | input_shape=unet_input_shape,
172 | nb_labels=n_labels,
173 | nb_levels=n_levels,
174 | nb_conv_per_level=nb_conv_per_level,
175 | conv_size=conv_size,
176 | nb_features=unet_feat_count,
177 | feat_mult=feat_multiplier,
178 | activation=activation,
179 | batch_norm=-1,
180 | skip_n_concatenations=skip_n_concatenations,
181 | name='l2l')
182 |
183 | # input generator
184 | model_inputs = build_model_inputs(path_inputs=list_paths_input_labels,
185 | path_outputs=list_paths_target_labels,
186 | batchsize=batchsize,
187 | subjects_prob=subjects_prob,
188 | dtype_input='int32')
189 | input_generator = utils.build_training_generator(model_inputs, batchsize)
190 |
191 | # pre-training with weighted L2, input is fit to the softmax rather than the probabilities
192 | if wl2_epochs > 0:
193 | wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output])
194 | wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2')
195 | train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint)
196 | checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs)
197 |
198 | # fine-tuning with dice metric
199 | dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice')
200 | train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint)
201 |
202 |
203 | def build_augmentation_model(labels_shape,
204 | segmentation_labels,
205 | crop_shape=None,
206 | output_div_by_n=None,
207 | scaling_bounds=0.15,
208 | rotation_bounds=15,
209 | shearing_bounds=0.012,
210 | translation_bounds=False,
211 | nonlin_std=3.,
212 | nonlin_scale=.0625,
213 | prob_erosion_dilation=0.3,
214 | min_erosion_dilation=4,
215 | max_erosion_dilation=7):
216 |
217 | # reformat resolutions and get shapes
218 | labels_shape = utils.reformat_to_list(labels_shape)
219 | n_dims, _ = utils.get_dims(labels_shape)
220 | n_labels = len(segmentation_labels)
221 |
222 | # get shapes
223 | crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n)
224 |
225 | # define model inputs
226 | net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32')
227 | target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32')
228 |
229 | # deform labels
230 | noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds,
231 | rotation_bounds=rotation_bounds,
232 | shearing_bounds=shearing_bounds,
233 | translation_bounds=translation_bounds,
234 | nonlin_std=nonlin_std,
235 | nonlin_scale=nonlin_scale,
236 | inter_method='nearest')([net_input, target_input])
237 |
238 | # cropping
239 | if crop_shape != labels_shape:
240 | noisy_labels, target = layers.RandomCrop(crop_shape)([noisy_labels, target])
241 |
242 | # random erosion
243 | if prob_erosion_dilation > 0:
244 | noisy_labels = layers.RandomDilationErosion(min_erosion_dilation,
245 | max_erosion_dilation,
246 | prob=prob_erosion_dilation)(noisy_labels)
247 |
248 | # convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot
249 | noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels)
250 | target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target)
251 | noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels),
252 | name='noisy_labels_out')([noisy_labels, target])
253 |
254 | # build model and return
255 | brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target])
256 | return brain_model
257 |
--------------------------------------------------------------------------------
/SynthSeg/validate_denoiser.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import os
19 |
20 | # project imports
21 | from SynthSeg.predict_denoiser import predict
22 |
23 | # third-party imports
24 | from ext.lab2im import utils
25 |
26 |
27 | def validate_training(prediction_dir,
28 | gt_dir,
29 | models_dir,
30 | validation_main_dir,
31 | target_segmentation_labels,
32 | input_segmentation_labels=None,
33 | evaluation_labels=None,
34 | step_eval=1,
35 | min_pad=None,
36 | cropping=None,
37 | topology_classes=None,
38 | sigma_smoothing=0,
39 | keep_biggest_component=False,
40 | n_levels=5,
41 | nb_conv_per_level=2,
42 | conv_size=3,
43 | unet_feat_count=24,
44 | feat_multiplier=2,
45 | activation='elu',
46 | skip_n_concatenations=0,
47 | recompute=True):
48 |
49 | # create result folder
50 | utils.mkdir(validation_main_dir)
51 |
52 | # loop over models
53 | list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval]
54 | # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 2 == 0]
55 | loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True)
56 | for model_idx, path_model in enumerate(list_models):
57 |
58 | # build names and create folders
59 | model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', ''))
60 | dice_path = os.path.join(model_val_dir, 'dice.npy')
61 | utils.mkdir(model_val_dir)
62 |
63 | if (not os.path.isfile(dice_path)) | recompute:
64 | loop_info.update(model_idx)
65 | predict(path_predictions=prediction_dir,
66 | path_corrections=model_val_dir,
67 | path_model=path_model,
68 | target_segmentation_labels=target_segmentation_labels,
69 | input_segmentation_labels=input_segmentation_labels,
70 | min_pad=min_pad,
71 | cropping=cropping,
72 | topology_classes=topology_classes,
73 | sigma_smoothing=sigma_smoothing,
74 | keep_biggest_component=keep_biggest_component,
75 | n_levels=n_levels,
76 | nb_conv_per_level=nb_conv_per_level,
77 | conv_size=conv_size,
78 | unet_feat_count=unet_feat_count,
79 | feat_multiplier=feat_multiplier,
80 | activation=activation,
81 | skip_n_concatenations=skip_n_concatenations,
82 | gt_folder=gt_dir,
83 | evaluation_labels=evaluation_labels,
84 | recompute=recompute,
85 | verbose=False)
86 |
--------------------------------------------------------------------------------
/SynthSeg/validate_group.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import os
19 |
20 | # project imports
21 | from SynthSeg.predict_group import predict
22 |
23 | # third-party imports
24 | from ext.lab2im import utils
25 |
26 |
27 | def validate_training(image_dir,
28 | mask_dir,
29 | gt_dir,
30 | models_dir,
31 | validation_main_dir,
32 | labels_segmentation,
33 | labels_mask,
34 | evaluation_labels=None,
35 | step_eval=1,
36 | min_pad=None,
37 | cropping=None,
38 | sigma_smoothing=0,
39 | strict_masking=False,
40 | keep_biggest_component=False,
41 | n_levels=5,
42 | nb_conv_per_level=2,
43 | conv_size=3,
44 | unet_feat_count=24,
45 | feat_multiplier=2,
46 | activation='elu',
47 | list_incorrect_labels=None,
48 | list_correct_labels=None,
49 | recompute=False):
50 |
51 | # create result folder
52 | utils.mkdir(validation_main_dir)
53 |
54 | # loop over models
55 | list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval]
56 | # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0]
57 | loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True)
58 | for model_idx, path_model in enumerate(list_models):
59 |
60 | # build names and create folders
61 | model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', ''))
62 | dice_path = os.path.join(model_val_dir, 'dice.npy')
63 | utils.mkdir(model_val_dir)
64 |
65 | if (not os.path.isfile(dice_path)) | recompute:
66 | loop_info.update(model_idx)
67 | predict(path_images=image_dir,
68 | path_masks=mask_dir,
69 | path_segmentations=model_val_dir,
70 | path_model=path_model,
71 | labels_segmentation=labels_segmentation,
72 | labels_mask=labels_mask,
73 | min_pad=min_pad,
74 | cropping=cropping,
75 | sigma_smoothing=sigma_smoothing,
76 | strict_masking=strict_masking,
77 | keep_biggest_component=keep_biggest_component,
78 | n_levels=n_levels,
79 | nb_conv_per_level=nb_conv_per_level,
80 | conv_size=conv_size,
81 | unet_feat_count=unet_feat_count,
82 | feat_multiplier=feat_multiplier,
83 | activation=activation,
84 | gt_folder=gt_dir,
85 | evaluation_labels=evaluation_labels,
86 | list_incorrect_labels=list_incorrect_labels,
87 | list_correct_labels=list_correct_labels,
88 | recompute=recompute,
89 | verbose=False)
90 |
--------------------------------------------------------------------------------
/SynthSeg/validate_qc.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import os
19 | import re
20 | import logging
21 | import numpy as np
22 | import matplotlib.pyplot as plt
23 | from tensorflow.python.summary.summary_iterator import summary_iterator
24 |
25 | # project imports
26 | from SynthSeg.predict_qc import predict
27 |
28 | # third-party imports
29 | from ext.lab2im import utils
30 |
31 |
32 | def validate_training(prediction_dir,
33 | gt_dir,
34 | models_dir,
35 | validation_main_dir,
36 | labels_list,
37 | labels_to_convert=None,
38 | convert_gt=False,
39 | shape=224,
40 | n_levels=5,
41 | nb_conv_per_level=3,
42 | conv_size=5,
43 | unet_feat_count=24,
44 | feat_multiplier=2,
45 | activation='relu',
46 | step_eval=1,
47 | recompute=False):
48 |
49 | # create result folder
50 | utils.mkdir(validation_main_dir)
51 |
52 | # loop over models
53 | list_models = utils.list_files(models_dir, expr=['qc', '.h5'], cond_type='and')[::step_eval]
54 | # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0]
55 | loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True)
56 | for model_idx, path_model in enumerate(list_models):
57 |
58 | # build names and create folders
59 | model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', ''))
60 | score_path = os.path.join(model_val_dir, 'pred_qc_results.npy')
61 | utils.mkdir(model_val_dir)
62 |
63 | if (not os.path.isfile(score_path)) | recompute:
64 | loop_info.update(model_idx)
65 | predict(path_predictions=prediction_dir,
66 | path_qc_results=score_path,
67 | path_model=path_model,
68 | labels_list=labels_list,
69 | labels_to_convert=labels_to_convert,
70 | convert_gt=convert_gt,
71 | shape=shape,
72 | n_levels=n_levels,
73 | nb_conv_per_level=nb_conv_per_level,
74 | conv_size=conv_size,
75 | unet_feat_count=unet_feat_count,
76 | feat_multiplier=feat_multiplier,
77 | activation=activation,
78 | path_gts=gt_dir,
79 | verbose=False)
80 |
81 |
82 | def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_indices=None,
83 | skip_first_dice_row=True, size_max_circle=100, figsize=(11, 6), y_lim=None, fontsize=18,
84 | list_linestyles=None, list_colours=None, plot_legend=False):
85 | """This function plots the validation curves of several networks, based on the results of validate_training().
86 | It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores
87 | for the corresponding validated epoch.
88 | :param list_validation_dirs: list of all the validation folders of the trainings to plot.
89 | :param eval_indices: (optional) compute the average Dice on a subset of labels indicated by the specified indices.
90 | Can be a 1d numpy array, the path to such an array, or a list of 1d numpy arrays as long as list_validation_dirs.
91 | :param skip_first_dice_row: if eval_indices is None, skip the first row of the dice matrices (usually background)
92 | :param size_max_circle: (optional) size of the marker for epochs achieving the best validation scores.
93 | :param figsize: (optional) size of the figure to draw.
94 | :param fontsize: (optional) fontsize used for the graph."""
95 |
96 | n_curves = len(list_validation_dirs)
97 |
98 | if eval_indices is not None:
99 | if isinstance(eval_indices, (np.ndarray, str)):
100 | if isinstance(eval_indices, str):
101 | eval_indices = np.load(eval_indices)
102 | eval_indices = np.squeeze(utils.reformat_to_n_channels_array(eval_indices, n_dims=len(eval_indices)))
103 | eval_indices = [eval_indices] * len(list_validation_dirs)
104 | elif isinstance(eval_indices, list):
105 | for (i, e) in enumerate(eval_indices):
106 | if isinstance(e, np.ndarray):
107 | eval_indices[i] = np.squeeze(utils.reformat_to_n_channels_array(e, n_dims=len(e)))
108 | else:
109 | raise TypeError('if provided as a list, eval_indices should only contain numpy arrays')
110 | else:
111 | raise TypeError('eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays.')
112 | else:
113 | eval_indices = [None] * len(list_validation_dirs)
114 |
115 | # reformat model names
116 | if architecture_names is None:
117 | architecture_names = [os.path.basename(os.path.dirname(d)) for d in list_validation_dirs]
118 | else:
119 | architecture_names = utils.reformat_to_list(architecture_names, len(list_validation_dirs))
120 |
121 | # prepare legend labels
122 | if plot_legend is False:
123 | list_legend_labels = ['_nolegend_'] * n_curves
124 | elif plot_legend is True:
125 | list_legend_labels = architecture_names
126 | else:
127 | list_legend_labels = architecture_names
128 | list_legend_labels = ['_nolegend_' if i >= plot_legend else list_legend_labels[i] for i in range(n_curves)]
129 |
130 | # prepare linestyles
131 | if list_linestyles is not None:
132 | list_linestyles = utils.reformat_to_list(list_linestyles)
133 | else:
134 | list_linestyles = [None] * n_curves
135 |
136 | # prepare curve colours
137 | if list_colours is not None:
138 | list_colours = utils.reformat_to_list(list_colours)
139 | else:
140 | list_colours = [None] * n_curves
141 |
142 | # loop over architectures
143 | plt.figure(figsize=figsize)
144 | for idx, (net_val_dir, net_name, linestyle, colour, legend_label, eval_idx) in enumerate(zip(list_validation_dirs,
145 | architecture_names,
146 | list_linestyles,
147 | list_colours,
148 | list_legend_labels,
149 | eval_indices)):
150 |
151 | list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False)
152 |
153 | # loop over epochs
154 | list_net_scores = list()
155 | list_epochs = list()
156 | for epoch_dir in list_epochs_dir:
157 |
158 | # build names and create folders
159 | path_epoch_scores = utils.list_files(os.path.join(net_val_dir, epoch_dir), expr='diff')
160 | if len(path_epoch_scores) == 1:
161 | path_epoch_scores = path_epoch_scores[0]
162 | if eval_idx is not None:
163 | list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[eval_idx, :])))
164 | else:
165 | if skip_first_dice_row:
166 | list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[1:, :])))
167 | else:
168 | list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores))))
169 | list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir)))
170 |
171 | # plot validation scores for current architecture
172 | if list_net_scores: # check that archi has been validated for at least 1 epoch
173 | list_net_scores = np.array(list_net_scores)
174 | list_epochs = np.array(list_epochs)
175 | list_epochs, idx = np.unique(list_epochs, return_index=True)
176 | list_net_scores = list_net_scores[idx]
177 | min_score = np.min(list_net_scores)
178 | epoch_min_score = list_epochs[np.argmin(list_net_scores)]
179 | print('\n'+net_name)
180 | print('epoch min score: %d' % epoch_min_score)
181 | print('min score: %0.3f' % min_score)
182 | plt.plot(list_epochs, list_net_scores, label=legend_label, linestyle=linestyle, color=colour)
183 | plt.scatter(epoch_min_score, min_score, s=size_max_circle, color=colour)
184 |
185 | # finalise plot
186 | plt.grid()
187 | plt.tick_params(axis='both', labelsize=fontsize)
188 | plt.ylabel('Scores', fontsize=fontsize)
189 | plt.xlabel('Epochs', fontsize=fontsize)
190 | if y_lim is not None:
191 | plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot
192 | plt.title('Validation curves', fontsize=fontsize)
193 | if plot_legend:
194 | plt.legend(fontsize=fontsize)
195 | plt.tight_layout(pad=1)
196 | plt.show()
197 |
198 |
199 | def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, 6), fontsize=18,
200 | y_lim=None, remove_legend=False):
201 | """This function draws the learning curve of several trainings on the same graph.
202 | :param path_tensorboard_files: list of tensorboard files corresponding to the models to plot.
203 | :param architecture_names: list of the names of the models
204 | :param figsize: (optional) size of the figure to draw.
205 | :param fontsize: (optional) fontsize used for the graph.
206 | """
207 |
208 | # reformat inputs
209 | path_tensorboard_files = utils.reformat_to_list(path_tensorboard_files)
210 | architecture_names = utils.reformat_to_list(architecture_names)
211 | assert len(path_tensorboard_files) == len(architecture_names), 'names and tensorboard lists should have same length'
212 |
213 | # loop over architectures
214 | plt.figure(figsize=figsize)
215 | for path_tensorboard_file, name in zip(path_tensorboard_files, architecture_names):
216 |
217 | path_tensorboard_file = utils.reformat_to_list(path_tensorboard_file)
218 |
219 | # extract loss at the end of all epochs
220 | list_losses = list()
221 | list_epochs = list()
222 | logging.getLogger('tensorflow').disabled = True
223 | for path in path_tensorboard_file:
224 | for e in summary_iterator(path):
225 | for v in e.summary.value:
226 | if v.tag == 'loss' or v.tag == 'accuracy' or v.tag == 'epoch_loss':
227 | list_losses.append(v.simple_value)
228 | list_epochs.append(e.step)
229 | plt.plot(np.array(list_epochs), np.array(list_losses), label=name, linewidth=2)
230 |
231 | # finalise plot
232 | plt.grid()
233 | if not remove_legend:
234 | plt.legend(fontsize=fontsize)
235 | plt.xlabel('Epochs', fontsize=fontsize)
236 | plt.ylabel('Scores', fontsize=fontsize)
237 | if y_lim is not None:
238 | plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot
239 | plt.tick_params(axis='both', labelsize=fontsize)
240 | plt.title('Learning curves', fontsize=fontsize)
241 | plt.tight_layout(pad=1)
242 | plt.show()
243 |
--------------------------------------------------------------------------------
/bibtex.bib:
--------------------------------------------------------------------------------
1 | @article{billot_synthseg_2023,
2 | title = {SynthSeg: {Segmentation} of brain {MRI} scans of any contrast and resolution without retraining},
3 | author = {Billot, Benjamin and Greve, Douglas N. and Puonti, Oula and Thielscher, Axel and Van Leemput, Koen and Fischl, Bruce and Dalca, Adrian V. and Iglesias, Juan Eugenio},
4 | journal = {{Medical} {Image} {Analysis}},
5 | year = {2023},
6 | volume = {86},
7 | pages = {102789},
8 | issn = {1361-8415},
9 | doi = {10.1016/j.media.2023.102789},
10 | }
11 |
12 | @article{billot_robust_2023,
13 | title = {{Robust} machine learning segmentation for large-scale analysis of heterogeneous clinical brain {MRI} datasets},
14 | author = {Billot, Benjamin and Colin, Magdamo Cheng, You and Das, Sudeshna and Iglesias, Juan Eugenio},
15 | journal = {{Proceedings} of the {National} {Academy} of {Sciences} ({PNAS})},
16 | year = {2023},
17 | volume = {120},
18 | number = {9},
19 | pages = {1--10},
20 | doi = {10.1073/pnas.2216399120},
21 | }
22 |
--------------------------------------------------------------------------------
/data/README_figures/new_features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/new_features.png
--------------------------------------------------------------------------------
/data/README_figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/overview.png
--------------------------------------------------------------------------------
/data/README_figures/robust.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/robust.png
--------------------------------------------------------------------------------
/data/README_figures/segmentations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/segmentations.png
--------------------------------------------------------------------------------
/data/README_figures/table_versions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/table_versions.png
--------------------------------------------------------------------------------
/data/README_figures/training_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/training_data.png
--------------------------------------------------------------------------------
/data/README_figures/youtube_link.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/README_figures/youtube_link.png
--------------------------------------------------------------------------------
/data/labels table.txt:
--------------------------------------------------------------------------------
1 | List of the structures segmented by SynthSeg along with their corresponding label values.
2 |
3 | The structures are given in the same order as they appear in the posteriors, i.e. the first map of the posteriors
4 | corresponds to the background, then the second map is associated to the left cerebral white matter, etc.
5 |
6 | Please note that the label values follow the FreeSurfer classification. Also, we do not provide any colour scheme,
7 | as the colour displayed for each structure depends on the used image viewer.
8 |
9 | WARNING: if you use the --v1 flag, note that it won't segment the CSF (label 24), since this was introduced in version 2
10 |
11 | labels structures
12 | 0 background
13 | 2 left cerebral white matter
14 | 3 left cerebral cortex
15 | 4 left lateral ventricle
16 | 5 left inferior lateral ventricle
17 | 7 left cerebellum white matter
18 | 8 left cerebellum cortex
19 | 10 left thalamus
20 | 11 left caudate
21 | 12 left putamen
22 | 13 left pallidum
23 | 14 3rd ventricle
24 | 15 4th ventricle
25 | 16 brain-stem
26 | 17 left hippocampus
27 | 18 left amygdala
28 | 26 left accumbens area
29 | 24 CSF
30 | 28 left ventral DC
31 | 41 right cerebral white matter
32 | 42 right cerebral cortex
33 | 43 right lateral ventricle
34 | 44 right inferior lateral ventricle
35 | 46 right cerebellum white matter
36 | 47 right cerebellum cortex
37 | 49 right thalamus
38 | 50 right caudate
39 | 51 right putamen
40 | 52 right pallidum
41 | 53 right hippocampus
42 | 54 right amygdala
43 | 58 right accumbens area
44 | 60 right ventral DC
--------------------------------------------------------------------------------
/data/labels_classes_priors/generation_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/generation_classes.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/generation_classes_contrast_specific.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/generation_classes_contrast_specific.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/generation_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/generation_labels.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/prior_means_t1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/prior_means_t1.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/prior_stds_t1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/prior_stds_t1.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_denoiser_labels_2.0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_denoiser_labels_2.0.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_parcellation_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_parcellation_labels.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_parcellation_names.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_parcellation_names.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_qc_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_labels.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_qc_labels_2.0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_labels_2.0.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_qc_names.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_names.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_qc_names_2.0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_qc_names_2.0.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_segmentation_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_labels.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_segmentation_labels_2.0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_labels_2.0.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_segmentation_names.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_names.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_segmentation_names_2.0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_segmentation_names_2.0.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_topological_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_topological_classes.npy
--------------------------------------------------------------------------------
/data/labels_classes_priors/synthseg_topological_classes_2.0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/labels_classes_priors/synthseg_topological_classes_2.0.npy
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_01.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_01.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_02.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_02.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_03.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_03.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_04.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_04.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_05.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_05.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_06.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_06.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_07.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_07.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_08.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_08.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_09.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_09.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_10.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_10.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_11.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_11.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_12.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_12.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_13.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_13.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_14.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_14.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_15.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_15.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_16.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_16.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_17.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_17.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_18.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_18.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_19.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_19.nii.gz
--------------------------------------------------------------------------------
/data/training_label_maps/training_seg_20.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/training_label_maps/training_seg_20.nii.gz
--------------------------------------------------------------------------------
/data/tutorial_7/noisy_segmentations_d/0001.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/noisy_segmentations_d/0001.nii.gz
--------------------------------------------------------------------------------
/data/tutorial_7/noisy_segmentations_d/0002.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/noisy_segmentations_d/0002.nii.gz
--------------------------------------------------------------------------------
/data/tutorial_7/noisy_segmentations_d/0003.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/noisy_segmentations_d/0003.nii.gz
--------------------------------------------------------------------------------
/data/tutorial_7/segmentation_labels_s1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/segmentation_labels_s1.npy
--------------------------------------------------------------------------------
/data/tutorial_7/target_segmentations_d/0001.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/target_segmentations_d/0001.nii.gz
--------------------------------------------------------------------------------
/data/tutorial_7/target_segmentations_d/0002.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/target_segmentations_d/0002.nii.gz
--------------------------------------------------------------------------------
/data/tutorial_7/target_segmentations_d/0003.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/data/tutorial_7/target_segmentations_d/0003.nii.gz
--------------------------------------------------------------------------------
/ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/ext/__init__.py
--------------------------------------------------------------------------------
/ext/lab2im/__init__.py:
--------------------------------------------------------------------------------
1 | from . import edit_tensors
2 | from . import edit_volumes
3 | from . import image_generator
4 | from . import lab2im_model
5 | from . import layers
6 | from . import utils
7 |
--------------------------------------------------------------------------------
/ext/lab2im/image_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite the first SynthSeg paper:
3 | https://github.com/BBillot/lab2im/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import numpy as np
19 | import numpy.random as npr
20 |
21 | # project imports
22 | from ext.lab2im import utils
23 | from ext.lab2im import edit_volumes
24 | from ext.lab2im.lab2im_model import lab2im_model
25 |
26 |
27 | class ImageGenerator:
28 |
29 | def __init__(self,
30 | labels_dir,
31 | generation_labels=None,
32 | output_labels=None,
33 | batchsize=1,
34 | n_channels=1,
35 | target_res=None,
36 | output_shape=None,
37 | output_div_by_n=None,
38 | generation_classes=None,
39 | prior_distributions='uniform',
40 | prior_means=None,
41 | prior_stds=None,
42 | use_specific_stats_for_channel=False,
43 | blur_range=1.15):
44 | """
45 | This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels
46 | maps, and a python generator that supplies the input data for this model.
47 | To generate pairs of image/labels you can just call the method generate_image() on an object of this class.
48 |
49 | :param labels_dir: path of folder with all input label maps, or to a single label map.
50 |
51 | # IMPORTANT !!!
52 | # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
53 | # these values refer to the RAS axes.
54 |
55 | # label maps-related parameters
56 | :param generation_labels: (optional) list of all possible label values in the input label maps.
57 | Default is None, where the label values are directly gotten from the provided label maps.
58 | If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array.
59 | :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in
60 | the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps
61 | will be converted to output_labels[i] in the returned label maps. Examples:
62 | Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps.
63 | Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps.
64 | Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels.
65 |
66 | # output-related parameters
67 | :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
68 | :param n_channels: (optional) number of channels to be synthetised. Default is 1.
69 | :param target_res: (optional) target resolution of the generated images and corresponding label maps.
70 | If None, the outputs will have the same resolution as the input label maps.
71 | Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array.
72 | :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image.
73 | Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
74 | :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites
75 | output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or
76 | the path to a 1d numpy array.
77 |
78 | # GMM-sampling parameters
79 | :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
80 | distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a
81 | 1d numpy array, or the path to a 1d numpy array.
82 | It should have the same length as generation_labels, and contain values between 0 and K-1, where K is the total
83 | number of classes. Default is all labels have different classes (K=len(generation_labels)).
84 | :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
85 | Can either be 'uniform', or 'normal'. Default is 'uniform'.
86 | :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
87 | these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
88 | 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
89 | uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
90 | mini_batch from the same distribution.
91 | 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
92 | not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each
93 | mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from
94 | N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
95 | 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
96 | from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
97 | modality from the n_mod possibilities, and we sample the GMM means like in 2).
98 | If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
99 | (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
100 | 4) the path to such a numpy array.
101 | Default is None, which corresponds to prior_means = [25, 225].
102 | :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
103 | Default is None, which corresponds to prior_stds = [5, 25].
104 | :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
105 | only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
106 |
107 | # blurring parameters
108 | :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is
109 | given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a c
110 | coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range].
111 | If None, no randomisation. Default is 1.15.
112 | """
113 |
114 | # prepare data files
115 | self.labels_paths = utils.list_images_in_folder(labels_dir)
116 |
117 | # generation parameters
118 | self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \
119 | utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4))
120 | self.n_channels = n_channels
121 | if generation_labels is not None:
122 | self.generation_labels = utils.load_array_if_path(generation_labels)
123 | else:
124 | self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir)
125 | if output_labels is not None:
126 | self.output_labels = utils.load_array_if_path(output_labels)
127 | else:
128 | self.output_labels = self.generation_labels
129 | self.target_res = utils.load_array_if_path(target_res)
130 | self.batchsize = batchsize
131 | # preliminary operations
132 | self.output_shape = utils.load_array_if_path(output_shape)
133 | self.output_div_by_n = output_div_by_n
134 | # GMM parameters
135 | self.prior_distributions = prior_distributions
136 | if generation_classes is not None:
137 | self.generation_classes = utils.load_array_if_path(generation_classes)
138 | assert self.generation_classes.shape == self.generation_labels.shape, \
139 | 'if provided, generation labels should have the same shape as generation_labels'
140 | unique_classes = np.unique(self.generation_classes)
141 | assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \
142 | 'generation_classes should a linear range between 0 and its maximum value.'
143 | else:
144 | self.generation_classes = np.arange(self.generation_labels.shape[0])
145 | self.prior_means = utils.load_array_if_path(prior_means)
146 | self.prior_stds = utils.load_array_if_path(prior_stds)
147 | self.use_specific_stats_for_channel = use_specific_stats_for_channel
148 |
149 | # blurring parameters
150 | self.blur_range = blur_range
151 |
152 | # build transformation model
153 | self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model()
154 |
155 | # build generator for model inputs
156 | self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels))
157 |
158 | # build brain generator
159 | self.image_generator = self._build_image_generator()
160 |
161 | def _build_lab2im_model(self):
162 | # build_model
163 | lab_to_im_model = lab2im_model(labels_shape=self.labels_shape,
164 | n_channels=self.n_channels,
165 | generation_labels=self.generation_labels,
166 | output_labels=self.output_labels,
167 | atlas_res=self.atlas_res,
168 | target_res=self.target_res,
169 | output_shape=self.output_shape,
170 | output_div_by_n=self.output_div_by_n,
171 | blur_range=self.blur_range)
172 | out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:]
173 | return lab_to_im_model, out_shape
174 |
175 | def _build_image_generator(self):
176 | while True:
177 | model_inputs = next(self.model_inputs_generator)
178 | [image, labels] = self.labels_to_image_model.predict(model_inputs)
179 | yield image, labels
180 |
181 | def generate_image(self):
182 | """call this method when an object of this class has been instantiated to generate new brains"""
183 | (image, labels) = next(self.image_generator)
184 | # put back images in native space
185 | list_images = list()
186 | list_labels = list()
187 | for i in range(self.batchsize):
188 | list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff,
189 | n_dims=self.n_dims))
190 | list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff,
191 | n_dims=self.n_dims))
192 | image = np.stack(list_images, axis=0)
193 | labels = np.stack(list_labels, axis=0)
194 | return np.squeeze(image), np.squeeze(labels)
195 |
196 | def _build_model_inputs(self, n_labels):
197 |
198 | # get label info
199 | _, _, n_dims, _, _, _ = utils.get_volume_info(self.labels_paths[0])
200 |
201 | # Generate!
202 | while True:
203 |
204 | # randomly pick as many images as batchsize
205 | indices = npr.randint(len(self.labels_paths), size=self.batchsize)
206 |
207 | # initialise input lists
208 | list_label_maps = []
209 | list_means = []
210 | list_stds = []
211 |
212 | for idx in indices:
213 |
214 | # load label in identity space, and add them to inputs
215 | y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4))
216 | list_label_maps.append(utils.add_axis(y, axis=[0, -1]))
217 |
218 | # add means and standard deviations to inputs
219 | means = np.empty((1, n_labels, 0))
220 | stds = np.empty((1, n_labels, 0))
221 | for channel in range(self.n_channels):
222 |
223 | # retrieve channel specific stats if necessary
224 | if isinstance(self.prior_means, np.ndarray):
225 | if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel:
226 | if self.prior_means.shape[0] / 2 != self.n_channels:
227 | raise ValueError("the number of blocks in prior_means does not match n_channels. This "
228 | "message is printed because use_specific_stats_for_channel is True.")
229 | tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :]
230 | else:
231 | tmp_prior_means = self.prior_means
232 | else:
233 | tmp_prior_means = self.prior_means
234 | if isinstance(self.prior_stds, np.ndarray):
235 | if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel:
236 | if self.prior_stds.shape[0] / 2 != self.n_channels:
237 | raise ValueError("the number of blocks in prior_stds does not match n_channels. This "
238 | "message is printed because use_specific_stats_for_channel is True.")
239 | tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :]
240 | else:
241 | tmp_prior_stds = self.prior_stds
242 | else:
243 | tmp_prior_stds = self.prior_stds
244 |
245 | # draw means and std devs from priors
246 | tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels,
247 | self.prior_distributions, 125., 100.,
248 | positive_only=True)
249 | tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels,
250 | self.prior_distributions, 15., 10.,
251 | positive_only=True)
252 | tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1])
253 | tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1])
254 | means = np.concatenate([means, tmp_means], axis=-1)
255 | stds = np.concatenate([stds, tmp_stds], axis=-1)
256 | list_means.append(means)
257 | list_stds.append(stds)
258 |
259 | # build list of inputs of augmentation model
260 | list_inputs = [list_label_maps, list_means, list_stds]
261 | if self.batchsize > 1: # concatenate individual input types if batchsize > 1
262 | list_inputs = [np.concatenate(item, 0) for item in list_inputs]
263 | else:
264 | list_inputs = [item[0] for item in list_inputs]
265 |
266 | yield list_inputs
267 |
--------------------------------------------------------------------------------
/ext/lab2im/lab2im_model.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite the first SynthSeg paper:
3 | https://github.com/BBillot/lab2im/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # python imports
18 | import numpy as np
19 | import keras.layers as KL
20 | from keras.models import Model
21 |
22 | # project imports
23 | from ext.lab2im import utils
24 | from ext.lab2im import layers
25 | from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling
26 |
27 |
28 | def lab2im_model(labels_shape,
29 | n_channels,
30 | generation_labels,
31 | output_labels,
32 | atlas_res,
33 | target_res,
34 | output_shape=None,
35 | output_div_by_n=None,
36 | blur_range=1.15):
37 | """
38 | This function builds a keras/tensorflow model to generate images from provided label maps.
39 | The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map.
40 | The model will take as inputs:
41 | -a label map
42 | -a vector containing the means of the Gaussian Mixture Model for each label,
43 | -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
44 | -an array of size batch*(n_dims+1)*(n_dims+1) representing a linear transformation
45 | The model returns:
46 | -the generated image normalised between 0 and 1.
47 | -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
48 | :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
49 | :param n_channels: number of channels to be synthetised.
50 | :param generation_labels: list of all possible label values in the input label maps.
51 | Can be a sequence or a 1d numpy array.
52 | :param output_labels: list of the same length as generation_labels to indicate which values to use in the label maps
53 | returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be converted to
54 | output_labels[i] in the returned label maps. Examples:
55 | Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps.
56 | Set output_labels[i]=generation_labels[i] if you wish to keep the value generation_labels[i] in the returned maps.
57 | Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels.
58 | :param atlas_res: resolution of the input label maps.
59 | Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
60 | :param target_res: target resolution of the generated images and corresponding label maps.
61 | Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
62 | :param output_shape: (optional) desired shape of the output images.
63 | If the atlas and target resolutions are the same, the output will be cropped to output_shape, and if the two
64 | resolutions are different, the output will be resized with trilinear interpolation to output_shape.
65 | Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
66 | :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
67 | if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
68 | :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
69 | or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
70 | from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
71 | """
72 |
73 | # reformat resolutions
74 | labels_shape = utils.reformat_to_list(labels_shape)
75 | n_dims, _ = utils.get_dims(labels_shape)
76 | atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0]
77 | target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0]
78 |
79 | # get shapes
80 | crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n)
81 |
82 | # define model inputs
83 | labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32')
84 | means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input')
85 | stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input')
86 |
87 | # deform labels
88 | labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input)
89 |
90 | # cropping
91 | if crop_shape != labels_shape:
92 | labels._keras_shape = tuple(labels.get_shape().as_list())
93 | labels = layers.RandomCrop(crop_shape)(labels)
94 |
95 | # build synthetic image
96 | labels._keras_shape = tuple(labels.get_shape().as_list())
97 | image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input])
98 |
99 | # apply bias field
100 | image._keras_shape = tuple(image.get_shape().as_list())
101 | image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image)
102 |
103 | # intensity augmentation
104 | image._keras_shape = tuple(image.get_shape().as_list())
105 | image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image)
106 |
107 | # blur image
108 | sigma = blurring_sigma_for_downsampling(atlas_res, target_res)
109 | image._keras_shape = tuple(image.get_shape().as_list())
110 | image = layers.GaussianBlur(sigma=sigma, random_blur_range=blur_range)(image)
111 |
112 | # resample to target res
113 | if crop_shape != output_shape:
114 | image = resample_tensor(image, output_shape, interp_method='linear')
115 | labels = resample_tensor(labels, output_shape, interp_method='nearest')
116 |
117 | # reset unwanted labels to zero
118 | labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels)
119 |
120 | # build model (dummy layer enables to keep the labels when plugging this model to other models)
121 | image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
122 | brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels])
123 |
124 | return brain_model
125 |
126 |
127 | def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n):
128 |
129 | n_dims = len(atlas_res)
130 |
131 | # get resampling factor
132 | if atlas_res.tolist() != target_res.tolist():
133 | resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)]
134 | else:
135 | resample_factor = None
136 |
137 | # output shape specified, need to get cropping shape, and resample shape if necessary
138 | if output_shape is not None:
139 | output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int')
140 |
141 | # make sure that output shape is smaller or equal to label shape
142 | if resample_factor is not None:
143 | output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)]
144 | else:
145 | output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)]
146 |
147 | # make sure output shape is divisible by output_div_by_n
148 | if output_div_by_n is not None:
149 | tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n)
150 | for s in output_shape]
151 | if output_shape != tmp_shape:
152 | print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n,
153 | tmp_shape))
154 | output_shape = tmp_shape
155 |
156 | # get cropping and resample shape
157 | if resample_factor is not None:
158 | cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)]
159 | else:
160 | cropping_shape = output_shape
161 |
162 | # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n
163 | else:
164 | cropping_shape = labels_shape
165 | if resample_factor is not None:
166 | output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)]
167 | else:
168 | output_shape = cropping_shape
169 | # make sure output shape is divisible by output_div_by_n
170 | if output_div_by_n is not None:
171 | output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer')
172 | for s in output_shape]
173 |
174 | return cropping_shape, output_shape
175 |
--------------------------------------------------------------------------------
/ext/neuron/__init__.py:
--------------------------------------------------------------------------------
1 | from . import layers
2 | from . import models
3 | from . import utils
4 |
--------------------------------------------------------------------------------
/models/synthseg_1.0.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BBillot/SynthSeg/2a2aa3bbfccb83f8253a51ca8b329b9938a2646d/models/synthseg_1.0.h5
--------------------------------------------------------------------------------
/requirements_python3.6.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | astor==0.8.1
3 | cached-property==1.5.2
4 | cachetools==4.2.4
5 | certifi==2022.12.7
6 | charset-normalizer==2.0.12
7 | cycler==0.11.0
8 | dataclasses==0.8
9 | gast==0.2.2
10 | google-auth==1.35.0
11 | google-auth-oauthlib==0.4.6
12 | google-pasta==0.2.0
13 | grpcio==1.48.2
14 | h5py==2.10.0
15 | idna==3.4
16 | importlib-metadata==4.8.3
17 | Keras==2.3.1
18 | Keras-Applications==1.0.8
19 | Keras-Preprocessing==1.1.2
20 | kiwisolver==1.3.1
21 | Markdown==3.3.7
22 | matplotlib==3.3.4
23 | nibabel==3.2.2
24 | numpy==1.19.5
25 | oauthlib==3.2.2
26 | opt-einsum==3.3.0
27 | packaging==21.3
28 | Pillow==8.4.0
29 | protobuf==3.19.6
30 | pyasn1==0.4.8
31 | pyasn1-modules==0.2.8
32 | pyparsing==3.0.9
33 | python-dateutil==2.8.2
34 | PyYAML==6.0
35 | requests==2.27.1
36 | requests-oauthlib==1.3.1
37 | rsa==4.9
38 | scipy==1.5.4
39 | six==1.16.0
40 | tensorboard==2.0.2
41 | tensorflow-estimator==2.0.1
42 | tensorflow-gpu==2.0.0
43 | termcolor==1.1.0
44 | typing_extensions==4.1.1
45 | urllib3==1.26.15
46 | Werkzeug==2.0.3
47 | wrapt==1.15.0
48 | zipp==3.6.0
49 |
--------------------------------------------------------------------------------
/requirements_python3.8.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | astunparse==1.6.3
3 | cachetools==4.2.4
4 | certifi==2022.12.7
5 | charset-normalizer==3.1.0
6 | contourpy==1.0.7
7 | cycler==0.11.0
8 | fonttools==4.39.2
9 | gast==0.3.3
10 | google-auth==1.35.0
11 | google-auth-oauthlib==0.4.6
12 | google-pasta==0.2.0
13 | grpcio==1.51.3
14 | h5py==2.10.0
15 | idna==3.4
16 | importlib-metadata==6.0.0
17 | Keras==2.3.1
18 | Keras-Applications==1.0.8
19 | Keras-Preprocessing==1.1.2
20 | kiwisolver==1.4.4
21 | Markdown==3.4.1
22 | MarkupSafe==2.1.2
23 | matplotlib==3.6.2
24 | nibabel==5.0.1
25 | numpy==1.23.5
26 | oauthlib==3.2.2
27 | opt-einsum==3.3.0
28 | packaging==23.0
29 | Pillow==9.4.0
30 | protobuf==3.20.3
31 | pyasn1==0.4.8
32 | pyasn1-modules==0.2.8
33 | pyparsing==3.0.9
34 | python-dateutil==2.8.2
35 | PyYAML==6.0
36 | requests==2.28.2
37 | requests-oauthlib==1.3.1
38 | rsa==4.9
39 | scipy==1.4.1
40 | six==1.16.0
41 | tensorboard==2.2.2
42 | tensorboard-plugin-wit==1.8.1
43 | tensorflow-estimator==2.2.0
44 | tensorflow-gpu==2.2.0
45 | termcolor==2.2.0
46 | urllib3==1.26.15
47 | Werkzeug==2.2.3
48 | wrapt==1.15.0
49 | zipp==3.15.0
50 |
--------------------------------------------------------------------------------
/scripts/commands/SynthSeg_predict.py:
--------------------------------------------------------------------------------
1 | """
2 | This script enables to launch predictions with SynthSeg from the terminal.
3 |
4 | If you use this code, please cite one of the SynthSeg papers:
5 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
6 |
7 | Copyright 2020 Benjamin Billot
8 |
9 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
10 | compliance with the License. You may obtain a copy of the License at
11 | https://www.apache.org/licenses/LICENSE-2.0
12 | Unless required by applicable law or agreed to in writing, software distributed under the License is
13 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14 | implied. See the License for the specific language governing permissions and limitations under the
15 | License.
16 | """
17 |
18 | # python imports
19 | import os
20 | import sys
21 | from argparse import ArgumentParser
22 |
23 | # add main folder to python path and import ./SynthSeg/predict_synthseg.py
24 | synthseg_home = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))))
25 | sys.path.append(synthseg_home)
26 | model_dir = os.path.join(synthseg_home, 'models')
27 | labels_dir = os.path.join(synthseg_home, 'data/labels_classes_priors')
28 | from SynthSeg.predict_synthseg import predict
29 |
30 |
31 | # parse arguments
32 | parser = ArgumentParser(description="SynthSeg", epilog='\n')
33 |
34 | # input/outputs
35 | parser.add_argument("--i", help="Image(s) to segment. Can be a path to an image or to a folder.")
36 | parser.add_argument("--o", help="Segmentation output(s). Must be a folder if --i designates a folder.")
37 | parser.add_argument("--parc", action="store_true", help="(optional) Whether to perform cortex parcellation.")
38 | parser.add_argument("--robust", action="store_true", help="(optional) Whether to use robust predictions (slower).")
39 | parser.add_argument("--fast", action="store_true", help="(optional) Bypass some postprocessing for faster predictions.")
40 | parser.add_argument("--ct", action="store_true", help="(optional) Clip intensities to [0,80] for CT scans.")
41 | parser.add_argument("--vol", help="(optional) Path to output CSV file with volumes (mm3) for all regions and subjects.")
42 | parser.add_argument("--qc", help="(optional) Path to output CSV file with qc scores for all subjects.")
43 | parser.add_argument("--post", help="(optional) Posteriors output(s). Must be a folder if --i designates a folder.")
44 | parser.add_argument("--resample", help="(optional) Resampled image(s). Must be a folder if --i designates a folder.")
45 | parser.add_argument("--crop", nargs='+', type=int, help="(optional) Size of 3D patches to analyse. Default is 192.")
46 | parser.add_argument("--threads", type=int, default=1, help="(optional) Number of cores to be used. Default is 1.")
47 | parser.add_argument("--cpu", action="store_true", help="(optional) Enforce running with CPU rather than GPU.")
48 | parser.add_argument("--v1", action="store_true", help="(optional) Use SynthSeg 1.0 (updated 25/06/22).")
49 |
50 | # check for no arguments
51 | if len(sys.argv) < 2:
52 | parser.print_help()
53 | sys.exit(1)
54 |
55 | # parse commandline
56 | args = vars(parser.parse_args())
57 |
58 | # print SynthSeg version and checks boolean params for SynthSeg-robust
59 | if args['robust']:
60 | args['fast'] = True
61 | assert not args['v1'], 'The flag --v1 cannot be used with --robust since SynthSeg-robust only came out with 2.0.'
62 | version = 'SynthSeg-robust 2.0'
63 | else:
64 | version = 'SynthSeg 1.0' if args['v1'] else 'SynthSeg 2.0'
65 | if args['fast']:
66 | version += ' (fast)'
67 | print('\n' + version + '\n')
68 |
69 | # enforce CPU processing if necessary
70 | if args['cpu']:
71 | print('using CPU, hiding all CUDA_VISIBLE_DEVICES')
72 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
73 |
74 | # limit the number of threads to be used if running on CPU
75 | import tensorflow as tf
76 | if args['threads'] == 1:
77 | print('using 1 thread')
78 | else:
79 | print('using %s threads' % args['threads'])
80 | tf.config.threading.set_inter_op_parallelism_threads(args['threads'])
81 | tf.config.threading.set_intra_op_parallelism_threads(args['threads'])
82 |
83 | # path models
84 | if args['robust']:
85 | args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_robust_2.0.h5')
86 | else:
87 | args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_2.0.h5')
88 | args['path_model_parcellation'] = os.path.join(model_dir, 'synthseg_parc_2.0.h5')
89 | args['path_model_qc'] = os.path.join(model_dir, 'synthseg_qc_2.0.h5')
90 |
91 | # path labels
92 | args['labels_segmentation'] = os.path.join(labels_dir, 'synthseg_segmentation_labels_2.0.npy')
93 | args['labels_denoiser'] = os.path.join(labels_dir, 'synthseg_denoiser_labels_2.0.npy')
94 | args['labels_parcellation'] = os.path.join(labels_dir, 'synthseg_parcellation_labels.npy')
95 | args['labels_qc'] = os.path.join(labels_dir, 'synthseg_qc_labels_2.0.npy')
96 | args['names_segmentation_labels'] = os.path.join(labels_dir, 'synthseg_segmentation_names_2.0.npy')
97 | args['names_parcellation_labels'] = os.path.join(labels_dir, 'synthseg_parcellation_names.npy')
98 | args['names_qc_labels'] = os.path.join(labels_dir, 'synthseg_qc_names_2.0.npy')
99 | args['topology_classes'] = os.path.join(labels_dir, 'synthseg_topological_classes_2.0.npy')
100 | args['n_neutral_labels'] = 19
101 |
102 | # use previous model if needed
103 | if args['v1']:
104 | args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_1.0.h5')
105 | args['labels_segmentation'] = args['labels_segmentation'].replace('_2.0.npy', '.npy')
106 | args['labels_qc'] = args['labels_qc'].replace('_2.0.npy', '.npy')
107 | args['names_segmentation_labels'] = args['names_segmentation_labels'].replace('_2.0.npy', '.npy')
108 | args['names_qc_labels'] = args['names_qc_labels'].replace('_2.0.npy', '.npy')
109 | args['topology_classes'] = args['topology_classes'].replace('_2.0.npy', '.npy')
110 | args['n_neutral_labels'] = 18
111 |
112 | # run prediction
113 | predict(path_images=args['i'],
114 | path_segmentations=args['o'],
115 | path_model_segmentation=args['path_model_segmentation'],
116 | labels_segmentation=args['labels_segmentation'],
117 | robust=args['robust'],
118 | fast=args['fast'],
119 | v1=args['v1'],
120 | do_parcellation=args['parc'],
121 | n_neutral_labels=args['n_neutral_labels'],
122 | names_segmentation=args['names_segmentation_labels'],
123 | labels_denoiser=args['labels_denoiser'],
124 | path_posteriors=args['post'],
125 | path_resampled=args['resample'],
126 | path_volumes=args['vol'],
127 | path_model_parcellation=args['path_model_parcellation'],
128 | labels_parcellation=args['labels_parcellation'],
129 | names_parcellation=args['names_parcellation_labels'],
130 | path_model_qc=args['path_model_qc'],
131 | labels_qc=args['labels_qc'],
132 | path_qc_scores=args['qc'],
133 | names_qc=args['names_qc_labels'],
134 | cropping=args['crop'],
135 | topology_classes=args['topology_classes'],
136 | ct=args['ct'])
137 |
--------------------------------------------------------------------------------
/scripts/commands/predict.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # imports
18 | from argparse import ArgumentParser
19 | from SynthSeg.predict import predict
20 |
21 | parser = ArgumentParser()
22 |
23 | # Positional arguments
24 | parser.add_argument("path_images", type=str, help="path single image or path of the folders with training labels")
25 | parser.add_argument("path_segmentations", type=str, help="segmentations folder/path")
26 | parser.add_argument("path_model", type=str, help="model file path")
27 |
28 | # labels parameters
29 | parser.add_argument("labels_segmentation", type=str, help="path label list")
30 | parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None)
31 | parser.add_argument("--names_list", type=str, dest="names_segmentation", default=None,
32 | help="path list of label names, only used if --vol is specified")
33 |
34 | # Saving paths
35 | parser.add_argument("--post", type=str, dest="path_posteriors", default=None, help="posteriors folder/path")
36 | parser.add_argument("--resampled", type=str, dest="path_resampled", default=None,
37 | help="path/folder of the images resampled at the given target resolution")
38 | parser.add_argument("--vol", type=str, dest="path_volumes", default=None, help="path volume file")
39 |
40 | # Processing parameters
41 | parser.add_argument("--min_pad", type=int, dest="min_pad", default=None,
42 | help="margin of the padding")
43 | parser.add_argument("--cropping", type=int, dest="cropping", default=None,
44 | help="crop volume before processing. Segmentations will have the same size as input image.")
45 | parser.add_argument("--target_res", type=float, dest="target_res", default=1.,
46 | help="Target resolution at which segmentations will be given.")
47 | parser.add_argument("--flip", action='store_true', dest="flip",
48 | help="to activate test-time augmentation (right/left flipping)")
49 | parser.add_argument("--topology_classes", type=str, dest="topology_classes", default=None,
50 | help="path list of classes, for topologically enhanced biggest connected component analysis")
51 | parser.add_argument("--smoothing", type=float, dest="sigma_smoothing", default=0.5,
52 | help="var for gaussian blurring of the posteriors")
53 | parser.add_argument("--biggest_component", action='store_true', dest="keep_biggest_component",
54 | help="only keep biggest component in segmentation (recommended)")
55 |
56 | # Architecture parameters
57 | parser.add_argument("--conv_size", type=int, dest="conv_size", default=3, help="size of unet convolution masks")
58 | parser.add_argument("--n_levels", type=int, dest="n_levels", default=5, help="number of levels for unet")
59 | parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2, help="conv par level")
60 | parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24,
61 | help="number of features of unet first layer")
62 | parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2,
63 | help="factor of new feature maps per level")
64 | parser.add_argument("--activation", type=str, dest="activation", default='elu', help="activation function")
65 |
66 | # Evaluation parameters
67 | parser.add_argument("--gt", type=str, default=None, dest="gt_folder",
68 | help="folder containing ground truth segmentations, which triggers the evaluation.")
69 | parser.add_argument("--eval_label_list", type=str, dest="evaluation_labels", default=None,
70 | help="labels to evaluate Dice scores on if gt is provided. Default is the same as label_list.")
71 | parser.add_argument("--incorrect_labels", type=str, default=None, dest="list_incorrect_labels",
72 | help="path list labels to correct.")
73 | parser.add_argument("--correct_labels", type=str, default=None, dest="list_correct_labels",
74 | help="path list correct labels.")
75 |
76 | args = parser.parse_args()
77 | predict(**vars(args))
78 |
--------------------------------------------------------------------------------
/scripts/commands/training.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | from argparse import ArgumentParser
18 | from SynthSeg.training import training
19 | from ext.lab2im.utils import infer
20 |
21 | parser = ArgumentParser()
22 |
23 | # ------------------------------------------------- General parameters -------------------------------------------------
24 | # Positional arguments
25 | parser.add_argument("labels_dir", type=str)
26 | parser.add_argument("model_dir", type=str)
27 |
28 | # ---------------------------------------------- Generation parameters ----------------------------------------------
29 | # label maps parameters
30 | parser.add_argument("--generation_labels", type=str, dest="generation_labels", default=None)
31 | parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None)
32 | parser.add_argument("--segmentation_labels", type=str, dest="segmentation_labels", default=None)
33 | parser.add_argument("--subjects_prob", type=str, dest="subjects_prob", default=None)
34 |
35 | # output-related parameters
36 | parser.add_argument("--batch_size", type=int, dest="batchsize", default=1)
37 | parser.add_argument("--channels", type=int, dest="n_channels", default=1)
38 | parser.add_argument("--target_res", type=float, dest="target_res", default=None)
39 | parser.add_argument("--output_shape", type=int, dest="output_shape", default=None)
40 |
41 | # GMM-sampling parameters
42 | parser.add_argument("--generation_classes", type=str, dest="generation_classes", default=None)
43 | parser.add_argument("--prior_type", type=str, dest="prior_distributions", default='uniform')
44 | parser.add_argument("--prior_means", type=str, dest="prior_means", default=None)
45 | parser.add_argument("--prior_stds", type=str, dest="prior_stds", default=None)
46 | parser.add_argument("--specific_stats", action='store_true', dest="use_specific_stats_for_channel")
47 | parser.add_argument("--mix_prior_and_random", action='store_true', dest="mix_prior_and_random")
48 |
49 | # spatial deformation parameters
50 | parser.add_argument("--no_flipping", action='store_false', dest="flipping")
51 | parser.add_argument("--scaling", dest="scaling_bounds", type=infer, default=0.2)
52 | parser.add_argument("--rotation", dest="rotation_bounds", type=infer, default=15)
53 | parser.add_argument("--shearing", dest="shearing_bounds", type=infer, default=.012)
54 | parser.add_argument("--translation", dest="translation_bounds", type=infer, default=False)
55 | parser.add_argument("--nonlin_std", type=float, dest="nonlin_std", default=4.)
56 | parser.add_argument("--nonlin_scale", type=float, dest="nonlin_scale", default=.04)
57 |
58 | # blurring/resampling parameters
59 | parser.add_argument("--randomise_res", action='store_true', dest="randomise_res")
60 | parser.add_argument("--max_res_iso", type=float, dest="max_res_iso", default=4.)
61 | parser.add_argument("--max_res_aniso", type=float, dest="max_res_aniso", default=8.)
62 | parser.add_argument("--data_res", dest="data_res", type=infer, default=None)
63 | parser.add_argument("--thickness", dest="thickness", type=infer, default=None)
64 |
65 | # bias field parameters
66 | parser.add_argument("--bias_std", type=float, dest="bias_field_std", default=.7)
67 | parser.add_argument("--bias_scale", type=float, dest="bias_scale", default=.025)
68 |
69 | parser.add_argument("--gradients", action='store_true', dest="return_gradients")
70 |
71 | # -------------------------------------------- UNet architecture parameters --------------------------------------------
72 | parser.add_argument("--n_levels", type=int, dest="n_levels", default=5)
73 | parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2)
74 | parser.add_argument("--conv_size", type=int, dest="conv_size", default=3)
75 | parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24)
76 | parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2)
77 | parser.add_argument("--activation", type=str, dest="activation", default='elu')
78 |
79 | # ------------------------------------------------- Training parameters ------------------------------------------------
80 | parser.add_argument("--lr", type=float, dest="lr", default=1e-4)
81 | parser.add_argument("--wl2_epochs", type=int, dest="wl2_epochs", default=1)
82 | parser.add_argument("--dice_epochs", type=int, dest="dice_epochs", default=50)
83 | parser.add_argument("--steps_per_epoch", type=int, dest="steps_per_epoch", default=10000)
84 | parser.add_argument("--checkpoint", type=str, dest="checkpoint", default=None)
85 |
86 | args = parser.parse_args()
87 | training(**vars(args))
88 |
--------------------------------------------------------------------------------
/scripts/commands/training_supervised.py:
--------------------------------------------------------------------------------
1 | """
2 | If you use this code, please cite one of the SynthSeg papers:
3 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4 |
5 | Copyright 2020 Benjamin Billot
6 |
7 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8 | compliance with the License. You may obtain a copy of the License at
9 | https://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software distributed under the License is
11 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12 | implied. See the License for the specific language governing permissions and limitations under the
13 | License.
14 | """
15 |
16 |
17 | # imports
18 | from argparse import ArgumentParser
19 | from SynthSeg.training_supervised import training
20 | from ext.lab2im.utils import infer
21 |
22 | parser = ArgumentParser()
23 |
24 | # ------------------------------------------------- General parameters -------------------------------------------------
25 | # Positional arguments
26 | parser.add_argument("image_dir", type=str)
27 | parser.add_argument("labels_dir", type=str)
28 | parser.add_argument("model_dir", type=str)
29 |
30 | # label maps parameters
31 | parser.add_argument("--segmentation_labels", type=str, dest="segmentation_labels", default=None)
32 | parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None)
33 | parser.add_argument("--subjects_prob", type=str, dest="subjects_prob", default=None)
34 |
35 | # output-related parameters
36 | parser.add_argument("--batch_size", type=int, dest="batchsize", default=1)
37 | parser.add_argument("--target_res", type=int, dest="target_res", default=None)
38 | parser.add_argument("--output_shape", type=int, dest="output_shape", default=None)
39 |
40 | # ----------------------------------------------- Augmentation parameters ----------------------------------------------
41 | # spatial deformation parameters
42 | parser.add_argument("--no_flipping", action='store_false', dest="flipping")
43 | parser.add_argument("--scaling", dest="scaling_bounds", type=infer, default=.2)
44 | parser.add_argument("--rotation", dest="rotation_bounds", type=infer, default=15)
45 | parser.add_argument("--shearing", dest="shearing_bounds", type=infer, default=.012)
46 | parser.add_argument("--translation", dest="translation_bounds", type=infer, default=False)
47 | parser.add_argument("--nonlin_std", type=float, dest="nonlin_std", default=4.)
48 | parser.add_argument("--nonlin_scale", type=float, dest="nonlin_scale", default=.04)
49 |
50 | # resampling parameters
51 | parser.add_argument("--randomise_res", action='store_true', dest="randomise_res")
52 | parser.add_argument("--max_res_iso", type=float, dest="max_res_iso", default=4.)
53 | parser.add_argument("--max_res_aniso", type=float, dest="max_res_aniso", default=8.)
54 | parser.add_argument("--data_res", dest="data_res", type=infer, default=None)
55 | parser.add_argument("--thickness", dest="thickness", type=infer, default=None)
56 |
57 | # bias field parameters
58 | parser.add_argument("--bias_std", type=float, dest="bias_field_std", default=.7)
59 | parser.add_argument("--bias_scale", type=float, dest="bias_scale", default=.025)
60 |
61 | parser.add_argument("--gradients", action='store_true', dest="return_gradients")
62 |
63 | # -------------------------------------------- UNet architecture parameters --------------------------------------------
64 | parser.add_argument("--n_levels", type=int, dest="n_levels", default=5)
65 | parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2)
66 | parser.add_argument("--conv_size", type=int, dest="conv_size", default=3)
67 | parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24)
68 | parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2)
69 | parser.add_argument("--activation", type=str, dest="activation", default='elu')
70 |
71 | # ------------------------------------------------- Training parameters ------------------------------------------------
72 | parser.add_argument("--lr", type=float, dest="lr", default=1e-4)
73 | parser.add_argument("--wl2_epochs", type=int, dest="wl2_epochs", default=1)
74 | parser.add_argument("--dice_epochs", type=int, dest="dice_epochs", default=50)
75 | parser.add_argument("--steps_per_epoch", type=int, dest="steps_per_epoch", default=10000)
76 | parser.add_argument("--checkpoint", type=str, dest="checkpoint", default=None)
77 |
78 | args = parser.parse_args()
79 | training(**vars(args))
80 |
--------------------------------------------------------------------------------
/scripts/tutorials/1-generation_visualisation.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Very simple script to generate an example of the synthetic data used to train SynthSeg.
4 | This is for visualisation purposes, since it uses all the default parameters.
5 |
6 |
7 |
8 | If you use this code, please cite one of the SynthSeg papers:
9 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
10 |
11 | Copyright 2020 Benjamin Billot
12 |
13 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
14 | compliance with the License. You may obtain a copy of the License at
15 | https://www.apache.org/licenses/LICENSE-2.0
16 | Unless required by applicable law or agreed to in writing, software distributed under the License is
17 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
18 | implied. See the License for the specific language governing permissions and limitations under the
19 | License.
20 | """
21 |
22 |
23 | from ext.lab2im import utils
24 | from SynthSeg.brain_generator import BrainGenerator
25 |
26 | # generate an image from the label map.
27 | brain_generator = BrainGenerator('../../data/training_label_maps/training_seg_01.nii.gz')
28 | im, lab = brain_generator.generate_brain()
29 |
30 | # save output image and label map under SynthSeg/generated_examples
31 | utils.save_volume(im, brain_generator.aff, brain_generator.header, './outputs_tutorial_1/image.nii.gz')
32 | utils.save_volume(lab, brain_generator.aff, brain_generator.header, './outputs_tutorial_1/labels.nii.gz')
33 |
--------------------------------------------------------------------------------
/scripts/tutorials/2-generation_explained.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | This script explains how the different parameters controlling the generation of the synthetic data.
4 | These parameters will be reused in the training function, but we describe them here, as the synthetic images are saved,
5 | and thus can be visualised.
6 | Note that most of the parameters here are set to their default value, but we show them nonetheless, just to explain
7 | their effect. Moreover, we encourage the user to play with them to get a sense of their impact on the generation.
8 |
9 |
10 |
11 | If you use this code, please cite one of the SynthSeg papers:
12 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
13 |
14 | Copyright 2020 Benjamin Billot
15 |
16 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
17 | compliance with the License. You may obtain a copy of the License at
18 | https://www.apache.org/licenses/LICENSE-2.0
19 | Unless required by applicable law or agreed to in writing, software distributed under the License is
20 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
21 | implied. See the License for the specific language governing permissions and limitations under the
22 | License.
23 | """
24 |
25 |
26 | import os
27 | from ext.lab2im import utils
28 | from SynthSeg.brain_generator import BrainGenerator
29 |
30 | # script parameters
31 | n_examples = 5 # number of examples to generate in this script
32 | result_dir = './outputs_tutorial_2' # folder where examples will be saved
33 |
34 |
35 | # ---------- Input label maps and associated values ----------
36 |
37 | # folder containing label maps to generate images from (note that they must have a ".nii", ".nii.gz" or ".mgz" format)
38 | path_label_map = '../../data/training_label_maps'
39 |
40 | # Here we specify the structures in the label maps for which we want to generate intensities.
41 | # This is given as a list of label values, which do not necessarily need to be present in every label map.
42 | # However, these labels must follow a specific order: first the background, and then all the other labels. Moreover, if
43 | # 1) the label maps contain some right/left-specific label values, and 2) we activate flipping augmentation (which is
44 | # true by default), then the rest of the labels must follow a strict order:
45 | # first the non-sided labels (i.e. those which are not right/left specific), then all the left labels, and finally the
46 | # corresponding right labels (in the same order as the left ones). Please make sure each that each sided label has a
47 | # right and a left value (this is essential!!!).
48 | #
49 | # Example: generation_labels = [0, # background
50 | # 24, # CSF
51 | # 507, # extra-cerebral soft tissues
52 | # 2, # left white matter
53 | # 3, # left cerebral cortex
54 | # 4, # left lateral ventricle
55 | # 17, # left hippocampus
56 | # 25, # left lesions
57 | # 41, # right white matter
58 | # 42, # right cerebral cortex
59 | # 43, # right lateral ventricle
60 | # 53, # right hippocampus
61 | # 57] # right lesions
62 | # Note that plenty of structures are not represented here..... but it's just an example ! :)
63 | generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
64 |
65 |
66 | # We also have to specify the number of non-sided labels in order to differentiate them from the labels with
67 | # right/left values.
68 | # Example: (continuing the previous one): in this example it would be 3 (background, CSF, extra-cerebral soft tissues).
69 | n_neutral_labels = 18
70 |
71 | # By default, the output label maps (i.e. the target segmentations) contain all the labels used for generation.
72 | # However, we may want not to predict all the generation labels (e.g. extra-cerebral soft tissues).
73 | # For this reason, we specify here the target segmentation label corresponding to every generation structure.
74 | # This new list must have the same length as generation_labels, and follow the same order.
75 | #
76 | # Example: (continuing the previous one) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57]
77 | # output_labels = [0, 0, 0, 2, 3, 4, 17, 2, 41, 42, 43, 53, 41]
78 | # Note that in this example the labels 24 (CSF), and 507 (extra-cerebral soft tissues) are not predicted, or said
79 | # differently they are segmented as background.
80 | # Also, the left and right lesions (labels 25 and 57) are segmented as left and right white matter (labels 2 and 41).
81 | output_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
82 |
83 |
84 | # ---------- Shape and resolution of the outputs ----------
85 |
86 | # number of channel to synthesise for multi-modality settings. Set this to 1 (default) in the uni-modality scenario.
87 | n_channels = 1
88 |
89 | # We have the possibility to generate training examples at a different resolution than the training label maps (e.g.
90 | # when using ultra HR training label maps). Here we want to generate at the same resolution as the training label maps,
91 | # so we set this to None.
92 | target_res = None
93 |
94 | # The generative model offers the possibility to randomly crop the training examples to a given size.
95 | # Here we crop them to 160^3, such that the produced images fit on the GPU during training.
96 | output_shape = 160
97 |
98 |
99 | # ---------- GMM sampling parameters ----------
100 |
101 | # Here we use uniform prior distribution to sample the means/stds of the GMM. Because we don't specify prior_means and
102 | # prior_stds, those priors will have default bounds of [0, 250], and [0, 35]. Those values enable to generate a wide
103 | # range of contrasts (often unrealistic), which will make the segmentation network contrast-agnostic.
104 | prior_distributions = 'uniform'
105 |
106 | # We regroup labels with similar tissue types into K "classes", so that intensities of similar regions are sampled
107 | # from the same Gaussian distribution. This is achieved by providing a list indicating the class of each label.
108 | # It should have the same length as generation_labels, and follow the same order. Importantly the class values must be
109 | # between 0 and K-1, where K is the total number of different classes.
110 | #
111 | # Example: (continuing the previous one) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57]
112 | # generation_classes = [0, 1, 2, 3, 4, 5, 4, 6, 7, 8, 9, 8, 10]
113 | # In this example labels 3 and 17 are in the same *class* 4 (that has nothing to do with *label* 4), and thus will be
114 | # associated to the same Gaussian distribution when sampling the GMM.
115 | generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
116 |
117 |
118 | # ---------- Spatial augmentation ----------
119 |
120 | # We now introduce some parameters concerning the spatial deformation. They enable to set the range of the uniform
121 | # distribution from which the corresponding parameters are selected.
122 | # We note that because the label maps will be resampled with nearest neighbour interpolation, they can look less smooth
123 | # than the original segmentations.
124 |
125 | flipping = True # enable right/left flipping
126 | scaling_bounds = 0.2 # the scaling coefficients will be sampled from U(1-scaling_bounds; 1+scaling_bounds)
127 | rotation_bounds = 15 # the rotation angles will be sampled from U(-rotation_bounds; rotation_bounds)
128 | shearing_bounds = 0.012 # the shearing coefficients will be sampled from U(-shearing_bounds; shearing_bounds)
129 | translation_bounds = False # no translation is performed, as this is already modelled by the random cropping
130 | nonlin_std = 4. # this controls the maximum elastic deformation (higher = more deformation)
131 | bias_field_std = 0.7 # this controls the maximum bias field corruption (higher = more bias)
132 |
133 |
134 | # ---------- Resolution parameters ----------
135 |
136 | # This enables us to randomise the resolution of the produces images.
137 | # Although being only one parameter, this is crucial !!
138 | randomise_res = True
139 |
140 |
141 | # ------------------------------------------------------ Generate ------------------------------------------------------
142 |
143 | # instantiate BrainGenerator object
144 | brain_generator = BrainGenerator(labels_dir=path_label_map,
145 | generation_labels=generation_labels,
146 | n_neutral_labels=n_neutral_labels,
147 | prior_distributions=prior_distributions,
148 | generation_classes=generation_classes,
149 | output_labels=output_labels,
150 | n_channels=n_channels,
151 | target_res=target_res,
152 | output_shape=output_shape,
153 | flipping=flipping,
154 | scaling_bounds=scaling_bounds,
155 | rotation_bounds=rotation_bounds,
156 | shearing_bounds=shearing_bounds,
157 | translation_bounds=translation_bounds,
158 | nonlin_std=nonlin_std,
159 | bias_field_std=bias_field_std,
160 | randomise_res=randomise_res)
161 |
162 | for n in range(n_examples):
163 |
164 | # generate new image and corresponding labels
165 | im, lab = brain_generator.generate_brain()
166 |
167 | # save output image and label map
168 | utils.save_volume(im, brain_generator.aff, brain_generator.header,
169 | os.path.join(result_dir, 'image_%s.nii.gz' % n))
170 | utils.save_volume(lab, brain_generator.aff, brain_generator.header,
171 | os.path.join(result_dir, 'labels_%s.nii.gz' % n))
172 |
--------------------------------------------------------------------------------
/scripts/tutorials/3-training.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | This script shows how we trained SynthSeg.
4 | Importantly, it reuses numerous parameters seen in the previous tutorial about image generation
5 | (i.e., 2-generation_explained.py), which we strongly recommend reading before this one.
6 |
7 |
8 |
9 | If you use this code, please cite one of the SynthSeg papers:
10 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
11 |
12 | Copyright 2020 Benjamin Billot
13 |
14 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
15 | compliance with the License. You may obtain a copy of the License at
16 | https://www.apache.org/licenses/LICENSE-2.0
17 | Unless required by applicable law or agreed to in writing, software distributed under the License is
18 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
19 | implied. See the License for the specific language governing permissions and limitations under the
20 | License.
21 | """
22 |
23 |
24 | # project imports
25 | from SynthSeg.training import training
26 |
27 |
28 | # path training label maps
29 | path_training_label_maps = '../../data/training_label_maps'
30 | path_model_dir = './outputs_tutorial_3/'
31 | batchsize = 1
32 |
33 | # architecture parameters
34 | n_levels = 5 # number of resolution levels
35 | nb_conv_per_level = 2 # number of convolution per level
36 | conv_size = 3 # size of the convolution kernel (e.g. 3x3x3)
37 | unet_feat_count = 24 # number of feature maps after the first convolution
38 | activation = 'elu' # activation for all convolution layers except the last, which will use softmax regardless
39 | feat_multiplier = 2 # if feat_multiplier is set to 1, we will keep the number of feature maps constant throughout the
40 | # network; 2 will double them(resp. half) after each max-pooling (resp. upsampling);
41 | # 3 will triple them, etc.
42 |
43 | # training parameters
44 | lr = 1e-4 # learning rate
45 | wl2_epochs = 1 # number of pre-training epochs with wl2 metric w.r.t. the layer before the softmax
46 | dice_epochs = 100 # number of training epochs
47 | steps_per_epoch = 5000 # number of iteration per epoch
48 |
49 |
50 | # ---------- Generation parameters ----------
51 | # these parameters are from the previous tutorial, and thus we do not explain them again here
52 |
53 | # generation and segmentation labels
54 | path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
55 | n_neutral_labels = 18
56 | path_segmentation_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
57 |
58 | # shape and resolution of the outputs
59 | target_res = None
60 | output_shape = 160
61 | n_channels = 1
62 |
63 | # GMM sampling
64 | prior_distributions = 'uniform'
65 | path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
66 |
67 | # spatial deformation parameters
68 | flipping = True
69 | scaling_bounds = .2
70 | rotation_bounds = 15
71 | shearing_bounds = .012
72 | translation_bounds = False
73 | nonlin_std = 4.
74 | bias_field_std = .7
75 |
76 | # acquisition resolution parameters
77 | randomise_res = True
78 |
79 | # ------------------------------------------------------ Training ------------------------------------------------------
80 |
81 | training(path_training_label_maps,
82 | path_model_dir,
83 | generation_labels=path_generation_labels,
84 | segmentation_labels=path_segmentation_labels,
85 | n_neutral_labels=n_neutral_labels,
86 | batchsize=batchsize,
87 | n_channels=n_channels,
88 | target_res=target_res,
89 | output_shape=output_shape,
90 | prior_distributions=prior_distributions,
91 | generation_classes=path_generation_classes,
92 | flipping=flipping,
93 | scaling_bounds=scaling_bounds,
94 | rotation_bounds=rotation_bounds,
95 | shearing_bounds=shearing_bounds,
96 | translation_bounds=translation_bounds,
97 | nonlin_std=nonlin_std,
98 | randomise_res=randomise_res,
99 | bias_field_std=bias_field_std,
100 | n_levels=n_levels,
101 | nb_conv_per_level=nb_conv_per_level,
102 | conv_size=conv_size,
103 | unet_feat_count=unet_feat_count,
104 | feat_multiplier=feat_multiplier,
105 | activation=activation,
106 | lr=lr,
107 | wl2_epochs=wl2_epochs,
108 | dice_epochs=dice_epochs,
109 | steps_per_epoch=steps_per_epoch)
110 |
--------------------------------------------------------------------------------
/scripts/tutorials/4-prediction.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | This script shows how to perform inference after having trained your own model.
4 | Importantly, it reuses some of the parameters used in tutorial 3-training.
5 | Moreover, we emphasise that this tutorial explains how to perform inference on your own trained models.
6 | To predict segmentations based on the distributed mode for SynthSeg, please refer to the README.md file.
7 |
8 |
9 |
10 | If you use this code, please cite one of the SynthSeg papers:
11 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
12 |
13 | Copyright 2020 Benjamin Billot
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
16 | compliance with the License. You may obtain a copy of the License at
17 | https://www.apache.org/licenses/LICENSE-2.0
18 | Unless required by applicable law or agreed to in writing, software distributed under the License is
19 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
20 | implied. See the License for the specific language governing permissions and limitations under the
21 | License.
22 | """
23 |
24 | # project imports
25 | from SynthSeg.predict import predict
26 |
27 | # paths to input/output files
28 | # Here we assume the availability of an image that we wish to segment with a model we have just trained.
29 | # We emphasise that we do not provide such an image (this is just an example after all :))
30 | # Input images must have a .nii, .nii.gz, or .mgz extension.
31 | # Note that path_images can also be the path to an entire folder, in which case all the images within this folder will
32 | # be segmented. In this case, please provide path_segm (and possibly path_posteriors, and path_resampled) as folder.
33 | path_images = '/a/path/to/an/image/im.nii.gz'
34 | # path to the output segmentation
35 | path_segm = './outputs_tutorial_4/predicted_segmentations/im_seg.nii.gz'
36 | # we can also provide paths for optional files containing the probability map for all predicted labels
37 | path_posteriors = './outputs_tutorial_4/predicted_information/im_post.nii.gz'
38 | # and for a csv file that will contain the volumes of each segmented structure
39 | path_vol = './outputs_tutorial_4/predicted_information/volumes.csv'
40 |
41 | # of course we need to provide the path to the trained model (here we use the main synthseg model).
42 | path_model = '../../models/synthseg_1.0.h5'
43 | # but we also need to provide the path to the segmentation labels used during training
44 | path_segmentation_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
45 | # optionally we can give a numpy array with the names corresponding to the structures in path_segmentation_labels
46 | path_segmentation_names = '../../data/labels_classes_priors/synthseg_segmentation_names.npy'
47 |
48 | # We can now provide various parameters to control the preprocessing of the input.
49 | # First we can play with the size of the input. Remember that the size of input must be divisible by 2**n_levels, so the
50 | # input image will be automatically padded to the nearest shape divisible by 2**n_levels (this is just for processing,
51 | # the output will then be cropped to the original image size).
52 | # Alternatively, you can crop the input to a smaller shape for faster processing, or to make it fit on your GPU.
53 | cropping = 192
54 | # Finally, we finish preprocessing the input by resampling it to the resolution at which the network has been trained to
55 | # produce predictions. If the input image has a resolution outside the range [target_res-0.05, target_res+0.05], it will
56 | # automatically be resampled to target_res.
57 | target_res = 1.
58 | # Note that if the image is indeed resampled, you have the option to save the resampled image.
59 | path_resampled = './outputs_tutorial_4/predicted_information/im_resampled_target_res.nii.gz'
60 |
61 | # After the image has been processed by the network, there are again various options to postprocess it.
62 | # First, we can apply some test-time augmentation by flipping the input along the right-left axis and segmenting
63 | # the resulting image. In this case, and if the network has right/left specific labels, it is also very important to
64 | # provide the number of neutral labels. This must be the exact same as the one used during training.
65 | flip = True
66 | n_neutral_labels = 18
67 | # Second, we can smooth the probability maps produced by the network. This doesn't change much the results, but helps to
68 | # reduce high frequency noise in the obtained segmentations.
69 | sigma_smoothing = 0.5
70 | # Then we can operate some fancier version of biggest connected component, by regrouping structures within so-called
71 | # "topological classes". For each class we successively: 1) sum all the posteriors corresponding to the labels of this
72 | # class, 2) obtain a mask for this class by thresholding the summed posteriors by a low value (arbitrarily set to 0.1),
73 | # 3) keep the biggest connected component, and 4) individually apply the obtained mask to the posteriors of all the
74 | # labels for this class.
75 | # Example: (continuing the previous one) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57]
76 | # output_labels = [0, 0, 0, 2, 3, 4, 17, 2, 41, 42, 43, 53, 41]
77 | # topological_classes = [0, 0, 0, 1, 1, 2, 3, 1, 4, 4, 5, 6, 7]
78 | # Here we regroup labels 2 and 3 in the same topological class, same for labels 41 and 42. The topological class of
79 | # unsegmented structures must be set to 0 (like for 24 and 507).
80 | topology_classes = '../../data/labels_classes_priors/synthseg_topological_classes.npy'
81 | # Finally, we can also operate a strict version of biggest connected component, to get rid of unwanted noisy label
82 | # patch that can sometimes occur in the background. If so, we do recommend to use the smoothing option described above.
83 | keep_biggest_component = True
84 |
85 | # Regarding the architecture of the network, we must provide the predict function with the same parameters as during
86 | # training.
87 | n_levels = 5
88 | nb_conv_per_level = 2
89 | conv_size = 3
90 | unet_feat_count = 24
91 | activation = 'elu'
92 | feat_multiplier = 2
93 |
94 | # Finally, we can set up an evaluation step after all images have been segmented.
95 | # In this purpose, we need to provide the path to the ground truth corresponding to the input image(s).
96 | # This is done by using the "gt_folder" parameter, which must have the same type as path_images (i.e., the path to a
97 | # single image or to a folder). If provided as a folder, ground truths must be sorted in the same order as images in
98 | # path_images.
99 | # Just set this to None if you do not want to run evaluation.
100 | gt_folder = '/the/path/to/the/ground_truth/gt.nii.gz'
101 | # Dice scores will be computed and saved as a numpy array in the folder containing the segmentation(s).
102 | # This numpy array will be organised as follows: rows correspond to structures, and columns to subjects. Importantly,
103 | # rows are given in a sorted order.
104 | # Example: we segment 2 subjects, where output_labels = [0, 0, 0, 2, 3, 4, 17, 2, 41, 42, 43, 53, 41]
105 | # so sorted output_labels = [0, 2, 3, 4, 17, 41, 42, 43, 53]
106 | # dice = [[xxx, xxx], # scores for label 0
107 | # [xxx, xxx], # scores for label 2
108 | # [xxx, xxx], # scores for label 3
109 | # [xxx, xxx], # scores for label 4
110 | # [xxx, xxx], # scores for label 17
111 | # [xxx, xxx], # scores for label 41
112 | # [xxx, xxx], # scores for label 42
113 | # [xxx, xxx], # scores for label 43
114 | # [xxx, xxx]] # scores for label 53
115 | # / \
116 | # subject 1 subject 2
117 | #
118 | # Also we can compute different surface distances (Hausdorff, Hausdorff99, Hausdorff95 and mean surface distance). The
119 | # results will be saved in arrays similar to the Dice scores.
120 | compute_distances = True
121 |
122 | # All right, we're ready to make predictions !!
123 | predict(path_images,
124 | path_segm,
125 | path_model,
126 | path_segmentation_labels,
127 | n_neutral_labels=n_neutral_labels,
128 | path_posteriors=path_posteriors,
129 | path_resampled=path_resampled,
130 | path_volumes=path_vol,
131 | names_segmentation=path_segmentation_names,
132 | cropping=cropping,
133 | target_res=target_res,
134 | flip=flip,
135 | topology_classes=topology_classes,
136 | sigma_smoothing=sigma_smoothing,
137 | keep_biggest_component=keep_biggest_component,
138 | n_levels=n_levels,
139 | nb_conv_per_level=nb_conv_per_level,
140 | conv_size=conv_size,
141 | unet_feat_count=unet_feat_count,
142 | feat_multiplier=feat_multiplier,
143 | activation=activation,
144 | gt_folder=gt_folder,
145 | compute_distances=compute_distances)
146 |
--------------------------------------------------------------------------------
/scripts/tutorials/5-generation_advanced.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | This script shows how to generate synthetic images with narrowed intensity distributions (e.g. T1-weighted scans) and
4 | at a specific resolution. All the arguments shown here can be used in the training function.
5 | These parameters were not explained in the previous tutorials as they were not used for the training of SynthSeg.
6 |
7 | Specifically, this script generates 5 examples of training data simulating 3mm axial T1 scans, which have been resampled
8 | at 1mm resolution to be segmented.
9 | Contrast-specificity is achieved by now imposing Gaussian priors (instead of uniform) over the GMM parameters.
10 | Resolution-specificity is achieved by first blurring and downsampling to the simulated LR. The data will then be
11 | upsampled back to HR, so that the downstream network is trained to segment at HR. This upsampling step mimics the
12 | process that will happen at test time.
13 |
14 |
15 |
16 | If you use this code, please cite one of the SynthSeg papers:
17 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
18 |
19 | Copyright 2020 Benjamin Billot
20 |
21 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
22 | compliance with the License. You may obtain a copy of the License at
23 | https://www.apache.org/licenses/LICENSE-2.0
24 | Unless required by applicable law or agreed to in writing, software distributed under the License is
25 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
26 | implied. See the License for the specific language governing permissions and limitations under the
27 | License.
28 | """
29 |
30 |
31 | import os
32 | import numpy as np
33 | from ext.lab2im import utils
34 | from SynthSeg.brain_generator import BrainGenerator
35 |
36 | # script parameters
37 | n_examples = 5 # number of examples to generate in this script
38 | result_dir = './outputs_tutorial_5' # folder where examples will be saved
39 |
40 |
41 | # path training label maps
42 | path_label_map = '../../data/training_label_maps'
43 | generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
44 | output_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
45 | n_neutral_labels = 18
46 | output_shape = 160
47 |
48 |
49 | # ---------- GMM sampling parameters ----------
50 |
51 | # Here we use Gaussian priors to control the means and standard deviations of the GMM.
52 | prior_distributions = 'normal'
53 |
54 | # Here we still regroup labels into classes of similar tissue types:
55 | # Example: (continuing the example of tutorial 1) generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57]
56 | # generation_classes = [0, 1, 2, 3, 4, 5, 4, 6, 3, 4, 5, 4, 6]
57 | # Note that structures with right/left labels are now associated with the same class.
58 | generation_classes = '../../data/labels_classes_priors/generation_classes_contrast_specific.npy'
59 |
60 | # We specify here the hyperparameters governing the prior distribution of the GMM.
61 | # As these prior distributions are Gaussian, they are each controlled by a mean and a standard deviation.
62 | # Therefore, the numpy array pointed by prior_means is of size (2, K), where K is the total number of classes specified
63 | # in generation_classes. The first row of prior_means correspond to the means of the Gaussian priors, and the second row
64 | # correspond to standard deviations.
65 | #
66 | # Example: (continuing the previous one) prior_means = np.array([[0, 30, 80, 110, 95, 40, 70]
67 | # [0, 10, 50, 15, 10, 15, 30]])
68 | # This means that intensities of label 3 and 17, which are both in class 4, will be drawn from the Gaussian
69 | # distribution, whose mean will be sampled from the Gaussian distribution with index 4 in prior_means N(95, 10).
70 | # Here is the complete table of correspondence for this example:
71 | # mean of Gaussian for label 0 drawn from N(0,0)=0
72 | # mean of Gaussian for label 24 drawn from N(30,10)
73 | # mean of Gaussian for label 507 drawn from N(80,50)
74 | # mean of Gaussian for labels 2 and 41 drawn from N(110,15)
75 | # mean of Gaussian for labels 3, 17, 42, 53 drawn from N(95,10)
76 | # mean of Gaussian for labels 4 and 43 drawn from N(40,15)
77 | # mean of Gaussian for labels 25 and 57 drawn from N(70,30)
78 | # These hyperparameters were estimated with the function SynthSR/estimate_priors.py/build_intensity_stats()
79 | prior_means = '../../data/labels_classes_priors/prior_means_t1.npy'
80 | # same as for prior_means, but for the standard deviations of the GMM.
81 | prior_stds = '../../data/labels_classes_priors/prior_stds_t1.npy'
82 |
83 | # ---------- Resolution parameters ----------
84 |
85 | # here we aim to synthesise data at a specific resolution, thus we do not randomise it anymore !
86 | randomise_res = False
87 |
88 | # blurring/downsampling parameters
89 | # We specify here the slice spacing/thickness that we want the synthetic scans to mimic. The axes refer to the *RAS*
90 | # axes, as all the provided data (label maps and images) will be automatically aligned to those axes during training.
91 | # RAS refers to Right-left/Anterior-posterior/Superior-inferior axes, i.e. sagittal/coronal/axial directions.
92 | data_res = np.array([1., 1., 3.]) # slice spacing i.e. resolution to mimic
93 | thickness = np.array([1., 1., 3.]) # slice thickness
94 |
95 | # ------------------------------------------------------ Generate ------------------------------------------------------
96 |
97 | # instantiate BrainGenerator object
98 | brain_generator = BrainGenerator(labels_dir=path_label_map,
99 | generation_labels=generation_labels,
100 | output_labels=output_labels,
101 | n_neutral_labels=n_neutral_labels,
102 | output_shape=output_shape,
103 | prior_distributions=prior_distributions,
104 | generation_classes=generation_classes,
105 | prior_means=prior_means,
106 | prior_stds=prior_stds,
107 | randomise_res=randomise_res,
108 | data_res=data_res,
109 | thickness=thickness)
110 |
111 | for n in range(n_examples):
112 |
113 | # generate new image and corresponding labels
114 | im, lab = brain_generator.generate_brain()
115 |
116 | # save output image and label map
117 | utils.save_volume(im, brain_generator.aff, brain_generator.header,
118 | os.path.join(result_dir, 'image_t1_%s.nii.gz' % n))
119 | utils.save_volume(lab, brain_generator.aff, brain_generator.header,
120 | os.path.join(result_dir, 'labels_t1_%s.nii.gz' % n))
121 |
--------------------------------------------------------------------------------
/scripts/tutorials/6-intensity_estimation.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Examples to show how to estimate of the hyperparameters governing the GMM prior distributions.
4 | This in the case where you want to train contrast-specific versions of SynthSeg.
5 | Beware, if you do so, your model will not be able to segment any contrast at test time !
6 | We do not provide example images and associated label maps, so do not try to run this directly !
7 |
8 |
9 |
10 |
11 | If you use this code, please cite one of the SynthSeg papers:
12 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
13 |
14 | Copyright 2020 Benjamin Billot
15 |
16 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
17 | compliance with the License. You may obtain a copy of the License at
18 | https://www.apache.org/licenses/LICENSE-2.0
19 | Unless required by applicable law or agreed to in writing, software distributed under the License is
20 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
21 | implied. See the License for the specific language governing permissions and limitations under the
22 | License.
23 | """
24 |
25 |
26 | from SynthSeg.estimate_priors import build_intensity_stats
27 |
28 | # ----------------------------------------------- simple uni-modal case ------------------------------------------------
29 |
30 | # paths of directories containing the images and corresponding label maps
31 | image_dir = '/image_folder/t1'
32 | labels_dir = '/labels_folder'
33 | # list of labels from which we want to evaluate the GMM prior distributions
34 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy'
35 | # path of folder where to write estimated priors
36 | result_dir = './outputs_tutorial_6/t1_priors'
37 |
38 | build_intensity_stats(list_image_dir=image_dir,
39 | list_labels_dir=labels_dir,
40 | estimation_labels=estimation_labels,
41 | result_dir=result_dir,
42 | rescale=True)
43 |
44 | # ------------------------------------ building Gaussian priors from several labels ------------------------------------
45 |
46 | # same as before
47 | image_dir = '/image_folder/t1'
48 | labels_dir = '/labels_folder'
49 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy'
50 | result_dir = './outputs_tutorial_6/estimated_t1_priors_classes'
51 |
52 | # In the previous example, each label value is used to build the priors of a single Gaussian distribution.
53 | # We show here how to build Gaussian priors from intensities associated to several label values. For example, that could
54 | # be building the Gaussian prior of white matter by using the labels of right and left white matter.
55 | # This is done by specifying a vector, which regroups label values into "classes".
56 | # Labels sharing the same class will contribute to the construction of the same Gaussian prior.
57 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy'
58 |
59 | build_intensity_stats(list_image_dir=image_dir,
60 | list_labels_dir=labels_dir,
61 | estimation_labels=estimation_labels,
62 | estimation_classes=estimation_classes,
63 | result_dir=result_dir,
64 | rescale=True)
65 |
66 | # ---------------------------------------------- simple multi-modal case -----------------------------------------------
67 |
68 | # Here we have multi-modal images, where every image contains all channels.
69 | # Channels are supposed to be sorted in the same order for all subjects.
70 | image_dir = '/image_folder/multi-modal_t1_t2'
71 |
72 | # same as before
73 | labels_dir = '/labels_folder'
74 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy'
75 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy'
76 | result_dir = './outputs_tutorial_6/estimated_priors_multi_modal'
77 |
78 | build_intensity_stats(list_image_dir=image_dir,
79 | list_labels_dir=labels_dir,
80 | estimation_labels=estimation_labels,
81 | estimation_classes=estimation_classes,
82 | result_dir=result_dir,
83 | rescale=True)
84 |
85 | # ------------------------------------- multi-modal images with separate channels -------------------------------------
86 |
87 | # Here we have multi-modal images, where the different channels are stored in separate directories.
88 | # We provide the these different directories as a list.
89 | list_image_dir = ['/image_folder/t1', '/image_folder/t2']
90 | # In this example, we assume that channels are registered and at the same resolutions.
91 | # Therefore we can use the same label maps for all channels.
92 | labels_dir = '/labels_folder'
93 |
94 | # same as before
95 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy'
96 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy'
97 | result_dir = './outputs_tutorial_6/estimated_priors_multi_modal'
98 |
99 | build_intensity_stats(list_image_dir=list_image_dir,
100 | list_labels_dir=labels_dir,
101 | estimation_labels=estimation_labels,
102 | estimation_classes=estimation_classes,
103 | result_dir=result_dir,
104 | rescale=True)
105 |
106 | # ------------------------------------ multi-modal case with unregistered channels -------------------------------------
107 |
108 | # Again, we have multi-modal images where the different channels are stored in separate directories.
109 | list_image_dir = ['/image_folder/t1', '/image_folder/t2']
110 | # In this example, we assume that the channels are no longer registered.
111 | # Therefore we cannot use the same label maps for all channels, and must provide label maps for all modalities.
112 | labels_dir = ['/labels_folder/t1', '/labels_folder/t2']
113 |
114 | # same as before
115 | estimation_labels = '../../data/labels_classes_priors/generation_labels.npy'
116 | estimation_classes = '../../data/labels_classes_priors/generation_classes.npy'
117 | result_dir = './outputs_tutorial_6/estimated_unregistered_multi_modal'
118 |
119 | build_intensity_stats(list_image_dir=list_image_dir,
120 | list_labels_dir=labels_dir,
121 | estimation_labels=estimation_labels,
122 | estimation_classes=estimation_classes,
123 | result_dir=result_dir,
124 | rescale=True)
125 |
--------------------------------------------------------------------------------
/scripts/tutorials/7-synthseg+.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Very simple script to show how we trained SynthSeg+, which extends SynthSeg by building robustness to clinical
4 | acquisitions. For more details, please look at our MICCAI 2022 paper:
5 |
6 | Robust Segmentation of Brain MRI in the Wild with Hierarchical CNNs and no Retraining,
7 | Billot, Magdamo, Das, Arnold, Iglesias
8 | MICCAI 2022
9 |
10 | If you use this code, please cite one of the SynthSeg papers:
11 | https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
12 |
13 | Copyright 2020 Benjamin Billot
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
16 | compliance with the License. You may obtain a copy of the License at
17 | https://www.apache.org/licenses/LICENSE-2.0
18 | Unless required by applicable law or agreed to in writing, software distributed under the License is
19 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
20 | implied. See the License for the specific language governing permissions and limitations under the
21 | License.
22 | """
23 | from SynthSeg.training import training as training_s1
24 | from SynthSeg.training_denoiser import training as training_d
25 | from SynthSeg.training_group import training as training_s2
26 |
27 | import numpy as np
28 |
29 | # ------------------ segmenter S1
30 | # Here the purpose is to train a first network to produce preliminary segmentations of input scans with five general
31 | # labels: 0-background, 1-white matter, 2-grey matter, 3-fluids, 4-cerebellum.
32 |
33 | # As in tutorial 3, S1 is trained with synthetic images with randomised contrasts/resolution/artefacts such that it can
34 | # readily segment a wide range of test scans without retraining. The synthetic scans are obtained from the same label
35 | # maps and generative model as in the previous tutorials.
36 | labels_dir_s1 = '../../data/training_label_maps'
37 | path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
38 | path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
39 | # However, because we now wish to segment scans using only five labels, we use a different list of segmentation labels
40 | # where all label values in generation_labels are assigned to a target value between [0, 4].
41 | path_segmentation_labels_s1 = '../../data/tutorial_7/segmentation_labels_s1.npy'
42 |
43 | model_dir_s1 = './outputs_tutorial_7/training_s1' # folder where the models will be saved
44 |
45 |
46 | training_s1(labels_dir=labels_dir_s1,
47 | model_dir=model_dir_s1,
48 | generation_labels=path_generation_labels,
49 | segmentation_labels=path_segmentation_labels_s1,
50 | n_neutral_labels=18,
51 | generation_classes=path_generation_classes,
52 | target_res=1,
53 | output_shape=160,
54 | prior_distributions='uniform',
55 | prior_means=[0, 255],
56 | prior_stds=[0, 50],
57 | randomise_res=True)
58 |
59 | # ------------------ denoiser D
60 | # The purpose of this network is to perform label-to-label correction in order to correct potential mistakes made by S1
61 | # at test time. Therefore, D is trained with two sets of label maps: noisy segmentations from S1 (used as inputs to D),
62 | # and their corresponding ground truth (used as target to train D). In order to obtain input segmentations
63 | # representative of the mistakes of S1, these are obtained by degrading real images with extreme augmentation (spatial,
64 | # intensity, resolution, etc.), and feeding them to S1.
65 |
66 | # Obtaining the input/target segmentations is done offline by using the following function: sample_segmentation_pairs.py
67 | # In practice we sample a lot of them (i.e. 10,000), but we give here 8 example pairs. Note that these segmentations
68 | # have the same label values as the output of S1 (i.e. between [0, 4]).
69 | list_input_labels = ['../../data/tutorial_7/noisy_segmentations_d/0001.nii.gz',
70 | '../../data/tutorial_7/noisy_segmentations_d/0002.nii.gz',
71 | '../../data/tutorial_7/noisy_segmentations_d/0003.nii.gz']
72 | list_target_labels = ['../../data/tutorial_7/target_segmentations_d/0001.nii.gz',
73 | '../../data/tutorial_7/target_segmentations_d/0002.nii.gz',
74 | '../../data/tutorial_7/target_segmentations_d/0003.nii.gz']
75 |
76 | # Moreover, we perform spatial augmentation on the sampled pairs, in order to further increase the morphological
77 | # variability seen by the network. Furthermore, the input "noisy" segmentations are further augmented with random
78 | # erosion/dilation:
79 | prob_erosion_dilation = 0.3 # probability of performing random erosion/dilation
80 | min_erosion_dilation = 4, # minimum coefficient for erosion/dilation
81 | max_erosion_dilation = 5 # maximum coefficient for erosion/dilation
82 |
83 | # This is the list of label values included in the input/GT label maps. This list must contain unique values.
84 | input_segmentation_labels = np.array([0, 1, 2, 3, 4])
85 |
86 | model_dir_d = './outputs_tutorial_7/training_d' # folder where the models will be saved
87 |
88 | training_d(list_paths_input_labels=list_input_labels,
89 | list_paths_target_labels=list_target_labels,
90 | model_dir=model_dir_d,
91 | input_segmentation_labels=input_segmentation_labels,
92 | output_shape=160,
93 | prob_erosion_dilation=prob_erosion_dilation,
94 | min_erosion_dilation=min_erosion_dilation,
95 | max_erosion_dilation=max_erosion_dilation,
96 | conv_size=5,
97 | unet_feat_count=16,
98 | skip_n_concatenations=2)
99 |
100 | # ------------------ segmenter S2
101 | # Final segmentations are obtained with a last segmenter S2, which takes as inputs an image as well as the preliminary
102 | # segmentations of S1 that are corrected by D.
103 |
104 | # Here S2 is trained with synthetic images sampled from the usual training label maps with associated generation labels,
105 | # classes. Also, we now use the same segmentation labels as in tutorials 2, 3, and 4, as we now segment all the usual
106 | # regions.
107 | labels_dir_s2 = '../../data/training_label_maps' # these are the same as for S1
108 | path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
109 | path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
110 | path_segmentation_labels_s2 = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
111 |
112 | # The preliminary segmentations are given as soft probability maps and are directly derived from the ground truth.
113 | # Specifically, we take the structures that were segmented by S1, and regroup them into the same "groups" as before.
114 | grouping_labels = '../../data/tutorial_7/segmentation_labels_s1.npy'
115 | # However, in order to simulate test-time imperfections made by D, we these soft probability maps are slightly
116 | # augmented with spatial transforms, and sometimes undergo a random dilation/erosion.
117 |
118 | model_dir_s2 = './outputs_tutorial_7/training_s2' # folder where the models will be saved
119 |
120 | training_s2(labels_dir=labels_dir_s2,
121 | model_dir=model_dir_s2,
122 | generation_labels=path_generation_labels,
123 | n_neutral_labels=18,
124 | segmentation_labels=path_segmentation_labels_s2,
125 | generation_classes=path_generation_classes,
126 | grouping_labels=grouping_labels,
127 | target_res=1,
128 | output_shape=160,
129 | prior_distributions='uniform',
130 | prior_means=[0, 255],
131 | prior_stds=[0, 50],
132 | randomise_res=True)
133 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import sys
4 | import setuptools
5 |
6 | python_version = sys.version[:3]
7 |
8 | if (python_version != '3.6') & (python_version != '3.8'):
9 | raise Exception('Setup.py only works with python version 3.6 or 3.8, not {}'.format(python_version))
10 |
11 | else:
12 |
13 | with open('requirements_python' + python_version + '.txt') as f:
14 | required_packages = [line.strip() for line in f.readlines()]
15 |
16 | print(setuptools.find_packages())
17 |
18 | setuptools.setup(name='SynthSeg',
19 | version='2.0',
20 | license='Apache 2.0',
21 | description='Domain-agnostic segmentation of brain scans',
22 | author='Benjamin Billot',
23 | url='https://github.com/BBillot/SynthSeg',
24 | keywords=['segmentation', 'domain-agnostic', 'brain'],
25 | packages=setuptools.find_packages(),
26 | python_requires='>=3.6',
27 | install_requires=required_packages,
28 | include_package_data=True)
29 |
--------------------------------------------------------------------------------