├── .github
└── workflows
│ └── run-tests.yml
├── .gitignore
├── .readthedocs.yml
├── AUTHORS.md
├── CONTRIBUTING.md
├── LICENSE.md
├── README.md
├── atlalign
├── __init__.py
├── augmentations.py
├── base.py
├── data.py
├── label
│ ├── __init__.py
│ ├── cli.py
│ ├── io.py
│ └── new_GUI.py
├── metrics.py
├── ml_utils
│ ├── __init__.py
│ ├── augmentation.py
│ ├── callbacks.py
│ ├── io.py
│ ├── layers.py
│ ├── losses.py
│ └── models.py
├── nn.py
├── non_ml
│ ├── __init__.py
│ └── intensity.py
├── utils.py
├── visualization.py
├── volume.py
└── zoo.py
├── docs
├── Makefile
├── _images
│ ├── affine_simple.png
│ ├── anchoring.png
│ ├── annot_warping.png
│ ├── antspy.png
│ ├── aug_pipeline.png
│ ├── clipped_vd.png
│ ├── composition.png
│ ├── control_points.png
│ ├── coronal_interpolator.png
│ ├── edge_stretching.png
│ ├── evaluation_metrics.png
│ ├── example_augmentation.png
│ ├── feature_based.png
│ ├── image_registration.png
│ ├── image_registration_2.png
│ ├── int_augmentations.gif
│ ├── inverse.png
│ ├── labeling_tool.png
│ ├── metrics_overview.png
│ ├── resizing.png
│ ├── typical_dataset.png
│ └── warping.png
├── _static
│ └── .keep
├── conf.py
├── generate_metadata.py
├── index.rst
└── source
│ ├── 3d_interpolation.rst
│ ├── api
│ ├── atlalign.label.rst
│ ├── atlalign.ml_utils.rst
│ ├── atlalign.non_ml.rst
│ ├── atlalign.rst
│ └── modules.rst
│ ├── building_blocks.rst
│ ├── datasets.rst
│ ├── deep_learning_data.rst
│ ├── deep_learning_inference.rst
│ ├── deep_learning_training.rst
│ ├── evaluation.rst
│ ├── image_registration.rst
│ ├── installation.rst
│ ├── intensity.rst
│ ├── labeling_tool.rst
│ └── logo
│ └── Atlas_Alignment_banner.jpg
├── scripts
└── experiment_1.py
├── setup.py
├── tests
├── conftest.py
├── data
│ ├── animals.jpg
│ ├── mild_inversion.npy
│ └── supervised_dataset.h5
├── test_augmentations.py
├── test_base.py
├── test_data.py
├── test_metrics.py
├── test_ml_utils
│ ├── test_augmentation.py
│ ├── test_callbacks.py
│ ├── test_io.py
│ ├── test_layers.py
│ └── test_models.py
├── test_nn.py
├── test_non_ml.py
├── test_utils.py
├── test_visualization.py
├── test_volume.py
└── test_zoo.py
└── tox.ini
/.github/workflows/run-tests.yml:
--------------------------------------------------------------------------------
1 | name: ci testing
2 |
3 |
4 | on:
5 | push:
6 | branches: master
7 | pull_request:
8 |
9 |
10 | jobs:
11 |
12 | run_test:
13 |
14 | runs-on: ${{ matrix.os }}
15 |
16 | strategy:
17 |
18 | matrix:
19 | os: [ubuntu-latest] # macos 11 is currently in preview, macos-latest == 1.10.15
20 | python-version: [
21 | 3.8,
22 | 3.9,
23 | ]
24 | include:
25 | - python-version: 3.8
26 | tox-env: py38
27 | - python-version: 3.9
28 | tox-env: py39
29 |
30 | steps:
31 |
32 | - name: checkout latest commit
33 | uses: actions/checkout@v2
34 |
35 | - name: setup python ${{ matrix.python-version }}
36 | uses: actions/setup-python@v2
37 | with:
38 | python-version: ${{ matrix.python-version }}
39 |
40 | - name: install python dependencies
41 | run: |
42 | python -m pip install --upgrade pip
43 | pip install tox tox-gh-actions
44 |
45 | - name: linting and code style
46 | run: tox -vv -e lint
47 |
48 | - name: tests and coverage
49 | run: tox -vv -e ${{ matrix.tox-env }} -- --color=yes
50 |
51 | - name: docs
52 | run: tox -vv -e docs
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom entries
2 | /data/
3 | /.idea/
4 | .DS_Store
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
140 | # pytype static type analyzer
141 | .pytype/
142 |
143 | # Cython debug symbols
144 | cython_debug/
145 |
146 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | formats: []
4 |
5 | sphinx:
6 | builder: html
7 | configuration: docs/conf.py
8 |
9 | build:
10 | image: latest
11 |
12 | python:
13 | version: 3.7
14 | install:
15 | - method: pip
16 | path: .
17 | extra_requirements:
18 | - tf
19 | - docs
20 | system_packages: true
21 |
--------------------------------------------------------------------------------
/AUTHORS.md:
--------------------------------------------------------------------------------
1 | ## Maintainers
2 | - Jan Krepl
3 | - Francesco Casalegno
4 | - Emilie Delattre
5 |
6 | ## Authors
7 | - Jan Krepl
8 | - Francesco Casalegno
9 | - Emilie Delattre
10 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | This page includes some guidelines to enable you to contribute to the project.
4 |
5 | ## Found a bug?
6 |
7 | If you find a bug in the source code or in using the theme, you can
8 | open an issue on GitHub.
9 | Even better, you can submit a pull request with a fix.
10 |
11 | ## Submission guidelines
12 |
13 | ### Submitting an issue
14 |
15 | Before you submit an issue, please search the issue tracker, maybe an issue
16 | for your problem already exists and the discussion might inform you of workarounds
17 | readily available.
18 |
19 | We want to fix all the issues as soon as possible, but before fixing a bug we
20 | need to reproduce and confirm it. In order to reproduce bugs we will need as
21 | much information as possible, and preferably a sample demonstrating the issue.
22 |
23 | ### Submitting a pull request (PR)
24 |
25 | If you wish to contribute to the code base, please open a pull request by
26 | following GitHub's guidelines.
27 |
28 | ## Development Conventions
29 |
30 | `atlalign` uses:
31 | - Black for formatting code
32 | - Flake8 for linting code
33 | - PyDocStyle for checking docstrings
34 |
35 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Atlas Alignment
4 |
5 |
64 |
65 | Atlas Alignment is a toolbox to perform multimodal image registration. It
66 | includes both traditional and supervised deep learning models.
67 |
68 | This project originated from the Blue Brain Project efforts on aligning mouse
69 | brain atlases obtained with ISH gene expression and Nissl stains.
70 |
71 |
72 | ### Official documentation
73 | All details related to installation and logic are described in the
74 | [official documentation](https://atlas-alignment.readthedocs.io/).
75 |
76 |
77 | ### Installation
78 |
79 | #### Installation Requirements
80 |
81 | Some of the functionalities of `atlalign` depend on the [TensorFlow implementation
82 | of the Learned Perceptual Image Patch Similarity (LPIPS)](https://github.com/alexlee-gk/lpips-tensorflow). Unfortunately, the
83 | package is not available on PyPI and must be installed manually as follows
84 | for full functionality.
85 | ```shell script
86 | pip install git+http://github.com/alexlee-gk/lpips-tensorflow.git#egg=lpips_tf
87 | ```
88 |
89 | You can now move on to installing the actual `atlalign` package!
90 |
91 | #### Installation from PyPI
92 | The `atlalign` package can be easily installed from PyPI.
93 | ```shell script
94 | pip install atlalign
95 | ```
96 |
97 | #### Installation from source
98 | As an alternative to installing from PyPI, if you want to try the latest version
99 | you can also install from source.
100 | ```shell script
101 | pip install git+https://github.com/BlueBrain/atlas_alignment#egg=atlalign
102 | ```
103 |
104 | #### Installation for development
105 | If you want a dev install, you should install the latest version from source with
106 | all the extra requirements for running test and generating docs.
107 | ```shell script
108 | git clone https://github.com/BlueBrain/atlas_alignment
109 | cd atlas_alignment
110 | pip install -e .[dev,docs]
111 | ```
112 |
113 | ### Examples
114 | You can find multiple examples in the documentation. Specifically, make
115 | sure to read the
116 | [Building Blocks](https://atlas-alignment.readthedocs.io/en/latest/source/building_blocks.html)
117 | section of the docs to understand the basics.
118 |
119 | ### Data
120 | You can find example data on [Zenodo](https://zenodo.org/record/4541446#.YCqGFc9Kg4g).
121 | Unzip the files to `~/.atlalign/` folder so that you can use the `data.py` module
122 | without manual specification of paths.
123 |
124 | #### Allen Brain Institute Database
125 | You can find and download ISH data from Allen Brain Institute thanks to
126 | [Atlas Download Tools](https://github.com/BlueBrain/Atlas-Download-Tools) repository.
127 |
128 | ### Funding & Acknowledgment
129 | This project was supported by funding to the Blue Brain
130 | Project, a research center of the Ecole polytechnique fédérale de Lausanne, from
131 | the Swiss government's ETH Board of the Swiss Federal Institutes of Technology.
132 |
133 | COPYRIGHT (c) 2021-2022 Blue Brain Project/EPFL
134 |
--------------------------------------------------------------------------------
/atlalign/__init__.py:
--------------------------------------------------------------------------------
1 | """Image registration package.
2 |
3 | Release markers:
4 | X.Y
5 | X.Y.Z for bug fixes
6 | """
7 |
8 | """
9 | The package atlalign is a tool for registration of 2D images.
10 |
11 | Copyright (C) 2021 EPFL/Blue Brain Project
12 |
13 | This program is free software: you can redistribute it and/or modify
14 | it under the terms of the GNU Lesser General Public License as published by
15 | the Free Software Foundation, either version 3 of the License, or
16 | (at your option) any later version.
17 |
18 | This program is distributed in the hope that it will be useful,
19 | but WITHOUT ANY WARRANTY; without even the implied warranty of
20 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 | GNU Lesser General Public License for more details.
22 |
23 | You should have received a copy of the GNU Lesser General Public License
24 | along with this program. If not, see .
25 | """
26 |
27 | __version__ = "0.6.2"
28 |
--------------------------------------------------------------------------------
/atlalign/augmentations.py:
--------------------------------------------------------------------------------
1 | """Module creating one-to-many augmentations."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import h5py
23 | import numpy as np
24 | from skimage.feature import canny
25 | from skimage.util import img_as_float32
26 |
27 | from atlalign.base import DisplacementField
28 |
29 |
30 | def load_dataset_in_memory(h5_path, dataset_name):
31 | """Load a dataset of a h5 file in memory."""
32 | with h5py.File(h5_path, "r") as f:
33 | return f[dataset_name][:]
34 |
35 |
36 | class DatasetAugmenter:
37 | """Class that does the augmentation.
38 |
39 | Attributes
40 | ----------
41 | original_path : str
42 | Path to where the original dataset is located.
43 |
44 | """
45 |
46 | def __init__(self, original_path):
47 | self.original_path = original_path
48 |
49 | self.n_orig = len(load_dataset_in_memory(self.original_path, "image_id"))
50 |
51 | def augment(
52 | self,
53 | output_path,
54 | n_iter=10,
55 | anchor=True,
56 | p_reg=0.5,
57 | random_state=None,
58 | max_corrupted_pixels=500,
59 | ds_f=8,
60 | max_trials=5,
61 | ):
62 | """Augment the original dataset and create a new one.
63 |
64 | Note that this not modify the original dataset.
65 |
66 | Parameters
67 | ----------
68 | output_path : str
69 | Path to where the new h5 file stored.
70 |
71 | n_iter : int
72 | Number of augmented samples per each sample in the original dataset.
73 |
74 | anchor : bool
75 | If True, then dvf anchored before inverted.
76 |
77 | p_reg : bool
78 | Probability that we start from a registered image
79 | (rather than the moving).
80 |
81 | random_state : bool
82 | Random state
83 |
84 | max_corrupted_pixels : int
85 | Maximum numbr of corrupted pixels allowed for a dvf - the actual
86 | number is computed as np.sum(df.jacobian() < 0)
87 |
88 | ds_f : int
89 | Downsampling factor for inverses. 1 creates the least artifacts.
90 |
91 | max_trials : int
92 | Max number of attemps to augment before an identity displacement
93 | used as augmentation.
94 | """
95 | np.random.seed(random_state)
96 |
97 | n_new = n_iter * self.n_orig
98 | print(n_new)
99 |
100 | with h5py.File(self.original_path, "r") as f_orig:
101 | # extract
102 | dset_img_orig = f_orig["img"]
103 | dset_image_id_orig = f_orig["image_id"]
104 | dset_dataset_id_orig = f_orig["dataset_id"]
105 | dset_deltas_xy_orig = f_orig["deltas_xy"]
106 | dset_inv_deltas_xy_orig = f_orig["inv_deltas_xy"]
107 | dset_p_orig = f_orig["p"]
108 |
109 | with h5py.File(output_path, "w") as f_aug:
110 | dset_img_aug = f_aug.create_dataset(
111 | "img", (n_new, 320, 456), dtype="uint8"
112 | )
113 | dset_image_id_aug = f_aug.create_dataset(
114 | "image_id", (n_new,), dtype="int"
115 | )
116 | dset_dataset_id_aug = f_aug.create_dataset(
117 | "dataset_id", (n_new,), dtype="int"
118 | )
119 | dset_p_aug = f_aug.create_dataset("p", (n_new,), dtype="int")
120 | dset_deltas_xy_aug = f_aug.create_dataset(
121 | "deltas_xy", (n_new, 320, 456, 2), dtype=np.float16
122 | )
123 | dset_inv_deltas_xy_aug = f_aug.create_dataset(
124 | "inv_deltas_xy", (n_new, 320, 456, 2), dtype=np.float16
125 | )
126 |
127 | for i in range(n_new):
128 | print(i)
129 | i_orig = i % self.n_orig
130 |
131 | mov2reg = DisplacementField(
132 | dset_deltas_xy_orig[i_orig, ..., 0],
133 | dset_deltas_xy_orig[i_orig, ..., 1],
134 | )
135 |
136 | # copy
137 | dset_image_id_aug[i] = dset_image_id_orig[i_orig]
138 | dset_dataset_id_aug[i] = dset_dataset_id_orig[i_orig]
139 | dset_p_aug[i] = dset_p_orig[i_orig]
140 |
141 | use_reg = np.random.random() > p_reg
142 | print("Using registered: {}".format(use_reg))
143 |
144 | if not use_reg:
145 | # mov != reg
146 | img_mov = dset_img_orig[i_orig]
147 | else:
148 | # mov=reg
149 | img_mov = mov2reg.warp(dset_img_orig[i_orig])
150 | mov2reg = DisplacementField.generate(
151 | (320, 456), approach="identity"
152 | )
153 |
154 | is_nice = False
155 | n_trials = 0
156 |
157 | while not is_nice:
158 | n_trials += 1
159 |
160 | if n_trials == max_trials:
161 | print("Replicating original: out of trials")
162 | dset_img_aug[i] = dset_img_orig[i_orig]
163 | dset_deltas_xy_aug[i] = dset_deltas_xy_orig[i_orig]
164 | dset_inv_deltas_xy_aug[i] = dset_inv_deltas_xy_orig[i_orig]
165 | break
166 |
167 | else:
168 | mov2art = self.generate_mov2art(img_mov)
169 |
170 | reg2mov = mov2reg.pseudo_inverse(ds_f=ds_f)
171 | reg2art = reg2mov(mov2art)
172 |
173 | # anchor
174 | if anchor:
175 | print("ANCHORING")
176 | reg2art = reg2art.anchor(
177 | ds_f=50, smooth=0, h_kept=0.9, w_kept=0.9
178 | )
179 |
180 | art2reg = reg2art.pseudo_inverse(ds_f=ds_f)
181 |
182 | validity_check = np.all(
183 | np.isfinite(reg2art.delta_x)
184 | ) and np.all(np.isfinite(reg2art.delta_y))
185 | validity_check &= np.all(
186 | np.isfinite(art2reg.delta_x)
187 | ) and np.all(np.isfinite(art2reg.delta_y))
188 | jacobian_check = (
189 | np.sum(reg2art.jacobian < 0) < max_corrupted_pixels
190 | )
191 | jacobian_check &= (
192 | np.sum(art2reg.jacobian < 0) < max_corrupted_pixels
193 | )
194 |
195 | if validity_check and jacobian_check:
196 | is_nice = True
197 | print("Check passed")
198 | else:
199 | print("Check failed")
200 |
201 | if n_trials != max_trials:
202 | dset_img_aug[i] = mov2art.warp(img_mov)
203 | dset_deltas_xy_aug[i] = np.stack(
204 | [art2reg.delta_x, art2reg.delta_y], axis=-1
205 | )
206 | dset_inv_deltas_xy_aug[i] = np.stack(
207 | [reg2art.delta_x, reg2art.delta_y], axis=-1
208 | )
209 |
210 | @staticmethod
211 | def generate_mov2art(img_mov, verbose=True, radius_max=60, use_normal=True):
212 | """Generate geometric augmentation and its inverse."""
213 | shape = img_mov.shape
214 | img_mov_float = img_as_float32(img_mov)
215 | edge_mask = canny(img_mov_float)
216 |
217 | if use_normal:
218 | c = np.random.normal(0.7, 0.3)
219 | else:
220 | c = np.random.random()
221 |
222 | if verbose:
223 | print("Scalar: {}".format(c))
224 |
225 | mov2art = c * DisplacementField.generate(
226 | shape,
227 | approach="edge_stretching",
228 | edge_mask=edge_mask,
229 | interpolation_method="rbf",
230 | interpolator_kwargs={"function": "linear"},
231 | n_perturbation_points=6,
232 | radius_max=radius_max,
233 | )
234 |
235 | return mov2art
236 |
--------------------------------------------------------------------------------
/atlalign/label/__init__.py:
--------------------------------------------------------------------------------
1 | """Module containing the interactive labeling tool."""
2 | # The package atlalign is a tool for registration of 2D images.
3 | #
4 | # Copyright (C) 2021 EPFL/Blue Brain Project
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU Lesser General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
--------------------------------------------------------------------------------
/atlalign/label/cli.py:
--------------------------------------------------------------------------------
1 | """Command line interface implementation."""
2 | # The package atlalign is a tool for registration of 2D images.
3 | #
4 | # Copyright (C) 2021 EPFL/Blue Brain Project
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU Lesser General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 | import argparse
19 | import datetime
20 | import pathlib
21 | import sys
22 | from contextlib import redirect_stdout
23 |
24 |
25 | def main(argv=None):
26 | """Run CLI."""
27 | parser = argparse.ArgumentParser(
28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
29 | )
30 | parser.add_argument(
31 | "ref",
32 | type=str,
33 | help="Either a path to a reference image or a number from [0, 528) "
34 | "representing the coronal dimension in the nissl stain volume.",
35 | )
36 | parser.add_argument(
37 | "mov",
38 | type=str,
39 | help="Path to a moving image. Needs to be of the same shape as " "reference.",
40 | )
41 | parser.add_argument(
42 | "output_path", type=str, help="Folder where the outputs will be stored."
43 | )
44 | parser.add_argument(
45 | "-s",
46 | "--swap",
47 | default=False,
48 | help="Swap to the moving to reference mode.",
49 | action="store_true",
50 | )
51 | parser.add_argument(
52 | "-g",
53 | "--force-grayscale",
54 | default=False,
55 | help="Force the images to be grayscale. Convert RGB to grayscale if necessary.",
56 | action="store_true",
57 | )
58 | args = parser.parse_args(argv)
59 |
60 | # Imports
61 | import matplotlib.pyplot as plt
62 | import numpy as np
63 |
64 | from atlalign.data import nissl_volume
65 | from atlalign.label.io import load_image
66 | from atlalign.label.new_GUI import run_gui
67 |
68 | # Read input images
69 | output_channels = 1 if args.force_grayscale else None
70 | if args.ref.isdigit():
71 | img_ref = nissl_volume()[int(args.ref), ..., 0]
72 | else:
73 | img_ref_path = pathlib.Path(args.ref)
74 | img_ref = load_image(
75 | img_ref_path, output_channels=output_channels, output_dtype="float32"
76 | )
77 | img_mov_path = pathlib.Path(args.mov)
78 | img_mov = load_image(
79 | img_mov_path,
80 | output_channels=output_channels,
81 | output_dtype="float32",
82 | )
83 |
84 | # Launch GUI
85 | (
86 | result_df,
87 | keypoints,
88 | symmetric_registration,
89 | img_reg,
90 | interpolation_method,
91 | kernel,
92 | ) = run_gui(img_ref, img_mov, mode="mov2ref" if args.swap else "ref2mov")
93 |
94 | # Dump results and metadata to disk
95 | output_path = pathlib.Path(args.output_path)
96 | output_path.mkdir(exist_ok=True, parents=True)
97 |
98 | result_df.save(output_path / "df.npy")
99 | np.save(output_path / "img_reg.npy", img_reg)
100 | np.save(output_path / "img_ref.npy", img_ref)
101 | np.save(output_path / "img_mov.npy", img_mov)
102 | plt.imsave(output_path / "img_reg.png", img_reg)
103 | plt.imsave(output_path / "img_ref.png", img_ref)
104 | plt.imsave(output_path / "img_mov.png", img_mov)
105 | with open(output_path / "keypoints.csv", "w") as file, redirect_stdout(file):
106 | if args.swap:
107 | print("mov x,mov y,ref x,ref y")
108 | else:
109 | print("ref x,ref y,mov x,mov y")
110 | for (x1, y1), (x2, y2) in keypoints.items():
111 | print(f"{x1},{y1},{x2},{y2}")
112 | with open(output_path / "info.log", "w") as file, redirect_stdout(file):
113 | print("Timestamp :", datetime.datetime.now().ctime())
114 | print("")
115 | print("Parameters")
116 | print("----------")
117 | print("ref :", args.ref)
118 | print("mov :", args.mov)
119 | print("output_path :", output_path.resolve())
120 | print("swap :", args.swap)
121 | print("force-grayscale :", args.force_grayscale)
122 | print()
123 | print("Interpolation")
124 | print("-------------")
125 | print("Symmetric :", symmetric_registration)
126 | print("Method :", interpolation_method)
127 | print("Kernel :", kernel)
128 | print("Results were saved to", output_path.resolve())
129 |
130 |
131 | if __name__ == "__main__":
132 | sys.exit(main())
133 |
--------------------------------------------------------------------------------
/atlalign/label/io.py:
--------------------------------------------------------------------------------
1 | """Input and output utilities for the command line interface."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import cv2
23 |
24 |
25 | def load_image(
26 | file_path,
27 | allowed_suffix=("jpg", "png"),
28 | output_dtype=None,
29 | output_channels=None,
30 | output_shape=None,
31 | input_shape=None,
32 | keep_last=False,
33 | ):
34 | """Load image.
35 |
36 | Parameters
37 | ----------
38 | file_path : str or pathlib.Path
39 | Path to where the image stored.
40 |
41 | allowed_suffix : tuple
42 | List of allowed suffixes.
43 |
44 | output_dtype : str or None
45 | Determines the dtype of the output image. If None, then the same as input.
46 |
47 | output_channels : int, {1, 3} or None
48 | If 1 then grayscale, if 3 then RGB. If None then the sampe as the input image.
49 |
50 | output_shape : tuple
51 | Two element tuple representing (h_output, w_output).
52 |
53 | input_shape : tuple or None
54 | If None no assertion on the input shape. If not None then a tuple representing
55 | (h_input_expected, w_input_expected).
56 |
57 | keep_last : bool
58 | Only active if `output_channels=1`. If True, then the output has shape (h, w, 1). Else (h, w).
59 |
60 | Returns
61 | -------
62 | img : np.array
63 | Array of shape (h, w)
64 |
65 | """
66 | raw_input_ = cv2.imread(str(file_path), cv2.IMREAD_UNCHANGED)
67 |
68 | if input_shape is not None and input_shape != raw_input_[:2]:
69 | raise ValueError(
70 | "Asserted input shape {} different than actual one {}".format(
71 | input_shape, raw_input_.shape
72 | )
73 | )
74 |
75 | # If len(shape) == 2, i.e. shape = (width, height), then raw_input_ is
76 | # already a grayscale image and we don't need to do any conversion
77 | if len(raw_input_.shape) > 2:
78 | if output_channels is None or output_channels == 3:
79 | input_img = cv2.cvtColor(raw_input_, cv2.COLOR_BGR2RGB)
80 | elif output_channels == 1:
81 | input_img = cv2.cvtColor(raw_input_, cv2.COLOR_BGR2GRAY)
82 | else:
83 | raise ValueError("Invalid output channels: {}".format(output_channels))
84 | else:
85 | input_img = raw_input_
86 |
87 | if output_shape is not None:
88 | input_img = cv2.resize(input_img, (output_shape[1], output_shape[0]))
89 |
90 | if input_img.ndim == 3 and input_img.shape[2] == 1 and not keep_last:
91 | input_img = input_img[..., 0]
92 |
93 | if output_dtype:
94 | if "float" in output_dtype:
95 | input_img = (input_img / 255).astype(output_dtype)
96 |
97 | return input_img
98 |
--------------------------------------------------------------------------------
/atlalign/ml_utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Prevent 3 level imports for important objects."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | from atlalign.ml_utils.augmentation import augmenter_1 # noqa
23 | from atlalign.ml_utils.callbacks import MLFlowCallback, get_mlflow_artifact_path # noqa
24 | from atlalign.ml_utils.io import SupervisedGenerator # noqa
25 | from atlalign.ml_utils.layers import ( # noqa
26 | Affine2DVF,
27 | BilinearInterpolation,
28 | DVFComposition,
29 | ExtractMoving,
30 | NoOp,
31 | block_stn,
32 | )
33 | from atlalign.ml_utils.losses import ( # noqa
34 | DVF2IMG,
35 | NCC,
36 | Grad,
37 | Mixer,
38 | PerceptualLoss,
39 | VDClipper,
40 | cross_correlation,
41 | jacobian,
42 | jacobian_distance,
43 | mse_po,
44 | psnr,
45 | ssim,
46 | vector_distance,
47 | )
48 | from atlalign.ml_utils.models import ( # noqa
49 | load_model,
50 | merge_global_local,
51 | replace_lambda_in_config,
52 | save_model,
53 | )
54 |
55 | # Create utility dictionary
56 | ALL_IMAGE_LOSSES = {
57 | "mae": "mae",
58 | "mse": "mse",
59 | "ncc_5": NCC(win=5).loss,
60 | "ncc_9": NCC(win=9).loss,
61 | "ncc_12": NCC(win=12).loss,
62 | "ncc_20": NCC(win=20).loss,
63 | "pearson": cross_correlation,
64 | "perceptual_loss_net-lin_alex": PerceptualLoss(model="net-lin", net="alex").loss,
65 | "perceptual_loss_net-lin_vgg": PerceptualLoss(model="net-lin", net="vgg").loss,
66 | "perceptual_loss_net_alex": PerceptualLoss(model="net", net="alex").loss,
67 | "perceptual_loss_net_vgg": PerceptualLoss(model="net", net="vgg").loss,
68 | "psnr": psnr,
69 | "ssim": ssim,
70 | }
71 | ALL_DVF_LOSSES = {
72 | "grad": Grad().loss,
73 | "jacobian": jacobian,
74 | "jacobian_distance": jacobian_distance,
75 | "mae": "mae",
76 | "mse": "mse",
77 | "mse_po": mse_po,
78 | "vector_distance": vector_distance,
79 | "vdclip2": VDClipper(20, power=2).loss,
80 | "vdclip3": VDClipper(20, power=3).loss,
81 | "vector_jacobian_distance": Mixer(vector_distance, jacobian_distance).loss,
82 | "vector_jacobian_distance_02": Mixer(
83 | vector_distance, jacobian_distance, weights=[0.2, 0.8]
84 | ).loss,
85 | }
86 |
87 | ALL_DVF_LOSSES = {
88 | **ALL_DVF_LOSSES,
89 | **{k: DVF2IMG(v).loss for k, v in ALL_IMAGE_LOSSES.items() if callable(v)},
90 | }
91 |
92 | all_dvf_losses_items = list(ALL_DVF_LOSSES.items())
93 | all_dvf_losses_items.sort(key=lambda x: x[0])
94 |
95 | MIXED_DVF_LOSSES = {
96 | "{}&{}".format(k_i, k_o): Mixer(v_i, v_o).loss
97 | for i, (k_o, v_o) in enumerate(all_dvf_losses_items)
98 | for j, (k_i, v_i) in enumerate(all_dvf_losses_items)
99 | if j < i and callable(v_i) and callable(v_o)
100 | }
101 |
102 | ALL_DVF_LOSSES = {**ALL_DVF_LOSSES, **MIXED_DVF_LOSSES}
103 |
--------------------------------------------------------------------------------
/atlalign/ml_utils/augmentation.py:
--------------------------------------------------------------------------------
1 | """Augmentation related tools."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import imgaug.augmenters as iaa
23 |
24 |
25 | def augmenter_1(p=0.99):
26 | """Create augmenter.
27 |
28 | Contains no coordinate transforms.
29 |
30 | Parameters
31 | ----------
32 | p : float
33 | Number in [0, 1] representing the probability of a random augmentation happening.
34 |
35 | Returns
36 | -------
37 | seq : iaa.Augmenter
38 | Augmenter where each augmentation was manually inspected and makes
39 | sense.
40 |
41 | """
42 | subsubseq_1 = iaa.Multiply(mul=(0.8, 1.2))
43 | subsubseq_2 = iaa.Sequential([iaa.Sharpen(alpha=(0, 1))])
44 |
45 | subsubseq_3 = iaa.Sequential([iaa.EdgeDetect(alpha=(0, 0.9))])
46 |
47 | subsubseq_4 = iaa.OneOf([iaa.GaussianBlur((0, 3.0)), iaa.AverageBlur(k=(2, 7))])
48 |
49 | subsubseq_5 = iaa.AdditiveGaussianNoise(loc=(0, 0.5), scale=(0, 0.2))
50 |
51 | subsubseq_6 = iaa.Add((-0.3, 0.3))
52 |
53 | subsubseq_7 = iaa.Invert(p=1)
54 |
55 | subsubseq_8 = iaa.CoarseDropout(p=0.25, size_percent=(0.005, 0.06))
56 |
57 | subsubseq_9 = iaa.SigmoidContrast(gain=(0.8, 1.2))
58 |
59 | subsubseq_10 = iaa.LinearContrast(alpha=(0.8, 1.2))
60 |
61 | subsubseq_11 = iaa.Sequential([iaa.Emboss(alpha=(0, 1))])
62 |
63 | seq = iaa.Sometimes(
64 | p,
65 | iaa.OneOf(
66 | [
67 | subsubseq_1,
68 | subsubseq_2,
69 | subsubseq_3,
70 | subsubseq_4,
71 | subsubseq_5,
72 | subsubseq_6,
73 | subsubseq_7,
74 | subsubseq_8,
75 | subsubseq_9,
76 | subsubseq_10,
77 | subsubseq_11,
78 | ]
79 | ),
80 | )
81 |
82 | return seq
83 |
--------------------------------------------------------------------------------
/atlalign/ml_utils/callbacks.py:
--------------------------------------------------------------------------------
1 | """Callbacks and aggregation functions."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import pathlib
23 |
24 | import h5py
25 | import mlflow
26 | import pandas as pd
27 | from tensorflow import keras
28 |
29 | from atlalign.data import annotation_volume, segmentation_collapsing_labels
30 | from atlalign.metrics import evaluate_single
31 | from atlalign.ml_utils.io import SupervisedGenerator
32 |
33 |
34 | def get_mlflow_artifact_path(start_char=7):
35 | """Get path to the MLFlow artifacts of the active run.
36 |
37 | Stupid implementation.
38 |
39 | Parameters
40 | ----------
41 | start_char : int
42 | Since the string will start like "file:///actual/path..." we just
43 | slice it.
44 | """
45 | return pathlib.Path(mlflow.active_run().info.artifact_uri[start_char:])
46 |
47 |
48 | class MLFlowCallback(keras.callbacks.Callback):
49 | """Logs metrics into ML.
50 |
51 | Notes
52 | -----
53 | Only runs inside of an mlflow context.
54 |
55 | Parameters
56 | ----------
57 | merged_path : str
58 | Path to the master h5 file containing all the data.
59 |
60 | train_original_ixs_path : str
61 | Path to where original training indices are stored.
62 |
63 | val_original_ixs_path : str
64 | Path to where original validation indices are stored.
65 |
66 | freq : int
67 | Reports metrics on every `freq` batch.
68 |
69 | workers : int
70 | Number of workers to be used for each of the evaluations.
71 |
72 | return_inverse : bool
73 | If True, then generators behave differently.
74 |
75 | starting_step : int
76 | Useful when we want to use a checkpointed model and log metrics as of a different step then 1.
77 |
78 | use_validation : int
79 | If True, then the custom metrics are computed on the validation set.
80 | Otherwise they will be computed on the training set.
81 | """
82 |
83 | def __init__(
84 | self,
85 | merged_path,
86 | train_original_ixs_path,
87 | val_original_ixs_path,
88 | freq=10,
89 | workers=1,
90 | return_inverse=False,
91 | starting_step=0,
92 | use_validation=True,
93 | ):
94 | super().__init__()
95 |
96 | # Check if inside of an mlflow context
97 | if mlflow.active_run() is None:
98 | raise ValueError(
99 | "To use the MLFlowCallback one needs to be inside of a mlflow.start_run context."
100 | )
101 |
102 | # mlflow
103 | self.root_path = get_mlflow_artifact_path()
104 | mlflow.log_params(
105 | {
106 | "train_original_ixs_path": train_original_ixs_path,
107 | "val_original_ixs_path": val_original_ixs_path,
108 | "merged_path": merged_path,
109 | }
110 | )
111 |
112 | self.train_original_gen = SupervisedGenerator(
113 | merged_path,
114 | indexes=train_original_ixs_path,
115 | shuffle=False,
116 | batch_size=1,
117 | return_inverse=return_inverse,
118 | )
119 |
120 | self.val_original_gen = SupervisedGenerator(
121 | merged_path,
122 | indexes=val_original_ixs_path,
123 | shuffle=False,
124 | batch_size=1,
125 | return_inverse=return_inverse,
126 | )
127 | self.freq = freq
128 |
129 | self.workers = workers
130 | self.overall_batch = starting_step
131 | self.use_validation = use_validation
132 |
133 | def on_train_begin(self, logs=None):
134 | """Save model architecture."""
135 | arch_path = self.root_path / "architecture"
136 | checkpoints_path = self.root_path / "checkpoints"
137 |
138 | arch_path.mkdir(parents=True, exist_ok=True)
139 | checkpoints_path.mkdir(parents=True, exist_ok=True)
140 |
141 | def on_batch_end(self, batch, logs=None):
142 | """Log metrics to mlflow.
143 |
144 | The goal here is two extract 3 types of metrics:
145 | - train_merged - extracted from logs (it is a running average over epoch)
146 | - train_original - computed via evaluate_generator
147 | - val_original - computed via evaluate_generator
148 | """
149 | self.overall_batch += 1
150 |
151 | if self.overall_batch % self.freq != 0:
152 | return
153 |
154 | model = self.model
155 | metric_names = model.metrics_names
156 |
157 | all_metrics = {}
158 |
159 | # Keras
160 | all_metrics.update(
161 | {"{}_train_merged".format(metric): logs[metric] for metric in metric_names}
162 | )
163 |
164 | eval_train_original = model.evaluate_generator(
165 | self.train_original_gen, workers=self.workers
166 | )
167 | all_metrics.update(
168 | {
169 | "{}_train_original".format(metric): value
170 | for metric, value in zip(metric_names, eval_train_original)
171 | }
172 | )
173 |
174 | eval_val_original = model.evaluate_generator(
175 | self.val_original_gen, workers=self.workers
176 | )
177 | all_metrics.update(
178 | {
179 | "{}_val_original".format(metric): value
180 | for metric, value in zip(metric_names, eval_val_original)
181 | }
182 | )
183 |
184 | # Custom
185 | print(
186 | "\nComputing custom metrics on {} set!".format(
187 | "val" if self.use_validation else "train"
188 | )
189 | )
190 | gen = self.val_original_gen if self.use_validation else self.train_original_gen
191 |
192 | external_metrics_df = self.compute_external_metrics(model, gen)
193 |
194 | stats_dir = self.root_path / str(self.overall_batch) / "stats"
195 | stats_dir.mkdir(parents=True, exist_ok=True)
196 |
197 | external_metrics_df.to_csv(str(stats_dir / "stats.csv"))
198 | external_metrics_df.to_html(str(stats_dir / "stats.html"))
199 |
200 | external_metrics = dict(external_metrics_df.mean())
201 | all_metrics.update(external_metrics)
202 |
203 | # log into mlflow
204 | mlflow.log_metrics(all_metrics, step=self.overall_batch)
205 |
206 | keras.models.save_model(
207 | model,
208 | str(
209 | self.root_path
210 | / "checkpoints"
211 | / "model_{}.h5".format(self.overall_batch)
212 | ),
213 | )
214 |
215 | @staticmethod
216 | def compute_external_metrics(model, gen):
217 | """Compute external matrics sample by sample.
218 |
219 | Parameters
220 | ----------
221 | model
222 | Keras model
223 |
224 | gen : SupervisedGenerator
225 | Generator
226 |
227 | Returns
228 | -------
229 | metrics : dict
230 | Various metrics.
231 | """
232 | # checks
233 | if gen.shuffle:
234 | raise ValueError("Shuffling is not allowed for external metrics!")
235 | if gen.batch_size != 1:
236 | raise ValueError("Batch size has to be 1 for external metrics")
237 |
238 | # Prepare annotation related stuff (load in RAM, small arrays)
239 | indexes = gen.indexes
240 | with h5py.File(gen.hdf_path, "r") as f:
241 | ps = f["p"][:][indexes]
242 | ids = f["image_id"][:][indexes]
243 |
244 | avol = annotation_volume()
245 | collapsing_labels = segmentation_collapsing_labels()
246 |
247 | external_metrics_per_sample = []
248 |
249 | for i, p in enumerate(ps):
250 | sample = gen[i] # data[indexes[i]]
251 | if gen.return_inverse:
252 | img_mov = sample[0][0][0, ..., 1]
253 | deltas_true = sample[1][1][0]
254 | deltas_true_inv = sample[1][2][0]
255 |
256 | else:
257 | img_mov = sample[0][0, ..., 1]
258 | deltas_true = sample[1][1][0]
259 | deltas_true_inv = None
260 |
261 | deltas_pred = model.predict(sample[0])[1][0]
262 | deltas_pred_inv = None # we do not use the predicted one
263 |
264 | # compute external metrics
265 | res, images = evaluate_single(
266 | deltas_true,
267 | deltas_pred,
268 | img_mov,
269 | ds_f=8, # orig 8
270 | p=p,
271 | deltas_true_inv=deltas_true_inv,
272 | deltas_pred_inv=deltas_pred_inv,
273 | avol=avol,
274 | collapsing_labels=collapsing_labels,
275 | depths=(0, 1, 2, 3, 4, 5, 6, 7, 8),
276 | )
277 | external_metrics_per_sample.append(res)
278 |
279 | external_metrics_df = pd.DataFrame(external_metrics_per_sample, index=ids)
280 |
281 | return external_metrics_df
282 |
--------------------------------------------------------------------------------
/atlalign/ml_utils/io.py:
--------------------------------------------------------------------------------
1 | """Collection of functions dealing with input and output."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import copy
23 | import datetime
24 | import pathlib
25 |
26 | import h5py
27 | import mlflow
28 | import numpy as np
29 | from tensorflow import keras
30 |
31 | from atlalign.base import DisplacementField
32 | from atlalign.data import nissl_volume
33 |
34 |
35 | class SupervisedGenerator(keras.utils.Sequence):
36 | """Generator streaming supervised data from a HDF5 files.
37 |
38 | Parameters
39 | ----------
40 | hdf_path : str or pathlib.Path
41 | Path to where the hdf5 file is.
42 |
43 | batch_size : int
44 | Batch size.
45 |
46 | shuffle : bool
47 | If True, then data shuffled at the end of each epoch and also in the constructor.
48 |
49 | augmenter_ref : None or imgaug.augmenters.Sequential
50 | If None, no augmentation. If instance of a imgaug `Sequential` then its a pipeline that will be used
51 | to augment all reference images in a batch.
52 |
53 | augmenter_mov : None or imgaug.augmenters.Sequential
54 | If None, no augmentation. If instance of a imgaug `Sequential` then its a pipeline that will be used
55 | to augment all moving images in a batch.
56 |
57 | return_inverse : bool
58 | If True, then targets are [img_reg, dvf, dvf_inv] else its only [img_reg, dvf].
59 |
60 | indexes : None or list or str/pathlib.Path
61 | A user defined list of indices to be used for streaming. This is used to give user the chance
62 | to only stream parts of the data. If str or pathlib.Path then read from a .npy file
63 |
64 | Attributes
65 | ----------
66 | indexes : list
67 | List of indices determining the slicing order.
68 |
69 | volume : np.array
70 | Array representing the nissle stain volume of dtype float32.
71 |
72 | times : list
73 | List of timedeltas representing the for each batch yeild.
74 | """
75 |
76 | def __init__(
77 | self,
78 | hdf_path,
79 | batch_size=32,
80 | shuffle=False,
81 | augmenter_ref=None,
82 | augmenter_mov=None,
83 | return_inverse=False,
84 | mlflow_log=False,
85 | indexes=None,
86 | ):
87 |
88 | self._locals = locals()
89 | del self._locals["self"]
90 | self._locals["augmenter_mov"] = None if augmenter_mov is None else "active"
91 | self._locals["augmenter_ref"] = None if augmenter_ref is None else "active"
92 |
93 | if mlflow_log:
94 | mlflow.log_params(self._locals)
95 |
96 | self.hdf_path = str(hdf_path)
97 | self.batch_size = batch_size
98 | self.shuffle = shuffle
99 | self.augmenter_ref = augmenter_ref
100 | self.augmenter_mov = augmenter_mov
101 | self.return_inverse = return_inverse
102 |
103 | with h5py.File(self.hdf_path, "r") as f:
104 | length = len(f["p"])
105 |
106 | if indexes is None:
107 | self.indexes = list(np.arange(length))
108 |
109 | elif isinstance(indexes, list):
110 | self.indexes = indexes
111 |
112 | elif isinstance(indexes, (str, pathlib.Path)):
113 | self.indexes = list(np.load(str(indexes)))
114 |
115 | else:
116 | raise TypeError("Invalid indexes type {}".format(type(indexes)))
117 |
118 | self.volume = nissl_volume()
119 |
120 | self.times = []
121 | self.temp = []
122 |
123 | self.on_epoch_end()
124 |
125 | def __len__(self):
126 | """Length of the iterator = number of steps per epoch."""
127 | n_samples = len(self.indexes)
128 |
129 | return int(np.floor(n_samples / self.batch_size))
130 |
131 | def __getitem__(self, index):
132 | """Load samples in memory and possibly augment.
133 |
134 | Parameters
135 | ----------
136 | index : int
137 | Integer representing the index of to be returned batch.
138 |
139 | Returns
140 | -------
141 | X : np.array
142 | Array of shape (`self.batch_size`, 320, 456, 2) representing the stacked samples of reference
143 | and moving images.
144 |
145 | targets : list
146 | If `self.return_inverse=False` then 2 element list. First element represents the true registered images
147 | - shape (`self.batch_size`, 320, 456, 1). The second element is a batch of ground truth displacement
148 | vector fields - shape (`self.batch_size`, 320, 456, 2).
149 | If `self.return_inverse=True` then 3 element list. The first two elements are like above and the
150 | third one is a batch of ground truth inverse displacements fields (warping images in reference space to
151 | moving space) of shape (`self.batch_size`, 320, 456, 2).
152 | """
153 | begin_time = datetime.datetime.now()
154 | indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
155 | sorted_indexes = sorted(indexes) # hdf5 only supports sorted indexing
156 |
157 | # Generate indexes of the batch
158 | with h5py.File(self.hdf_path, "r") as f:
159 | dset_img = f["img"]
160 | dset_deltas_xy = f["deltas_xy"]
161 | dset_inv_deltas_xy = f["inv_deltas_xy"]
162 | dset_p = f["p"]
163 |
164 | sn = np.minimum(
165 | dset_p[sorted_indexes] // 25, np.ones(len(indexes), dtype=int) * 527
166 | )
167 | ref_images = self.volume[sn]
168 | mov_images = (
169 | dset_img[sorted_indexes][..., np.newaxis].astype("float32") / 255
170 | )
171 | batch_deltas_xy = dset_deltas_xy[sorted_indexes]
172 | batch_inv_deltas_xy = (
173 | dset_inv_deltas_xy[sorted_indexes] if self.return_inverse else None
174 | )
175 |
176 | if self.augmenter_ref is not None:
177 | ref_images = self.augmenter_ref.augment_images(ref_images)
178 |
179 | if self.augmenter_mov is not None:
180 | mov_images = self.augmenter_mov.augment_images(mov_images)
181 |
182 | X = np.concatenate([ref_images, mov_images], axis=3)
183 | if self.return_inverse:
184 | X_mr = np.concatenate([mov_images, ref_images], axis=3)
185 |
186 | # Registered images
187 | reg_images = np.zeros_like(mov_images)
188 |
189 | for i in range(len(mov_images)):
190 | df = DisplacementField(
191 | batch_deltas_xy[i, ..., 0], batch_deltas_xy[i, ..., 1]
192 | )
193 | assert df.is_valid, "{} is not valid".format(sorted_indexes[i])
194 | reg_images[i, ..., 0] = df.warp(mov_images[i, ..., 0])
195 |
196 | self.times.append((datetime.datetime.now() - begin_time))
197 |
198 | if self.return_inverse:
199 |
200 | return [X, X_mr], [reg_images, batch_deltas_xy, batch_inv_deltas_xy]
201 | else:
202 | return X, [reg_images, batch_deltas_xy]
203 |
204 | def on_epoch_end(self):
205 | """Take end of epoch action."""
206 | if self.shuffle:
207 | np.random.shuffle(self.indexes)
208 |
209 | def get_all_data(self):
210 | """Load entire dataset into memory."""
211 | orig_params = copy.deepcopy(self._locals)
212 |
213 | orig_params["batch_size"] = 1
214 | orig_params["shuffle"] = False
215 | orig_params["mlflow_log"] = False
216 |
217 | new_gen = self.__class__(**orig_params)
218 |
219 | all_inps, all_outs = [], []
220 | for inps, outs in new_gen:
221 | all_inps.append(inps)
222 | all_outs.append(outs)
223 |
224 | return all_inps, all_outs
225 |
--------------------------------------------------------------------------------
/atlalign/non_ml/__init__.py:
--------------------------------------------------------------------------------
1 | """Registration algorithm that do not use learning approaches."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | from atlalign.non_ml.intensity import antspy_registration # noqa
23 |
--------------------------------------------------------------------------------
/atlalign/non_ml/intensity.py:
--------------------------------------------------------------------------------
1 | """Collection of intensity based registration methods."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 | import warnings
22 |
23 | warnings.simplefilter(action="ignore") # noqa
24 |
25 | import ants # noqa
26 | import nibabel as nib # noqa
27 |
28 | from atlalign.base import GLOBAL_CACHE_FOLDER, DisplacementField # noqa
29 |
30 |
31 | def antspy_registration(
32 | fixed_img,
33 | moving_img,
34 | registration_type="SyN",
35 | reg_iterations=(40, 20, 0),
36 | aff_metric="mattes",
37 | syn_metric="mattes",
38 | verbose=False,
39 | initial_transform=None,
40 | path=GLOBAL_CACHE_FOLDER,
41 | ):
42 | """Register images using ANTsPY.
43 |
44 | Parameters
45 | ----------
46 | fixed_img: np.ndarray
47 | Fixed image.
48 |
49 | moving_img: np.ndarray
50 | Moving image to register.
51 |
52 | registration_type: {'Translation', 'Rigid', 'Similarity', 'QuickRigid', 'DenseRigid', 'BOLDRigid', 'Affine',
53 | 'AffineFast', 'BOLDAffine', 'TRSAA', 'ElasticSyN', 'SyN', 'SyNRA', 'SyNOnly', 'SyNCC', 'SyNabp',
54 | 'SyNBold', 'SyNBoldAff', 'SyNAggro', 'TVMSQ', 'TVMSQC'}, default 'SyN'
55 |
56 | Optimization algorithm to use to register (more info: https://antspy.readthedocs.io/en/latest/registration.
57 | html?highlight=registration#ants.registration)
58 |
59 | reg_iterations: tuple, default (40, 20, 0)
60 | Vector of iterations for SyN.
61 |
62 | aff_metric: {'GC', 'mattes', 'meansquares'}, default 'mattes'
63 | The metric for the affine part.
64 |
65 | syn_metric: {'CC', 'mattes', 'meansquares', 'demons'}, default 'mattes'
66 | The metric for the SyN part.
67 |
68 | verbose : bool, default False
69 | If True, then the inner solver prints convergence related information in standard output.
70 |
71 | path : str
72 | Path to a folder to where to save the `.nii.gz` file representing the composite transform.
73 |
74 | initial_transform : list or None
75 | Transforms to prepend the before the registration.
76 |
77 | Returns
78 | -------
79 | df: DisplacementField
80 | Displacement field between the moving and the fixed image
81 |
82 | meta : dict
83 | Contains relevant images and paths.
84 |
85 | """
86 | path = str(path)
87 | path += "" if path[-1] == "/" else "/"
88 |
89 | fixed_ants_image = ants.image_clone(ants.from_numpy(fixed_img), pixeltype="float")
90 | moving_ants_image = ants.image_clone(ants.from_numpy(moving_img), pixeltype="float")
91 | meta = ants.registration(
92 | fixed_ants_image,
93 | moving_ants_image,
94 | registration_type,
95 | reg_iterations=reg_iterations,
96 | aff_metric=aff_metric,
97 | syn_metric=syn_metric,
98 | verbose=verbose,
99 | initial_transform=initial_transform,
100 | syn_sampling=32,
101 | aff_sampling=32,
102 | )
103 |
104 | filename = ants.apply_transforms(
105 | fixed_ants_image,
106 | moving_ants_image,
107 | meta["fwdtransforms"],
108 | compose=path + "final_transform",
109 | )
110 |
111 | df = nib.load(filename)
112 | data = df.get_fdata()
113 | data = data.squeeze()
114 | dx = data[:, :, 1]
115 | dy = data[:, :, 0]
116 | df_final = DisplacementField(dx, dy)
117 |
118 | return df_final, meta
119 |
--------------------------------------------------------------------------------
/atlalign/utils.py:
--------------------------------------------------------------------------------
1 | """Collection of helper classes and function that do not deserve to be in base.py.
2 |
3 | Notes
4 | -----
5 | This module cannot import from anywhere else within this project to prevent circular dependencies.
6 |
7 | """
8 |
9 | """
10 | The package atlalign is a tool for registration of 2D images.
11 |
12 | Copyright (C) 2021 EPFL/Blue Brain Project
13 |
14 | This program is free software: you can redistribute it and/or modify
15 | it under the terms of the GNU Lesser General Public License as published by
16 | the Free Software Foundation, either version 3 of the License, or
17 | (at your option) any later version.
18 |
19 | This program is distributed in the hope that it will be useful,
20 | but WITHOUT ANY WARRANTY; without even the implied warranty of
21 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22 | GNU Lesser General Public License for more details.
23 |
24 | You should have received a copy of the GNU Lesser General Public License
25 | along with this program. If not, see .
26 | """
27 |
28 | import numpy as np
29 | import scipy.spatial.qhull as qhull
30 |
31 |
32 | def _triangulate(xyz, uvw):
33 | """Perform Delaunay triangulation.
34 |
35 | Parameters
36 | ----------
37 | xyz : np.ndarray
38 | An array of shape (N, 2) where each row represents one point in 2D (stable points) for which we know the
39 | function value.
40 |
41 | uvw : np.ndarray
42 | An array of shape (K, 2) where each row represents one point in 2D (query point) for which we want to
43 | interpolate the function value.
44 |
45 | Returns
46 | -------
47 | vertices : np.ndarray
48 | An array of shape (K, 3) representing the triangle vertices of each query point. Note that these
49 | vertices are always stable points (from range [O, N))
50 |
51 | wts : np.ndarray
52 | An array of shape (K, 3) representing the weights of respective vertices at each query point.
53 |
54 | """
55 | tri = qhull.Delaunay(xyz)
56 | simplex = tri.find_simplex(uvw)
57 | vertices = np.take(tri.simplices, simplex, axis=0)
58 | temp = np.take(tri.transform, simplex, axis=0)
59 | delta = uvw - temp[:, 2]
60 | bary = np.einsum("njk,nk->nj", temp[:, :2, :], delta)
61 | wts = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))
62 |
63 | return vertices, wts
64 |
65 |
66 | def _interpolate(values, vertices, wts):
67 | """Interpolate inside a triangle.
68 |
69 | Parameters
70 | ----------
71 | values : np.ndarray
72 | An array of shape (N,) that represents function value on the know points which are the vertices after
73 | Delaunay triangulation.
74 |
75 | vertices : np.ndarray
76 | An array of shape (K, 3) representing the triangle vertices of each query point. Note that these
77 | vertices are always stable points (from range [O, N)).
78 |
79 | wts : np.ndarray
80 | An array of shape (K, 3) representing the weights of respective vertices at each query point.
81 |
82 | Returns
83 | -------
84 | interpolations : np.ndarray
85 | An array of shape (K,) representing the interpolated function values on the query points.
86 |
87 | """
88 | return np.einsum("nj,nj->n", np.take(values, vertices), wts)
89 |
90 |
91 | def griddata_custom(points, values_f_1, values_f_2, xi):
92 | """Run griddata extensions that performs only one triangulation.
93 |
94 | Notes
95 | -----
96 | The scipy implementation does not allow to separate triangulation from interpolation. Since we need
97 | to evaluate 2 different functions on the !same! non-regular grid if points the triangulation can be simply
98 | just done once and stored.
99 |
100 | Parameters
101 | ----------
102 | points : np.ndarray
103 | An array of shape (N, 2) where each row represents one point in 2D for which we know the function value.
104 |
105 | values_f_1 : np.ndarray
106 | An array of shape (N,) where each row represents a value of function f_1 on the corresponding point in `points`.
107 |
108 | values_f_2 : np.ndarray
109 | An array of shape (N,) where each row represents a value of function f_2 on the corresponding point in `points`.
110 |
111 | xi : tuple
112 | Tuple of 2 np.ndarray of shapes (h, w) representing the x and y coordinates of the points where we want to
113 | interpolate data. Note that this is simply the result of `np.meshgrid` if our points of interest lie on a
114 | regular grid.
115 |
116 | Returns
117 | -------
118 | f_1_interpolation_on_xi : np.ndarray
119 | An array of shape (h, w) representing the interpolation of f_1 on the `xi` points.
120 |
121 | f_2_interpolation_on_xi : np.ndarray
122 | An array of shape (h, w) representing the interpolation of f_2 on the `xi` points.
123 |
124 | References
125 | ----------
126 | https://stackoverflow.com/questions/20915502/speedup-scipy-griddata-for-multiple-interpolations-between-two-irregular-grids # noqa
127 |
128 | """
129 | if isinstance(xi, tuple):
130 | shape = xi[0].shape
131 | xi = np.hstack((xi[0].reshape(-1, 1), xi[1].reshape(-1, 1))) # possible speedup
132 |
133 | else:
134 | raise TypeError("The xi needs to be a tuple of equally shaped np.ndarrays.")
135 |
136 | vertices, wts = _triangulate(points, xi)
137 |
138 | f_1_interpolation_on_xi = _interpolate(values_f_1, vertices, wts).reshape(shape)
139 | f_2_interpolation_on_xi = _interpolate(values_f_2, vertices, wts).reshape(shape)
140 |
141 | return f_1_interpolation_on_xi, f_2_interpolation_on_xi
142 |
143 |
144 | def _find_all_children(d, children_list=None):
145 | """Construct a list of all the ids of the children of a node and the node itself.
146 |
147 | Parameters
148 | ----------
149 | d : dict
150 | Dictionary node from whom we want the list of all the children and children's children.
151 |
152 | children_list : list, default None
153 | List of children which has to be empty for the first iteration of the function.
154 |
155 | Returns
156 | -------
157 | children_list : list
158 | List of children's ids.
159 |
160 | """
161 | if children_list is None:
162 | children_list = []
163 |
164 | for key, value in d.items():
165 | if key == "id":
166 | children_list.append(value)
167 | if isinstance(value, list):
168 | for child in value:
169 | _find_all_children(child, children_list)
170 |
171 | return children_list
172 |
173 |
174 | def _find_concatenate_labels(d, chosen_depth, dict_of_labels=None, current_depth=0):
175 | """Construct a dictionary which has for each key, the value of the new label after concatenation.
176 |
177 | Parameters
178 | ----------
179 | d : dict
180 | Dictionary node for which we want to concatenate some ids depending on the depth branch.
181 |
182 | chosen_depth : int
183 | Depth at which it is wanted to concatenate the labels.
184 |
185 | dict_of_labels : dict, default {}
186 | Dictionary of corresponding labels (empty at the first call).
187 |
188 | current_depth : int, default 0
189 | Depth of the dictionary node.
190 |
191 | Returns
192 | -------
193 | dict_of_labels: dict
194 | Dictionary of corresponding labels after concatenation of labels tree.
195 |
196 | """
197 | if dict_of_labels is None:
198 | dict_of_labels = {}
199 |
200 | if current_depth < chosen_depth:
201 | for key, value in d.items():
202 | if key == "id":
203 | dict_of_labels[value] = value
204 | if isinstance(value, list):
205 | current_depth = current_depth + 1
206 | for child in value:
207 | _find_concatenate_labels(
208 | child,
209 | chosen_depth,
210 | dict_of_labels=dict_of_labels,
211 | current_depth=current_depth,
212 | )
213 | else:
214 | children_list = []
215 | _find_all_children(d, children_list)
216 | for key, value in d.items():
217 | if key == "id":
218 | for child in children_list:
219 | dict_of_labels[child] = value
220 |
221 | return dict_of_labels
222 |
223 |
224 | def find_labels_dic(segmentation_array, dic, chosen_depth):
225 | """Collapse existing labels into parent labels corresponding to the tree provided in a dictionary.
226 |
227 | Parameters
228 | ----------
229 | segmentation_array : np.array
230 | Annotation array before the concatenation of the labels.
231 |
232 | dic : dict
233 | Dictionary of tree of labels.
234 |
235 | chosen_depth : int
236 | Depth at which it is wanted to concatenate the labels.
237 |
238 | Returns
239 | -------
240 | new_segmentation_array : np.array
241 | New Annotation array with the concatenation of the labels at the desired depth. If a specific label
242 | does not exist in the tree it is assigned -1.
243 |
244 | """
245 | labels_dic = _find_concatenate_labels(dic, chosen_depth)
246 |
247 | new_segmentation_array = segmentation_array.copy()
248 | all_labels = np.unique(segmentation_array)
249 |
250 | for label in all_labels:
251 | if label != 0:
252 | new_label = labels_dic.get(label, -1)
253 | new_segmentation_array[new_segmentation_array == label] = new_label
254 |
255 | return new_segmentation_array
256 |
--------------------------------------------------------------------------------
/atlalign/volume.py:
--------------------------------------------------------------------------------
1 | """Collection of tools for aggregating slices to 3D models."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import numpy as np
23 | import scipy
24 |
25 | from atlalign.data import nissl_volume
26 |
27 |
28 | class Volume:
29 | """Class representing mutliple slices.
30 |
31 | Parameters
32 | ----------
33 | sn : list
34 | List of section numbers.
35 |
36 | mov_imgs : list
37 | List of np.ndarrays representing the moving images corresponding to the `sn`.
38 |
39 | dvfs : list
40 | List of displacement fields corresponding to the `sn`.
41 | """
42 |
43 | def __init__(self, sn, mov_imgs, dvfs):
44 | # initial checks
45 | if not len(sn) == len(mov_imgs) == len(dvfs):
46 | raise ValueError("All the input lists need to have the same length")
47 |
48 | if len(set(sn)) != len(sn):
49 | raise ValueError("There are duplicate section numbers.")
50 |
51 | if not all([0 <= x < 528 for x in sn]):
52 | raise ValueError("All section numbers must lie in [0, 528).")
53 |
54 | self.sn = sn
55 | self.mov_imgs = mov_imgs
56 | self.dvfs = dvfs
57 |
58 | # attributes
59 | self.ref_imgs = nissl_volume()[self.sn, ..., 0]
60 | self.sn_to_ix = [
61 | None if x not in self.sn else self.sn.index(x) for x in range(528)
62 | ]
63 | self.reg_imgs = self._warp()
64 |
65 | @property
66 | def sorted_dvfs(self):
67 | """Return displacement fields sorted by the coronal section."""
68 | sorted_sn = sorted(self.sn)
69 |
70 | return [self.dvfs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn
71 |
72 | @property
73 | def sorted_mov(self):
74 | """Return moving images as sorted by the coronal section."""
75 | sorted_sn = sorted(self.sn)
76 |
77 | return [self.mov_imgs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn
78 |
79 | @property
80 | def sorted_ref(self):
81 | """Return reference images as sorted by the coronal section."""
82 | sorted_sn = sorted(self.sn)
83 |
84 | return [self.ref_imgs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn
85 |
86 | @property
87 | def sorted_reg(self):
88 | """Return registered images as sorted by the coronal section."""
89 | sorted_sn = sorted(self.sn)
90 |
91 | return [self.reg_imgs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn
92 |
93 | def _warp(self):
94 | """Warp the moving images to get registered ones."""
95 | return [df.warp(img) for df, img in zip(self.dvfs, self.mov_imgs)]
96 |
97 | def __getitem__(self, key):
98 | """Get all relevant data for a specified section.
99 |
100 | Parameters
101 | ----------
102 | key : int
103 | Section number to query.
104 |
105 | Returns
106 | -------
107 | ref_img : np.ndarray
108 | Reference image.
109 |
110 | mov_img : np.ndarray
111 | Moving image.
112 |
113 | reg_img : np.ndarray
114 | Registered image.
115 |
116 | df : DisplacementField
117 | Displacement field (mov2reg).
118 | """
119 | if self.sn_to_ix[key] is None:
120 | raise KeyError("The section {} not found".format(key))
121 |
122 | ix = self.sn_to_ix[key]
123 |
124 | return self.ref_imgs[ix], self.mov_imgs[ix], self.reg_imgs[ix], self.dvfs[ix]
125 |
126 |
127 | class GappedVolume:
128 | """Volume containing gaps.
129 |
130 | Parameters
131 | ----------
132 | sn : list
133 | List of section numbers. Note that not required to be ordered.
134 |
135 | imgs : np.ndarray or list
136 | Internally converted to list of grayscale images of same shape representing different coronal sections.
137 | Order corresponds to the one in `sn`.
138 |
139 | """
140 |
141 | def __init__(self, sn, imgs):
142 |
143 | if isinstance(imgs, np.ndarray):
144 | # turn into a list
145 | imgs = np.squeeze(imgs)
146 | imgs = [imgs[i] for i in range(len(imgs))]
147 |
148 | # checks
149 | if len(sn) != len(imgs):
150 | raise ValueError("Inconsitent lenghts")
151 |
152 | if len({img.shape for img in imgs}) != 1:
153 | raise ValueError("All the images need to have the same shape")
154 |
155 | if len(sn) != len(set(sn)):
156 | raise ValueError("There are duplicates in section numbers.")
157 |
158 | self.sn = sn
159 | self.imgs = imgs
160 |
161 | self.shape = imgs[0].shape
162 |
163 |
164 | class CoronalInterpolator:
165 | """Interpolator that works pixel by pixel in the coronal dimension."""
166 |
167 | def __init__(self, kind="linear", fill_value=0, bounds_error=False):
168 | """Construct."""
169 | self.kind = kind
170 | self.fill_value = fill_value
171 | self.bounds_error = bounds_error
172 |
173 | def interpolate(self, gv):
174 | """Interpolate.
175 |
176 | Note that some section images might have pixels equal to np.nan. In this case these pixels are skipped in the
177 | interpolation.
178 |
179 | Parameters
180 | ----------
181 | gv : GappedVolume
182 | Instance of the ``GappedVolume`` to be interpolated.
183 |
184 | Returns
185 | -------
186 | final_volume : np.ndarray
187 | Array of shape (528, 320, 456) that holds the entire interpolated volume without gaps.
188 |
189 | """
190 | grid = np.array(range(528))
191 | final_volume = np.empty((len(grid), *gv.shape))
192 |
193 | for r in range(gv.shape[0]):
194 | for c in range(gv.shape[1]):
195 | x_pixel, y_pixel = zip(
196 | *[
197 | (s, img[r, c])
198 | for s, img in zip(gv.sn, gv.imgs)
199 | if not np.isnan(img[r, c])
200 | ]
201 | )
202 |
203 | f = scipy.interpolate.interp1d(
204 | x_pixel,
205 | y_pixel,
206 | kind=self.kind,
207 | bounds_error=self.bounds_error,
208 | fill_value=self.fill_value,
209 | )
210 | try:
211 | final_volume[:, r, c] = f(grid)
212 | except Exception as e:
213 | print(e)
214 |
215 | return final_volume
216 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_images/affine_simple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/affine_simple.png
--------------------------------------------------------------------------------
/docs/_images/anchoring.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/anchoring.png
--------------------------------------------------------------------------------
/docs/_images/annot_warping.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/annot_warping.png
--------------------------------------------------------------------------------
/docs/_images/antspy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/antspy.png
--------------------------------------------------------------------------------
/docs/_images/aug_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/aug_pipeline.png
--------------------------------------------------------------------------------
/docs/_images/clipped_vd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/clipped_vd.png
--------------------------------------------------------------------------------
/docs/_images/composition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/composition.png
--------------------------------------------------------------------------------
/docs/_images/control_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/control_points.png
--------------------------------------------------------------------------------
/docs/_images/coronal_interpolator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/coronal_interpolator.png
--------------------------------------------------------------------------------
/docs/_images/edge_stretching.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/edge_stretching.png
--------------------------------------------------------------------------------
/docs/_images/evaluation_metrics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/evaluation_metrics.png
--------------------------------------------------------------------------------
/docs/_images/example_augmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/example_augmentation.png
--------------------------------------------------------------------------------
/docs/_images/feature_based.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/feature_based.png
--------------------------------------------------------------------------------
/docs/_images/image_registration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/image_registration.png
--------------------------------------------------------------------------------
/docs/_images/image_registration_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/image_registration_2.png
--------------------------------------------------------------------------------
/docs/_images/int_augmentations.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/int_augmentations.gif
--------------------------------------------------------------------------------
/docs/_images/inverse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/inverse.png
--------------------------------------------------------------------------------
/docs/_images/labeling_tool.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/labeling_tool.png
--------------------------------------------------------------------------------
/docs/_images/metrics_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/metrics_overview.png
--------------------------------------------------------------------------------
/docs/_images/resizing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/resizing.png
--------------------------------------------------------------------------------
/docs/_images/typical_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/typical_dataset.png
--------------------------------------------------------------------------------
/docs/_images/warping.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/warping.png
--------------------------------------------------------------------------------
/docs/_static/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_static/.keep
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 |
16 | import atlalign
17 |
18 | # sys.path.insert(0, os.path.abspath('.'))
19 | print(sys.path)
20 |
21 | # -- Project information -----------------------------------------------------
22 |
23 | project = 'Atlas Alignment'
24 | author = 'Blue Brain Project, EPFL'
25 |
26 | # -- General configuration ---------------------------------------------------
27 |
28 | # Add any Sphinx extension module names here, as strings. They can be
29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
30 | # ones.
31 | extensions = ['sphinx.ext.autodoc',
32 | 'sphinx.ext.doctest',
33 | 'sphinx.ext.napoleon',
34 | 'sphinx.ext.viewcode'
35 | ]
36 |
37 |
38 | # Add any paths that contain templates here, relative to this directory.
39 | templates_path = ['_templates']
40 |
41 | # List of patterns, relative to source directory, that match files and
42 | # directories to ignore when looking for source files.
43 | # This pattern also affects html_static_path and html_extra_path.
44 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
45 |
46 | # -- Options for HTML output -------------------------------------------------
47 |
48 | # The theme to use for HTML and HTML Help pages. See the documentation for
49 | # a list of builtin themes.
50 | #
51 | html_theme = 'sphinx-bluebrain-theme'
52 | html_title = 'Atlas Alignment'
53 | version = atlalign.__version__
54 |
55 | # Add any paths that contain custom static files (such as style sheets) here,
56 | # relative to this directory. They are copied after the builtin static files,
57 | # so a file named "default.css" will overwrite the builtin "default.css".
58 | html_static_path = ['_static']
59 |
60 | # Do not mention module names
61 | add_module_names = False
62 |
63 | # Blue brain theme specific
64 | html_show_sourceling = False
65 |
--------------------------------------------------------------------------------
/docs/generate_metadata.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import datetime
3 |
4 | import atlalign
5 |
6 | metadata_template = \
7 | """---
8 | packageurl: https://github.com/BlueBrain/atlas_alignment
9 | major: {major_version}
10 | description: Image registration with deep learning
11 | repository: https://github.com/BlueBrain/atlas_alignment
12 | externaldoc: https://bbpteam.epfl.ch/documentation/a.html#atlas-alignment
13 | updated: {date}
14 | maintainers: Jan Krepl
15 | name: Atlas Alignment
16 | license: BBP-internal-confidential
17 | issuesurl: https://github.com/BlueBrain/atlas_alignment
18 | version: {version}
19 | contributors: Jan Krepl
20 | minor: {minor_version}
21 | ---
22 | """
23 |
24 | file_directory = pathlib.Path(__file__).parent.absolute()
25 | metadata_path = file_directory / 'metadata.md'
26 |
27 | version = atlalign.__version__
28 | major_version = version.split('.')[0]
29 | minor_version = version.split('.')[1]
30 | date = datetime.datetime.now().strftime("%d/%m/%y")
31 |
32 | metadata_instance = metadata_template.format(version=version,
33 | major_version=major_version,
34 | minor_version=minor_version,
35 | date=date)
36 |
37 | with metadata_path.open('w') as f:
38 | f.write(metadata_instance)
39 |
40 | print('Finished')
41 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Atlas Alignment documentation master file, created by
2 | sphinx-quickstart on Wed Oct 30 16:22:27 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Atlas Alignment
7 | ================
8 | The goal of this project is to facilitate the registration of gene expression images to a reference (Nissl) atlas.
9 |
10 | .. toctree::
11 | :maxdepth: 2
12 | :caption: Contents:
13 |
14 | source/installation
15 | source/image_registration
16 | source/building_blocks
17 | source/datasets
18 | source/intensity
19 | source/labeling_tool
20 | source/deep_learning_data
21 | source/deep_learning_training
22 | source/deep_learning_inference
23 | source/evaluation
24 | source/3d_interpolation
25 |
26 | .. toctree::
27 | :maxdepth: 2
28 | :caption: API Reference
29 |
30 | source/api/modules
31 |
32 |
33 |
34 | Indices and tables
35 | ==================
36 |
37 | * :ref:`genindex`
38 | * :ref:`modindex`
39 | * :ref:`search`
40 |
--------------------------------------------------------------------------------
/docs/source/3d_interpolation.rst:
--------------------------------------------------------------------------------
1 | 3D Interpolation
2 | ================
3 | Once we register consecutive 2D slices the next step is to think of creating 3D volumes. This is especially relevant
4 | for data from the **Allen Brain Institute** that come from experiments with a very specific design.
5 |
6 | There are multiple ways how to perform this interpolation and some of these are contained in the :code:`atlalign.volume`
7 | module:
8 |
9 | - :code:`CoronalInterpolator`
10 |
11 | CoronalInterpolator
12 | -------------------
13 | The :code:`CoronalInterpolator` turns the entire problem into a 1D function interpolation. See below the sketch:
14 |
15 | .. image:: ../_images/coronal_interpolator.png
16 | :width: 400
17 | :height: 300
18 | :alt: Coronal interpolator
19 | :align: center
20 |
21 | For each pixel separately one interpolates over missing sections based purely on the corresponding pixels in all
22 | of the existing sections.
23 |
24 | .. testcode::
25 |
26 | import numpy as np
27 |
28 | from atlalign.volume import CoronalInterpolator, GappedVolume
29 |
30 | n_sections = 55
31 | shape = (30, 40)
32 |
33 | sn = np.random.choice(np.arange(527), size=n_sections, replace=False)
34 | imgs = np.random.random((n_sections, *shape))
35 |
36 | gv = GappedVolume(sn, imgs)
37 | ci = CoronalInterpolator()
38 |
39 | final_volume = ci.interpolate(gv)
40 | print(final_volume.shape) # (528, shape[0], shape[1])
41 |
42 | .. testoutput::
43 | :hide:
44 | :options: -ELLIPSIS, +NORMALIZE_WHITESPACE
45 |
46 | (528, 30, 40)
47 |
--------------------------------------------------------------------------------
/docs/source/api/atlalign.label.rst:
--------------------------------------------------------------------------------
1 | atlalign.label package
2 | ======================
3 |
4 | Submodules
5 | ----------
6 |
7 | atlalign.label.cli module
8 | -------------------------
9 |
10 | .. automodule:: atlalign.label.cli
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | atlalign.label.io module
16 | ------------------------
17 |
18 | .. automodule:: atlalign.label.io
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | atlalign.label.new\_GUI module
24 | ------------------------------
25 |
26 | .. automodule:: atlalign.label.new_GUI
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | Module contents
32 | ---------------
33 |
34 | .. automodule:: atlalign.label
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
--------------------------------------------------------------------------------
/docs/source/api/atlalign.ml_utils.rst:
--------------------------------------------------------------------------------
1 | atlalign.ml\_utils package
2 | ==========================
3 |
4 | Submodules
5 | ----------
6 |
7 | atlalign.ml\_utils.augmentation module
8 | --------------------------------------
9 |
10 | .. automodule:: atlalign.ml_utils.augmentation
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | atlalign.ml\_utils.callbacks module
16 | -----------------------------------
17 |
18 | .. automodule:: atlalign.ml_utils.callbacks
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | atlalign.ml\_utils.io module
24 | ----------------------------
25 |
26 | .. automodule:: atlalign.ml_utils.io
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | atlalign.ml\_utils.layers module
32 | --------------------------------
33 |
34 | .. automodule:: atlalign.ml_utils.layers
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | atlalign.ml\_utils.losses module
40 | --------------------------------
41 |
42 | .. automodule:: atlalign.ml_utils.losses
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | atlalign.ml\_utils.models module
48 | --------------------------------
49 |
50 | .. automodule:: atlalign.ml_utils.models
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | Module contents
56 | ---------------
57 |
58 | .. automodule:: atlalign.ml_utils
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
--------------------------------------------------------------------------------
/docs/source/api/atlalign.non_ml.rst:
--------------------------------------------------------------------------------
1 | atlalign.non\_ml package
2 | ========================
3 |
4 | Submodules
5 | ----------
6 |
7 | atlalign.non\_ml.intensity module
8 | ---------------------------------
9 |
10 | .. automodule:: atlalign.non_ml.intensity
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | Module contents
16 | ---------------
17 |
18 | .. automodule:: atlalign.non_ml
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
--------------------------------------------------------------------------------
/docs/source/api/atlalign.rst:
--------------------------------------------------------------------------------
1 | atlalign package
2 | ================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | atlalign.label
11 | atlalign.ml_utils
12 | atlalign.non_ml
13 |
14 | Submodules
15 | ----------
16 |
17 | atlalign.augmentations module
18 | -----------------------------
19 |
20 | .. automodule:: atlalign.augmentations
21 | :members:
22 | :undoc-members:
23 | :show-inheritance:
24 |
25 | atlalign.base module
26 | --------------------
27 |
28 | .. automodule:: atlalign.base
29 | :members:
30 | :undoc-members:
31 | :show-inheritance:
32 |
33 | atlalign.data module
34 | --------------------
35 |
36 | .. automodule:: atlalign.data
37 | :members:
38 | :undoc-members:
39 | :show-inheritance:
40 |
41 | atlalign.metrics module
42 | -----------------------
43 |
44 | .. automodule:: atlalign.metrics
45 | :members:
46 | :undoc-members:
47 | :show-inheritance:
48 |
49 | atlalign.nn module
50 | ------------------
51 |
52 | .. automodule:: atlalign.nn
53 | :members:
54 | :undoc-members:
55 | :show-inheritance:
56 |
57 | atlalign.utils module
58 | ---------------------
59 |
60 | .. automodule:: atlalign.utils
61 | :members:
62 | :undoc-members:
63 | :show-inheritance:
64 |
65 | atlalign.visualization module
66 | -----------------------------
67 |
68 | .. automodule:: atlalign.visualization
69 | :members:
70 | :undoc-members:
71 | :show-inheritance:
72 |
73 | atlalign.volume module
74 | ----------------------
75 |
76 | .. automodule:: atlalign.volume
77 | :members:
78 | :undoc-members:
79 | :show-inheritance:
80 |
81 | atlalign.zoo module
82 | -------------------
83 |
84 | .. automodule:: atlalign.zoo
85 | :members:
86 | :undoc-members:
87 | :show-inheritance:
88 |
89 | Module contents
90 | ---------------
91 |
92 | .. automodule:: atlalign
93 | :members:
94 | :undoc-members:
95 | :show-inheritance:
96 |
--------------------------------------------------------------------------------
/docs/source/api/modules.rst:
--------------------------------------------------------------------------------
1 | atlalign
2 | ========
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | atlalign
8 |
--------------------------------------------------------------------------------
/docs/source/datasets.rst:
--------------------------------------------------------------------------------
1 | .. _datasets:
2 |
3 | Datasets
4 | ========
5 |
6 | We include multiple dataset loading utilities. Note that for some functions
7 | implemented in the :code:`atlalign.data` module we assume that the user has
8 | the underlying raw data locally.
9 |
10 | Nissl volume
11 | ------------
12 | Contains the 25 microns reference volume used within the entire project. The corresponding function
13 | is :code:`nissl_volume` that returns a (528, 320, 456, 1) numpy array.
14 |
15 |
16 | Annotation volume
17 | -----------------
18 | Contains per pixel segmentation/annotation of the entire reference atlas. The corresponding
19 | function is :code:`annotation_volume` that returns a (528, 320, 456) numpy array. Note
20 | that one can use the corresponding :code:`segmentation_collapsing_labels` function
21 | to get the tree-like hierarchy of the classes.
22 |
23 |
24 | Manual registration
25 | -------------------
26 | This dataset comes from the labeling tool that extracts displacement fields. The corresponding function is
27 | :code:`manual_registration`.
28 | Available genes:
29 |
30 | - **Calb1**
31 | - **Calb2**
32 | - **Cck**
33 | - **Npy**
34 | - **Pvalb**
35 | - **Sst**
36 | - **Vip**
37 |
38 | .. code-block:: python
39 |
40 | from atlalign.data import manual_registration
41 |
42 | res = manual_registration()
43 |
44 | assert set(res.keys()) == {'dataset_id', 'deltas_xy', 'image_id', 'img', 'inv_deltas_xy', 'p'}
45 | assert len(res['image_id']) == 278
46 |
47 | The returned dictionary contains the following keys:
48 |
49 | - :code:`dataset_id` - unique id of the section dataset
50 | - :code:`deltas_xy` - array of shape (320, 456, 2) where the last dimension represents the x (resp y) deltas of the transformation
51 | - :code:`image_id` - unique id of the section image
52 | - :code:`img` - moving image of shape (320, 456) that was preregistered with the Allen API
53 | - :code:`inv_deltas_xy` - same as :code:`deltas_xy` but represents the inverse transformation
54 | - :code:`p` - coronal coordinate in microns [0, 13200]
55 |
56 | To perform the registration instantiate :code:`atlalign.base.DisplacementField` using the :code:`deltas_xy` and warp the
57 | :code:`img` with it.
58 |
59 | .. code-block:: python
60 |
61 | from atlalign.base import DisplacementField
62 | from atlalign.data import manual_registration
63 |
64 | import numpy as np
65 |
66 | res = manual_registration()
67 | i = 10
68 | delta_x = res['deltas_xy'][i, ..., 0]
69 | delta_y = res['deltas_xy'][i, ..., 1]
70 | img_mov = res['img'][i]
71 |
72 | df = DisplacementField(delta_x, delta_y)
73 | img_reg = df.warp(img_mov)
74 |
75 | For more details on :code:`atlalign.base.DisplacementField` see :ref:`building_blocks`.
76 |
77 | Dummy
78 | -----
79 | Artificially generated datasets.
80 |
81 | Rectangles
82 | ~~~~~~~~~~
83 | Rectangles with stripes of different intensities. Corresponding function - :code:`rectangles`.
84 |
85 | Circles
86 | ~~~~~~~
87 | Circles with inner circles of different intensities. Corresponding function - :code:`circles`.
88 |
89 |
--------------------------------------------------------------------------------
/docs/source/deep_learning_data.rst:
--------------------------------------------------------------------------------
1 | .. _dl_data:
2 |
3 |
4 | Deep Learning - Generating a dataset
5 | ====================================
6 | This section describes how to create a dataset for supervised learning. Note that we already described how to easily
7 | load existing datasets in :ref:`datasets`. However, in this chapter we will discuss in detail how to perform
8 | **augmentations** on these or similar datasets.
9 |
10 |
11 | Augmentation is a common strategy in deep learning. The goal is to use an existing dataset and make it larger
12 | (possibly infinitely) by altering both the inputs and the targets in some sensible way.
13 |
14 | Generally, we will distinguish two types of augmentations
15 |
16 | 1. **Intensity** - change the pixel intensities
17 | 2. **Geometric** - change the geometric structure of the underlying image
18 |
19 |
20 | The reason why in our use case we split augmentations into two groups is very simple : the **intensity** augmentations
21 | do not change the labels, whereas the **geometric** ones do change our labels.
22 |
23 |
24 | Geometric augmentations
25 | -----------------------
26 | The geometric augmentations are a major part of :code:`atlalign` functionality. All the augmentations can be found
27 | in :code:`atlalign.zoo` module. They are easily accessible to the user via the :code:`generate` class method.
28 |
29 | - :code:`affine`
30 | - :code:`affine_simple`
31 | - :code:`control_points`
32 | - :code:`edge_stretching`
33 | - :code:`projective`
34 |
35 |
36 | Affine simple
37 | ~~~~~~~~~~~~~
38 |
39 | .. testcode::
40 |
41 | import numpy as np
42 | import matplotlib.pyplot as plt
43 |
44 | from atlalign.base import DisplacementField
45 | from atlalign.data import rectangles
46 |
47 | shape=(320, 456)
48 |
49 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31))
50 | df = DisplacementField.generate(shape,
51 | approach='affine_simple',
52 | scale_x=1.9,
53 | scale_y=1,
54 | translation_x=-300,
55 | translation_y=0,
56 | rotation=0.2
57 | )
58 |
59 | img_aug = df.warp(img)
60 |
61 | _, (ax_orig, ax_aug) = plt.subplots(1, 2, figsize=(10, 14))
62 | ax_orig.imshow(img)
63 | ax_aug.imshow(img_aug)
64 |
65 | .. image:: ../_images/affine_simple.png
66 | :width: 600
67 | :alt: Affine simple
68 | :align: center
69 |
70 | Control points
71 | ~~~~~~~~~~~~~~
72 | Control points is a generic augmentation that gives the user the possibility to specify displacement only on
73 | a selected set of control points. For the remaining pixels the displacement will be intepolation.
74 |
75 | .. testcode::
76 |
77 | import matplotlib.pyplot as plt
78 |
79 | from atlalign.base import DisplacementField
80 | from atlalign.data import rectangles
81 |
82 | shape = (320, 456)
83 |
84 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31))
85 |
86 |
87 | points = np.array([[200, 150]])
88 |
89 | values_delta_x = np.array([-100])
90 | values_delta_y = np.array([0])
91 |
92 | df = DisplacementField.generate(shape,
93 | approach='control_points',
94 | points=points,
95 | values_delta_x=values_delta_x,
96 | values_delta_y=values_delta_y,
97 | interpolation_method='rbf')
98 |
99 | img_aug = df.warp(img)
100 |
101 | _, (ax_orig, ax_aug) = plt.subplots(1, 2, figsize=(10, 14))
102 | ax_orig.imshow(img)
103 | ax_aug.imshow(img_aug)
104 |
105 | .. image:: ../_images/control_points.png
106 | :width: 600
107 | :alt: Affine simple
108 | :align: center
109 |
110 |
111 |
112 | Edge stretching
113 | ~~~~~~~~~~~~~~~
114 | Edge stretching using :code:`control_points` in the background. However, instead of requiring the user to specify
115 | these points manually the user simply passes a mask array of edges. The algorithm then
116 | selects randomly :code:`n_perturbation_points` points out of the edges and randomly displaces them. Note that
117 | the :code:`interpolation_method='rbf'` and :code:`interpolator_kwargs={'function': 'linear'}` gives the nicest
118 | results.
119 |
120 | .. testcode::
121 |
122 | import matplotlib.pyplot as plt
123 |
124 | from atlalign.base import DisplacementField
125 | from atlalign.data import rectangles
126 |
127 | from skimage.feature import canny
128 |
129 | shape = (320, 456)
130 |
131 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31))
132 |
133 |
134 | edge_mask = canny(img)
135 |
136 | df = DisplacementField.generate(shape,
137 | approach='edge_stretching',
138 | edge_mask=edge_mask,
139 | interpolation_method='rbf',
140 | interpolator_kwargs={'function': 'linear'},
141 | n_perturbation_points=5)
142 |
143 | img_aug = df.warp(img)
144 |
145 | _, (ax_orig, ax_aug, ax_edges) = plt.subplots(1, 3, figsize=(10, 14))
146 | ax_orig.imshow(img)
147 | ax_aug.imshow(img_aug)
148 | ax_edges.imshow(edge_mask)
149 |
150 | .. image:: ../_images/edge_stretching.png
151 | :width: 600
152 | :alt: Affine simple
153 | :align: center
154 |
155 |
156 | Intensity augmentations
157 | -----------------------
158 | Intensity augmentations do not change the label (displacement field). As opposed to the geometric ones, we fully
159 | delegate these augmentations to a 3rd party package - :code:`imgaug`. For more details see the official documentation.
160 |
161 | The user can use some preset augmentor pipeplines in :code:`atlalign.ml_utils.augmentation` or create his own.
162 | See below an example of using a preset augmentor together with a small animation showing 10 random augmentations.
163 |
164 |
165 | .. testcode::
166 |
167 | from atlalign.ml_utils import augmenter_1
168 |
169 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31))
170 | aug = augmenter_1()
171 |
172 | img_aug = aug.augment_image(img)
173 |
174 | .. image:: ../_images/int_augmentations.gif
175 | :width: 00
176 | :alt: Affine simple
177 | :align: center
178 |
179 | Make sure that :code:`imgaug` pipelines do not contain any geometric transformation.
180 |
181 |
182 | Putting things together
183 | -----------------------
184 | With the tools described above and the ones from :ref:`building_blocks` we can significantly increase the size of
185 | our supervised learning datasets. Note that in general we want to augment the moving image with both the
186 | intensity and geometric augmentations. The reference image stays the same or only intensity augmentation is applied.
187 |
188 | To better demonstrate the geometric augmentation logic for real data, we refer the reader to the sketch below.
189 |
190 | .. image:: ../_images/aug_pipeline.png
191 | :width: 700
192 | :alt: Affine simple
193 | :align: center
194 |
195 | We assume that at the beginning we are given the **moving image** and transformation that registers this image -
196 | :code:`mov2reg` (in green). Note that if only registered images are provided this is equivalent to setting
197 | :code:`mov2reg` equal to an identity mapping. The actual augmentation is captured by :code:`mov2art` (in yellow).
198 | Once the user specifies it (randomly generates), then :code:`atlalign` can imply the :code:`art2reg`. How?
199 |
200 | 1. Invert :code:`mov2art` to obtain :code:`art2mov`.
201 | 2. Compose :code:`art2mov` and :code:`mov2reg`
202 |
203 |
204 | One clearly sees that the final transformation :code:`art2reg` will be a combination of the :code:`mov2reg` and :code:`art2mov`.
205 | Ideally, we want to make sure that these transformations are as nice as possible - differentiable and invertible.
206 |
207 | Note that one can inspect :code:`df.jacobian` to programatically determine how smooth the transformation is.
208 | Specifically, the pixels with nonpositive jacobian represent possible artifacts. Using :code:`edge_stretching` or
209 | similar it can happen occasionally that the transformations are ugly.
210 |
211 | See below an end-to-end example.
212 |
213 | Example
214 | -------
215 |
216 | .. code-block:: python
217 |
218 | import matplotlib.pyplot as plt
219 | import numpy as np
220 | from skimage.feature import canny
221 | from skimage.util import img_as_float32
222 |
223 | from atlalign.base import DisplacementField
224 | from atlalign.data import new_registration
225 | from atlalign.visualization import create_grid
226 |
227 |
228 | # helper function
229 | def generate_mov2art(img_mov, anchor=True):
230 | """Generate geometric augmentation and its inverse."""
231 | shape = img_mov.shape
232 | img_mov_float = img_as_float32(img_mov)
233 | edge_mask = canny(img_mov_float)
234 | mov2art = DisplacementField.generate(shape,
235 | approach='edge_stretching',
236 | edge_mask=edge_mask,
237 | interpolation_method='rbf',
238 | interpolator_kwargs={'function': 'linear'},
239 | n_perturbation_points=5)
240 |
241 | if anchor:
242 | mov2art = mov2art.anchor()
243 |
244 | art2mov = mov2art.pseudo_inverse()
245 |
246 | return mov2art, art2mov
247 |
248 | # load datast
249 | orig_dataset = new_registration()
250 | i = 10
251 |
252 | # load existing data
253 | img_mov = orig_dataset['img'][i]
254 | mov2reg = DisplacementField(orig_dataset['deltas_xy'][i,..., 0], orig_dataset['deltas_xy'][i,..., 1])
255 | img_grid = create_grid(img_mov.shape)
256 |
257 | # generate mov2art
258 | mov2art, art2mov = generate_mov2art(img_mov)
259 | img_art = mov2art.warp(img_mov)
260 |
261 | # numerically approximate composition
262 | art2reg = art2mov(mov2reg, border_mode='reflect')
263 |
264 |
265 | # Plotting
266 | _, ((ax_mov, ax_art),
267 | (ax_reg_mov, ax_reg_art),
268 | ( ax_reg_mov_grid, ax_reg_art_grid)) = plt.subplots(3, 2, figsize=(15, 10))
269 |
270 |
271 | ax_mov.imshow(img_mov)
272 | ax_mov.set_title('Moving')
273 | ax_art.imshow(img_art)
274 | ax_art.set_title('Artificial')
275 |
276 | ax_reg_mov.imshow(mov2reg.warp(img_mov))
277 | ax_reg_mov.set_title('Registered (starting from moving)')
278 | ax_reg_art.imshow(art2reg.warp(img_art))
279 | ax_reg_art.set_title('Registered (starting from artifical)')
280 |
281 | ax_reg_mov_grid.imshow(mov2reg.warp(img_grid))
282 | ax_reg_art_grid.imshow(art2reg.warp(img_grid))
283 |
284 |
285 | .. image:: ../_images/example_augmentation.png
286 | :width: 600
287 | :alt: Augmentation example
288 | :align: center
289 |
290 | Since generation of the augmentations or finding the inverse numerically might be
291 | slow we highly recommend precomputing everything in advance (before training)
292 | and storing it in a :code:`.h5` file. See the :code:`atlalign.augmentations.py` module
293 | that implements a similar strategy. Note that the training logic requires
294 | the data to be stored in a :code:`.h5` file.
295 |
--------------------------------------------------------------------------------
/docs/source/deep_learning_inference.rst:
--------------------------------------------------------------------------------
1 | Deep Learning - Inference
2 | =========================
3 |
4 | Loading model
5 | -------------
6 | To load a pretrained network one should use :code:`atlalign.ml_utils.load_model`.
7 | Note that it loads all possible custom layers in the background so that the user does not have
8 | to worry about it.
9 |
10 | .. code-block:: python
11 |
12 | from atlalign.ml_utils import load_model
13 |
14 | path_1 = 'path/to/a/folder' # inside of this folder .json (architecture) and .h5 (weights)
15 | path_2 = 'path/to/a/file.h5' # architecture and weights not separated + additional info (loss and optimizer)
16 |
17 | model_path_1 = load_model(path_1)
18 | model_path_2 = load_model(path_2, compile=True)
19 |
20 | Merging
21 | -------
22 | To merge a global and a local alignment network one can perform the composition via the :code:`__call__` method
23 | of :code:`atlalign.base.DisplacementField` on a per sample basis. A better approach is to use a custom keras layer
24 | implementing composition. For the latter option we provide a utility function :code:`atlalign.ml_utils.merge_global_local`.
25 |
26 | .. code-block:: python
27 |
28 | from atlalign.ml_utils import load_model, merge_global_local
29 |
30 | path_global = 'global_model.h5'
31 | path_local = 'local_model.h5'
32 |
33 | model_global = load_model(path_global)
34 | model_local = load_model(path_local)
35 |
36 | model_merged = merge_global_local(model_global, model_local)
37 |
38 | Forward pass
39 | ------------
40 | Performing the actual inference is extremely simple. Please review the :ref:`dl_training.supervised_generator` to
41 | understand the shape of the expected input. To quickly summarize (for non-inverse models) the user only needs to create
42 | a 4D array of the following shape
43 |
44 | .. code-block:: python
45 |
46 | (batch_size, height=320, width=456, depth=2)
47 |
48 | The last dimension is simply a stack of the :code:`img_ref` and :code:`img_mov`.
49 |
50 | .. code-block:: python
51 |
52 | import os
53 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' # to enable GPU
54 |
55 | import numpy as np
56 |
57 | from atlalign.base import DisplacementField
58 | from atlalign.ml_utils import load_model
59 |
60 | batch_size = 32 # how many samples are grouped together at inference time
61 |
62 | model = load_model('path/to/model.h5')
63 | X = np.random.random((200, 320, 456, 2))
64 |
65 | [reg_images, deltas_xy] = model.predict(X, batch_size=batch_size)
66 |
67 | # one can also create instances of DisplacementFields to perform many other tasks
68 | dfs = [DisplacementField(deltas_xy[i, ..., 0], deltas_xy[i, ..., 1]) for i in range(len(X))]
69 |
--------------------------------------------------------------------------------
/docs/source/evaluation.rst:
--------------------------------------------------------------------------------
1 | Metrics and Evaluation
2 | ======================
3 | Assuming we have the ground truth labels (necessary for supervised learning) we can use multiple evaluation metrics.
4 | In general they can be grouped into 3 categories based on what they compare:
5 |
6 |
7 | - **Image similarity**
8 | - **Displacement field**
9 | - **Segmentation**
10 |
11 | .. image:: ../_images/evaluation_metrics.png
12 | :width: 600
13 | :alt: Evaluation metrics
14 | :align: center
15 |
16 |
17 | All these metrics are implemented in :code:`atlalign.metrics`. A subset of them is also available as drop-in
18 | losses for deep learning in :code:`atlalign.ml_utils.losses`.
19 |
20 | The following sub-sections list the available metrics in each of the three categories.
21 | They are all part of the ``atlalign.metrics`` module and have a common interface::
22 |
23 | atlalign.metrics.(y_true, y_pred, **kwargs)
24 |
25 | The parameters ``y_true`` and ``y_pred`` are pairs of images, displacement fields,
26 | or segmentation maps. Multiple pairs of images can be processed at once by stacking
27 | them along the first dimension, so that ``y_true`` and ``y_pred`` have the shape
28 | ``(n_images, ...)``.
29 |
30 | Some metrics have optional keyword arguments that differ from metric to metric,
31 | the API reference for more details.
32 |
33 | Image Similarity Metrics
34 | ------------------------
35 |
36 | Loss-like (the smaller the better):
37 |
38 | - ``mse_img`` -- mean squared error
39 | - ``mae_img`` -- mean absolute error
40 | - ``demons_img`` -- ANTsPy's demons metric
41 | - ``perceptual_loss_img`` -- perceptual loss
42 |
43 | Similarity-like (the higher the better):
44 |
45 | - ``psnr_img`` -- peak signal to noise ratio (max = infinity)
46 | - ``cross_correlation_img`` -- image cross correlation (max = 1)
47 | - ``ssmi_img`` -- structural similarity (max = 1)
48 | - ``mi_img`` -- mutual information (max = mutual information with itself)
49 |
50 | Displacement Field Metrics
51 | --------------------------
52 |
53 | - ``correlation_combined`` -- combined version of correlation
54 | - ``mae_combined`` -- combined version of mean absolute error
55 | - ``mse_combined`` -- combined version of mean squared error
56 | - ``r2_combined`` -- combined version of r2
57 | - ``vector_distance_combined`` -- combined version of vector distance
58 |
59 | Segmentation Metrics
60 | --------------------
61 |
62 | - ``iou_score`` -- intersection over union (between 0 and 1, the higher the better)
63 | - ``dice_score`` -- dice score (between 0 and 1, the higher the better)
64 |
65 | Compute Many Metrics at Once
66 | ----------------------------
67 |
68 | To get a comprehensive overview of how specific model performs, we implemented a utility function
69 | :code:`atlalign.metrics.evaluate` that computes multiple metrics at the same time and returns the results
70 | in a :code:`pandas.DataFrame`.
71 |
72 |
73 | .. code-block:: python
74 |
75 | import numpy as np
76 |
77 | from atlalign.metrics import evaluate
78 |
79 | n_samples = 5
80 | shape = (320, 456)
81 |
82 | y_true = np.random.randint(0, 20, size=(n_samples, *shape, 2))
83 | y_pred = np.random.randint(0, 20, size=(n_samples, *shape, 2))
84 |
85 | imgs_mov = np.random.random((n_samples, *shape))
86 | img_ids = np.array(range(n_samples))
87 | dataset_ids = np.array(range(n_samples))
88 | ps = np.linspace(0, 12200, num=n_samples).astype('int')
89 |
90 | _, res_df = evaluate(y_true,
91 | y_pred,
92 | imgs_mov=imgs_mov,
93 | img_ids=img_ids,
94 | ps=ps,
95 | dataset_ids=dataset_ids,
96 | depths=(1, 2, 3, 4, 5))
97 |
98 | print(res_df.columns)
99 |
100 | .. code-block:: python
101 |
102 | Index(['angular_error_a', 'cross_correlation_img_a', 'dataset_id',
103 | 'iou_depth_1', 'iou_depth_2', 'iou_depth_3', 'iou_depth_4',
104 | 'iou_depth_5', 'jacobian_nonpositive_pixels_a',
105 | 'jacobian_nonpositive_pixels_perc_a', 'mae_img_a', 'mi_img_a',
106 | 'mse_img_a', 'norm_a', 'p', 'psnr_img_a', 'ssmi_img_a',
107 | 'vector_distance_a'],
108 | dtype='object')
109 |
--------------------------------------------------------------------------------
/docs/source/image_registration.rst:
--------------------------------------------------------------------------------
1 | Image Registration 101
2 | ======================
3 | Image registration is a task of geometrically transforming one image (moving) to a domain of another image (reference).
4 | This problem can be seen as two separate steps.
5 |
6 | **1) Prediction of a geometric transformation**
7 |
8 | .. image:: ../_images/image_registration.png
9 | :width: 600
10 | :alt: Inverse
11 | :align: center
12 |
13 | **2) Warping the moving image with the transformation**
14 |
15 | .. image:: ../_images/image_registration_2.png
16 | :width: 600
17 | :alt: Inverse
18 | :align: center
19 |
20 |
21 | There are multiple different algorithms that achieve the first step. This project focuses on supervised **deep learning**
22 | methods however also provides easy interface to other approaches - **feature** and **intensity** based registration.
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | .. _installation:
2 |
3 | Installation
4 | ============
5 | It is highly recommended to install the project into a new virtual environment.
6 |
7 | Python Requirements
8 | -------------------
9 | The project is only available for **Python 3.7**. The main reason for this
10 | restriction is an external dependency **ANTsPy** that does
11 | not provide many precompiled wheels on PyPI.
12 |
13 | External Dependencies
14 | ---------------------
15 | Some of the functionalities of :code:`atlalign` depend on the
16 | `TensorFlow implementation of the Learned Perceptual Image Patch Similarity `_.
17 | Unfortunately, the
18 | package is not available on PyPI and must be installed manually as follows.
19 |
20 | .. code-block:: bash
21 |
22 | pip install git+http://github.com/alexlee-gk/lpips-tensorflow.git#egg=lpips_tf
23 |
24 | You can now move on to installing the actual `atlalign` package!
25 |
26 | Installation from PyPI
27 | ----------------------
28 | The :code:`atlalign` package can be easily installed from PyPI.
29 |
30 | .. code-block:: bash
31 |
32 | pip install atlalign
33 |
34 | Installation from source
35 | ------------------------
36 | As an alternative to installing from PyPI, if you want to try the latest version
37 | you can also install from source.
38 |
39 | .. code-block:: bash
40 |
41 | pip install git+https://github.com/BlueBrain/atlas_alignment#egg=atlalign
42 |
43 | Development installation
44 | ------------------------
45 | For development installation one needs additional dependencies grouped in :code:`extras_requires` in the
46 | following way:
47 |
48 | - **dev** - pytest + plugins, flake8, pydocstyle, tox
49 | - **docs** - sphinx
50 |
51 | .. code-block:: bash
52 |
53 | git clone https://github.com/BlueBrain/atlas_alignment
54 | cd atlas_alignment
55 | pip install -e .[dev,docs]
56 |
57 |
58 | Generating documentation
59 | ------------------------
60 | To generate the documentation make sure you have dependencies from :code:`extras_requires` - :code:`docs`.
61 |
62 | .. code-block:: bash
63 |
64 | cd docs
65 | make clean && make html
66 |
67 | One can view the docs by opening :code:`docs/_build/html/index.html` in a browser.
68 |
--------------------------------------------------------------------------------
/docs/source/intensity.rst:
--------------------------------------------------------------------------------
1 | Intensity Based Registration
2 | ============================
3 | Intensity based registration refers to a set of algorithms that try to solve the registration problem via per sample
4 | maximization of an image similarity metric.
5 |
6 |
7 | ANTsPy
8 | ------
9 | We provide a very simple interface to an existing registration package caled **ANTsPy**. To find out more about
10 | the original package we refer the reader to
11 |
12 | - `github `_
13 | - `docs `_
14 |
15 |
16 | To use **ANTsPy** within :code:`atlalign` one can simply use :code:`atlalign.non_ml.antspy_registration`. See below
17 | a minimal example.
18 |
19 | .. code-block:: python
20 |
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 |
24 | from atlalign.base import DisplacementField
25 | from atlalign.data import circles
26 | from atlalign.non_ml import antspy_registration
27 |
28 | random_state = 4
29 | shape = (200, 230)
30 | p_drop_pixel = 0.1
31 |
32 | df = DisplacementField.generate(shape, approach='affine_simple', scale_x=1.4)
33 |
34 | img_ref = circles(1, shape, radius=75, n_levels=15, random_state=random_state)[0,..., 0]
35 | img_mov = df.warp(img_ref)
36 |
37 | to_drop = np.random.choice((True, False), p=(p_drop_pixel, 1 - p_drop_pixel), size=shape)
38 | img_ref[to_drop] = 0
39 |
40 |
41 | df_mov2ref_anstpy, _ = antspy_registration(img_ref, img_mov)
42 | df_mov2ref_truth = df.pseudo_inverse()
43 |
44 | img_reg_antspy = df_mov2ref_anstpy.warp(img_mov)
45 | img_reg_truth = df_mov2ref_truth.warp(img_mov)
46 |
47 | fig, ((ax_ref, ax_mov), (ax_reg_antspy, ax_reg_truth)) = plt.subplots(2, 2, figsize=(15, 15))
48 | ax_ref.imshow(img_ref)
49 | ax_ref.set_axis_off()
50 | ax_ref.set_title('Reference')
51 |
52 | ax_mov.imshow(img_mov)
53 | ax_mov.set_axis_off()
54 | ax_mov.set_title('Moving')
55 |
56 | ax_reg_antspy.imshow(img_reg_antspy)
57 | ax_reg_antspy.set_axis_off()
58 | ax_reg_antspy.set_title('Registered - ANTsPy')
59 |
60 | ax_reg_truth.imshow(img_reg_truth)
61 | ax_reg_truth.set_axis_off()
62 | ax_reg_truth.set_title('Registered - Ground Truth')
63 |
64 | .. image:: ../_images/antspy.png
65 | :width: 600
66 | :alt: ANTsPy
67 | :align: center
--------------------------------------------------------------------------------
/docs/source/labeling_tool.rst:
--------------------------------------------------------------------------------
1 | Labeling tool
2 | =============
3 | The goal of the labeling tool is to allow the user to align manually any two images and extract
4 | the displacement field.
5 |
6 | Launching
7 | ---------
8 | To launch the labeling tool one needs to use the entry point ``label-tool``.
9 |
10 | .. code-block:: bash
11 |
12 | label-tool --help
13 | usage: label-tool [-h] [-s] ref mov output_path
14 |
15 | positional arguments:
16 | ref Either a path to a reference image or a number from [0, 528)
17 | representing the coronal dimension in the nissl stain volume.
18 | mov Path to a moving image. Needs to be of the same shape as
19 | reference.
20 | output_path Folder where the outputs will be stored.
21 |
22 | optional arguments:
23 | -h, --help show this help message and exit
24 | -s, --swap Swap to the moving to reference mode. (default: False)
25 |
26 | Arguments
27 | ~~~~~~~~~
28 |
29 | 1. ``ref`` - number from 0 - 527 representing coronal section or a path to a custom reference image
30 | 2. ``mov`` - path to the moving image that needs to have the same shape as the reference one
31 | 3. ``output_path`` - path to folder where output is saved
32 |
33 | Options
34 | ~~~~~~~
35 |
36 | - ``--swap`` - flag that if activated the first landmark of a pair is on moving image otherwise it
37 | is on the reference
38 | - ``--force-grayscale`` - force color images to be converted to grayscale, can be useful for taking
39 | advantage of the colormap selection which does not work for color images.
40 |
41 | Interface
42 | ---------
43 | .. image:: ../_images/labeling_tool.png
44 | :width: 600
45 | :alt: Inverse
46 | :align: center
47 |
48 | The user adds landmark points until they are fully happy with the registration. Once the GUI is
49 | closed the resulting displacement field, the registered and moving images, the key points, and other
50 | metadata are stored into the :code:`OUTPUT_PATH`.
51 |
52 | 1
53 | ~
54 | The **interactive window** where the user specifies pairs of points that represent the landmarks in
55 | the moving (resp. reference) image. The landmarks in the moving image are represented by **crosses**
56 | and the ones in the reference image by **circles**. To delete a point hover your mouse over a
57 | reference point and press "d".
58 |
59 | 2
60 | ~
61 | The **result window** that displays the current registered image and the reference. Note that one
62 | can also see the deformation of a regular grid if one clicks on the **Show grid** button.
63 |
64 | 3
65 | ~
66 | **Sliders**
67 |
68 | - **Alpha Moving/Registered** - Controls the blending together of the reference and moving images.
69 | Value 0 corresponds to only the reference image, value 1 to only the moving image. The slider
70 | remembers two different settings that can be toggled using the space bar. This can be useful for
71 | quickly getting the full view of one of the images.
72 | - **Alpha Ref** - Control the translucency of the reference image.
73 | - **Moving/Registered Threshold** - Remove pixels below a given intensity from the moving image.
74 | - **Reference Threshold** - Remove pixels below a given intensity from the reference image.
75 |
76 | **Buttons**
77 |
78 | - **Reset** - Deletes all registration points.
79 | - **Symmetric Registration** - Toggle symmetry registration where all keypoints are automatically
80 | mirrored in the left/right direction.
81 | - **Change Order** - Swap the moving and reference image order.
82 | - **Show Arrows** - Toggle the visibility of lines between keypoints.
83 | - **Show Grid** - Show regular grid warped by the current displacement. Red spots represent regions
84 | that are not invertible.
85 |
86 | 4
87 | ~
88 | **Color maps** for the reference and moving image.
89 |
90 | 5
91 | ~
92 | **Overview statistics**
93 |
94 | - **Transformation quality** - If lower than 100% then displacement contains folds. Check the grid
95 | to see where exactly.
96 | - **Average displacement** - Average displacement size
97 |
98 | **Keyboard shortcuts**
99 |
100 | - **"a"** - pan
101 | - **"s"** - zoom
102 | - **"d"** - delete a keypoint
103 | - **"f"** - reset zoom
104 | - **space** - toggle the reference/moving blending alpha value
105 |
106 | Note that after zooming/panning the respective shortcut key needs to be pressed again to get back
107 | to the keypoint addition mode.
108 |
109 | 6
110 | ~
111 | **Interpolation window**. The first column represents the general algorithm
112 |
113 | - **griddata** - Delaunay triangulation followed by affine interpolation for each triangle
114 | - **rbf** - Interpolation using basis function
115 |
116 | The second column is only used for the **rbf** interpolation and it specifies the kernel.
117 |
118 | FAQ
119 | ---
120 |
121 | - **Can the moving and reference image have different sizes?** No, make sure they are the same and
122 | one can resize the displacement field after registration
123 | - **How do I save the displacement?** Simply by closing the GUI
124 | - **Can I delete some landmark pairs?** Yes, point at the reference landmark and press space bar
125 |
--------------------------------------------------------------------------------
/docs/source/logo/Atlas_Alignment_banner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/source/logo/Atlas_Alignment_banner.jpg
--------------------------------------------------------------------------------
/scripts/experiment_1.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from atlalign.base import DisplacementField
7 | from atlalign.data import (
8 | annotation_volume,
9 | manual_registration,
10 | nissl_volume,
11 | segmentation_collapsing_labels,
12 | )
13 | from atlalign.metrics import evaluate_single
14 | from atlalign.ml_utils import load_model
15 |
16 | # Define all relevant paths
17 | cache_dir = pathlib.Path.home() / ".atlalign"
18 |
19 | annotation_path = cache_dir / "annotation.npy"
20 | h5_path = cache_dir / "manual_registration.h5"
21 | labels_path = cache_dir / "annotation_hierarchy.json"
22 | model_local_path = cache_dir / "local.h5"
23 | model_global_path = cache_dir / "global.h5"
24 | nissl_path = cache_dir / "nissl.npy"
25 |
26 |
27 | # Validation dataset definition
28 | validation_ixs = [
29 | 2,
30 | 7,
31 | 15,
32 | 21,
33 | 25,
34 | 28,
35 | 37,
36 | 40,
37 | 49,
38 | 50,
39 | 54,
40 | 55,
41 | 56,
42 | 58,
43 | 60,
44 | 62,
45 | 64,
46 | 66,
47 | 78,
48 | 85,
49 | 93,
50 | 97,
51 | 121,
52 | 123,
53 | 125,
54 | 127,
55 | 129,
56 | 130,
57 | 133,
58 | 137,
59 | 140,
60 | 143,
61 | 144,
62 | 154,
63 | 169,
64 | 175,
65 | 183,
66 | 187,
67 | 192,
68 | 204,
69 | 206,
70 | 214,
71 | 216,
72 | 219,
73 | 223,
74 | 225,
75 | 226,
76 | 248,
77 | 251,
78 | 252,
79 | 254,
80 | 260,
81 | 271,
82 | 276,
83 | ]
84 |
85 |
86 | manual_labels = manual_registration(h5_path)
87 | nissl = nissl_volume(nissl_path)
88 | annotation = annotation_volume(annotation_path)
89 | labels = segmentation_collapsing_labels(labels_path)
90 |
91 |
92 | validation_set = {}
93 | keys = manual_labels.keys()
94 |
95 | for val_ix in validation_ixs:
96 | validation_set[val_ix] = {}
97 | for key in keys:
98 | validation_set[val_ix][key] = manual_labels[key][val_ix]
99 |
100 |
101 | model_local = load_model(model_local_path)
102 | model_global = load_model(model_global_path)
103 |
104 |
105 | metrics = {}
106 | df_locals = {}
107 |
108 | for k, data in validation_set.items():
109 | print(k)
110 |
111 | # Preparation
112 | img_mov = data["img"] / 255
113 | p = data["p"]
114 | section_num = p // 25
115 | img_ref = nissl[section_num][..., 0]
116 |
117 | # Global model
118 | inp_global = np.stack([img_ref, img_mov], axis=-1)[None, ...]
119 | deltas_xy_global = model_global.predict(inp_global)[1][0]
120 | df_global = DisplacementField(deltas_xy_global[..., 0], deltas_xy_global[..., 1])
121 |
122 | # Local model
123 | inp_local = np.stack([img_ref, df_global.warp(img_mov)], axis=-1)[None, ...]
124 | deltas_xy_local = model_local.predict([inp_local, np.zeros_like(inp_local)])[1][0]
125 | df_local = DisplacementField(deltas_xy_local[..., 0], deltas_xy_local[..., 1])
126 | df_locals[k] = df_local
127 |
128 | # Overall model
129 | df_pred = df_local(df_global)
130 |
131 | deltas_true = data["deltas_xy"]
132 | deltas_pred = np.stack([df_pred.delta_x, df_pred.delta_y], axis=-1)
133 |
134 | metrics[k], _ = evaluate_single(
135 | deltas_true,
136 | deltas_pred,
137 | img_mov,
138 | p=p,
139 | avol=annotation,
140 | collapsing_labels=labels,
141 | deltas_pred_inv=None,
142 | deltas_true_inv=data["inv_deltas_xy"],
143 | ds_f=8,
144 | depths=(0, 2, 4, 6, 8),
145 | )
146 | metrics[k]["p"] = p
147 | metrics[k]["image_id"] = data["image_id"]
148 |
149 |
150 | df = pd.DataFrame(metrics).transpose()
151 |
152 | cols = ["dice_0", "dice_2", "dice_4", "dice_6", "dice_8"]
153 | df_overview = pd.DataFrame({"mean": df[cols].mean(), "std": df[cols].std()})
154 |
155 | print(df_overview)
156 |
157 |
158 | # Corrupted pixels analysis - we exclude border pixels
159 | h, w = 320, 456
160 | margin_ud = 1
161 | margin_lr = 6 # the network pads the sides with 0s
162 |
163 | corrupted_pixels = np.array(
164 | [
165 | np.sum(x.jacobian[margin_ud : h - margin_ud, margin_lr : w - margin_lr] <= 0)
166 | for x in df_locals.values()
167 | ]
168 | )
169 |
170 | n_pixels = (h - 2 * (margin_ud)) * (w - 2 * (margin_lr))
171 |
172 |
173 | mean = 100 * (corrupted_pixels.mean() / n_pixels)
174 | std = 100 * (corrupted_pixels.std() / n_pixels)
175 |
176 | print("Corrupted_pixels")
177 | print(
178 | "Up and Down cuttoff: {}\nLeft and Right cutoff: {}\n".format(margin_ud, margin_lr)
179 | )
180 | print("Mean:\n{}\n".format(mean))
181 | print("STD:\n{}".format(std))
182 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | import atlalign
4 |
5 | # install with `pip install -e .`
6 |
7 | # Preparations
8 | VERSION = atlalign.__version__
9 | DESCRIPTION = "Blue Brain multi-modal registration and alignment toolbox"
10 |
11 | LONG_DESCRIPTION = """
12 | Atlas Alignment is a toolbox to perform multimodal image registration. It
13 | includes both traditional and supervised deep learning models.
14 |
15 | This project originated from the Blue Brain Project efforts on aligning mouse
16 | brain atlases obtained with ISH gene expression and Nissl stains."""
17 |
18 | PYTHON_REQUIRES = ">=3.6.0"
19 | INSTALL_REQUIRES = [
20 | "antspyx",
21 | "imgaug<0.3",
22 | "matplotlib>=3.0.3",
23 | "mlflow",
24 | "nibabel>=2.4.0",
25 | "numpy<1.24.0",
26 | "seaborn",
27 | "scikit-image>=0.17.1",
28 | "scikit-learn>=0.20.2",
29 | "scipy",
30 | "tensorflow>=2.6.0",
31 | "tensorflow-addons", # For resampler in atlalign/ml_utils/layers.py
32 | ]
33 |
34 | setup(
35 | name="atlalign",
36 | version=VERSION,
37 | description=DESCRIPTION,
38 | long_description=LONG_DESCRIPTION,
39 | url="https://github.com/BlueBrain/atlas_alignment",
40 | author="Blue Brain Project, EPFL",
41 | license="LGPLv3",
42 | packages=find_packages(),
43 | classifiers=[
44 | "Development Status :: 4 - Beta",
45 | "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
46 | "Operating System :: Unix",
47 | "Programming Language :: Python",
48 | "Programming Language :: Python :: 3",
49 | "Programming Language :: Python :: 3.6",
50 | "Programming Language :: Python :: 3.7",
51 | "Programming Language :: Python :: 3.8",
52 | "Programming Language :: Python :: 3.9",
53 | "Programming Language :: Python :: 3 :: Only",
54 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
55 | "Topic :: Scientific/Engineering :: Image Processing",
56 | "Topic :: Software Development :: Libraries :: Python Modules",
57 | ],
58 | python_requires=PYTHON_REQUIRES,
59 | install_requires=INSTALL_REQUIRES,
60 | extras_require={
61 | "dev": [
62 | "black>=20.8b1",
63 | "flake8>=3.7.4",
64 | "pydocstyle>=3.0.0",
65 | "pytest>=3.10.1",
66 | "pytest-cov",
67 | "pytest-mock>=1.10.1",
68 | ],
69 | "docs": ["sphinx>=1.3", "sphinx-bluebrain-theme"],
70 | },
71 | entry_points={"console_scripts": ["label-tool = atlalign.label.cli:main"]},
72 | )
73 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Define all fixtures.
2 |
3 | Notes
4 | -----
5 | A lot of the fixtures are built on top of the following 4 atomic fixtures:
6 | * img_grayscale_uint
7 | * img_grayscale_float
8 | * img_rgb_uint
9 | * img_rgb_float
10 |
11 | The fixtures below are just parameterized in a way that only a subset of the above
12 | atomic fixtures is used
13 | * img
14 | * img_grayscale
15 | * img_rgb
16 | * img_uint
17 | * img_float
18 | """
19 |
20 | """
21 | The package atlalign is a tool for registration of 2D images.
22 |
23 | Copyright (C) 2021 EPFL/Blue Brain Project
24 |
25 | This program is free software: you can redistribute it and/or modify
26 | it under the terms of the GNU Lesser General Public License as published by
27 | the Free Software Foundation, either version 3 of the License, or
28 | (at your option) any later version.
29 |
30 | This program is distributed in the hope that it will be useful,
31 | but WITHOUT ANY WARRANTY; without even the implied warranty of
32 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
33 | GNU Lesser General Public License for more details.
34 |
35 | You should have received a copy of the GNU Lesser General Public License
36 | along with this program. If not, see .
37 | """
38 |
39 | import pathlib
40 |
41 | import cv2
42 | import numpy as np
43 | import pytest
44 | from skimage.util import img_as_float32
45 |
46 | from atlalign.base import DisplacementField
47 |
48 | SHAPE = (20, 30) # used for `img_dummy` and `img_random`
49 | RANDOM_STATE = 2
50 |
51 | ROOT_PATH = pathlib.Path(__file__).resolve().parent.parent
52 |
53 |
54 | @pytest.fixture(scope="session")
55 | def path_test_data():
56 | return ROOT_PATH / "tests" / "data"
57 |
58 |
59 | @pytest.fixture(scope="session")
60 | def img_grayscale_uint(path_test_data):
61 | """Generate a grayscale image with dtype uint8."""
62 | file_path = path_test_data / "animals.jpg"
63 |
64 | img_out = cv2.imread(str(file_path), 0)
65 |
66 | assert img_out.ndim == 2
67 | assert img_out.dtype == np.uint8
68 | assert np.all(img_out >= 0) and np.all(img_out <= 255)
69 |
70 | return img_out
71 |
72 |
73 | @pytest.fixture()
74 | def img_grayscale_float(img_grayscale_uint):
75 | """Generate a float32 version of the grayscale image."""
76 | img_out = img_as_float32(img_grayscale_uint, force_copy=True)
77 |
78 | assert img_out.ndim == 2
79 | assert img_out.dtype == np.float32
80 | assert np.all(img_out >= 0) and np.all(img_out <= 1)
81 |
82 | return img_out
83 |
84 |
85 | @pytest.fixture(scope="session")
86 | def img_rgb_uint(path_test_data):
87 | """Just a RGB image with dtype uint8.
88 |
89 | Notes
90 | -----
91 | OpenCV reads a 3 channel image as 'BGR' but it is not a problem for our purposes.
92 |
93 | """
94 | file_path = path_test_data / "animals.jpg"
95 |
96 | img_out = cv2.imread(str(file_path), 1)
97 |
98 | assert img_out.ndim == 3
99 | assert img_out.dtype == np.uint8
100 | assert np.all(img_out >= 0) and np.all(img_out <= 255)
101 |
102 | return img_out
103 |
104 |
105 | @pytest.fixture()
106 | def img_rgb_float(img_rgb_uint):
107 | """Generate a float32 version of the rgb image."""
108 | img_out = img_as_float32(img_rgb_uint, force_copy=True)
109 |
110 | assert img_out.ndim == 3
111 | assert img_out.dtype == np.float32
112 | assert np.all(img_out >= 0) and np.all(img_out <= 1)
113 |
114 | return img_out
115 |
116 |
117 | @pytest.fixture(
118 | params=["grayscale_uint8", "grayscale_float32", "RGB_uint8", "RGB_float32"]
119 | )
120 | def img(request, img_rgb_uint, img_grayscale_uint, img_rgb_float, img_grayscale_float):
121 | """Generate parametrized fixture capturing all 4 possible uint8/float32 and grayscale/rgb combinations.
122 |
123 | Notes
124 | -----
125 | If this fixture used then the test will run automatically on all 4 of these.
126 |
127 | """
128 | img_type = request.param
129 |
130 | if img_type == "grayscale_uint8":
131 | return img_grayscale_uint
132 |
133 | elif img_type == "grayscale_float32":
134 | return img_grayscale_float
135 |
136 | elif img_type == "RGB_uint8":
137 | return img_rgb_uint
138 |
139 | elif img_type == "RGB_float32":
140 | return img_rgb_float
141 |
142 | else:
143 | raise ValueError("Unrecognized image type {}".format(img_type))
144 |
145 |
146 | @pytest.fixture(params=["grayscale_float32", "RGB_float32"])
147 | def img_float(request, img_rgb_float, img_grayscale_float):
148 | """Generate a parametrized fixture capturing all 2 possible float32 -> grayscale and rgb.
149 |
150 | Notes
151 | -----
152 | If this fixture used then the test will run automatically on all 2 of these.
153 |
154 | """
155 | img_type = request.param
156 |
157 | if img_type == "grayscale_float32":
158 | return img_grayscale_float
159 |
160 | elif img_type == "RGB_float32":
161 | return img_rgb_float
162 |
163 | else:
164 | raise ValueError("Unrecognized image type {}".format(img_type))
165 |
166 |
167 | @pytest.fixture(params=["grayscale_uint8", "RGB_uint8"])
168 | def img_uint(request, img_rgb_uint, img_grayscale_uint):
169 | """Generate a parametrized fixture capturing all 2 possible uint8 -> grayscale and rgb.
170 |
171 | Notes
172 | -----
173 | If this fixture used then the test will run automatically on all 2 of these.
174 |
175 | """
176 | img_type = request.param
177 |
178 | if img_type == "grayscale_uint8":
179 | return img_grayscale_uint
180 |
181 | elif img_type == "RGB_uint8":
182 | return img_rgb_uint
183 |
184 | else:
185 | raise ValueError("Unrecognized image type {}".format(img_type))
186 |
187 |
188 | @pytest.fixture(params=["grayscale_uint8", "grayscale_float32"])
189 | def img_grayscale(request, img_grayscale_uint, img_grayscale_float):
190 | """Generate a parametrized fixture capturing all 2 possible grayscale -> uint8 and float32.
191 |
192 | Notes
193 | -----
194 | If this fixture used then the test will run automatically on all 2 of these.
195 |
196 | """
197 | img_type = request.param
198 |
199 | if img_type == "grayscale_uint8":
200 | return img_grayscale_uint
201 |
202 | elif img_type == "grayscale_float32":
203 | return img_grayscale_float
204 |
205 | else:
206 | raise ValueError("Unrecognized image type {}".format(img_type))
207 |
208 |
209 | @pytest.fixture(params=["RGB_uint8", "RGB_float32"])
210 | def img_rgb(request, img_rgb_uint, img_rgb_float):
211 | """Generate a parametrized fixture capturing all 2 possible rgb -> uint8 and float32.
212 |
213 | Notes
214 | -----
215 | If this fixture used then the test will run automatically on all 2 of these.
216 |
217 | """
218 | img_type = request.param
219 |
220 | if img_type == "RGB_uint8":
221 | return img_rgb_uint
222 |
223 | elif img_type == "RGB_float32":
224 | return img_rgb_float
225 |
226 | else:
227 | raise ValueError("Unrecognized image type {}".format(img_type))
228 |
229 |
230 | @pytest.fixture(scope="function") # In order to allow for in place changes
231 | def img_dummy():
232 | """Generate a dummy image made out of all zeros."""
233 | return np.zeros(SHAPE, dtype=np.float32)
234 |
235 |
236 | @pytest.fixture(scope="function")
237 | def img_random():
238 | """Generate a dummy images made out of random intensities."""
239 | np.random.seed(RANDOM_STATE)
240 | out = np.random.random(SHAPE).astype(dtype=np.float32)
241 |
242 | assert out.dtype == np.float32
243 |
244 | return out
245 |
246 |
247 | @pytest.fixture(scope="session")
248 | def df_cached(path_test_data):
249 | """Load a DVF and its inverse of shape (80, 114).
250 |
251 | DVF represents a mild warping then can be relatively easily unwarped.
252 |
253 | Notes
254 | -----
255 | After composition the largest displacement in x ~ 0.05 and in y ~ 0.03.
256 |
257 | Returns
258 | -------
259 | delta_x : np.array
260 | DVF in the x direction.
261 |
262 | delta_y : np.array
263 | DVF in the y direction.
264 |
265 | delta_x_inv : np.array
266 | Inverse DVF in the x direction.
267 |
268 | delta_x_inv : np.array
269 | Inverse DVF in the y direction.
270 |
271 | """
272 | file_path = path_test_data / "mild_inversion.npy"
273 | a = np.load(str(file_path))
274 |
275 | delta_x = a[..., 0]
276 | delta_y = a[..., 1]
277 | delta_x_inv = a[..., 2]
278 | delta_y_inv = a[..., 3]
279 |
280 | return delta_x, delta_y, delta_x_inv, delta_y_inv
281 |
282 |
283 | @pytest.fixture()
284 | def df_id(request):
285 | """Generate an identity transformation.
286 |
287 | In order to specify a shape one decorates the test function in the following way:
288 | `@pytest.mark.parametrize('df_id', [(320, 456)], indirect=True)`
289 |
290 | """
291 | if hasattr(request, "param"):
292 | shape = request.param
293 | else:
294 | shape = (10, 11)
295 | return DisplacementField.generate(shape, approach="identity")
296 |
297 |
298 | @pytest.fixture()
299 | def label_dict():
300 | """Gnerate a dictionary for the concatenation of labels (segmentation)."""
301 | dic = {
302 | "id": 2,
303 | "children": [
304 | {"id": 3, "children": []},
305 | {
306 | "id": 4,
307 | "children": [{"id": 5, "children": []}, {"id": 6, "children": []}],
308 | },
309 | ],
310 | }
311 |
312 | return dic
313 |
--------------------------------------------------------------------------------
/tests/data/animals.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/tests/data/animals.jpg
--------------------------------------------------------------------------------
/tests/data/mild_inversion.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/tests/data/mild_inversion.npy
--------------------------------------------------------------------------------
/tests/data/supervised_dataset.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/tests/data/supervised_dataset.h5
--------------------------------------------------------------------------------
/tests/test_augmentations.py:
--------------------------------------------------------------------------------
1 | """Collection of tests focused on the augmentations.py module."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 | import pathlib
22 | from unittest.mock import Mock
23 |
24 | import numpy as np
25 | import pytest
26 |
27 | from atlalign.augmentations import DatasetAugmenter, load_dataset_in_memory
28 | from atlalign.base import DisplacementField
29 |
30 |
31 | @pytest.mark.parametrize(
32 | "key",
33 | [
34 | "dataset_id",
35 | "deltas_xy",
36 | "image_id",
37 | "img",
38 | "inv_deltas_xy",
39 | "p",
40 | ],
41 | )
42 | def test_load_dataset_in_memory(path_test_data, key):
43 | h5_path = path_test_data / "supervised_dataset.h5"
44 |
45 | res = load_dataset_in_memory(h5_path, key)
46 |
47 | assert isinstance(res, np.ndarray)
48 | assert len(res) > 0
49 |
50 |
51 | class TestDatasetAugmenter:
52 | def test_construction(self, path_test_data):
53 | da = DatasetAugmenter(path_test_data / "supervised_dataset.h5")
54 |
55 | assert da.n_orig > 0
56 |
57 | @pytest.mark.parametrize("n_iter", [1, 2])
58 | @pytest.mark.parametrize("anchor", [True, False])
59 | @pytest.mark.parametrize("is_valid", [True, False])
60 | def test_augment(
61 | self, monkeypatch, path_test_data, tmpdir, n_iter, anchor, is_valid
62 | ):
63 | fake_es = Mock(
64 | return_value=DisplacementField.generate(
65 | (320, 456), approach="affine_simple", rotation=0.2
66 | )
67 | )
68 | max_corrupted_pixels = 10
69 |
70 | if not is_valid:
71 | max_corrupted_pixels = 0 # hack that will force to use the original
72 |
73 | monkeypatch.setattr("atlalign.zoo.edge_stretching", fake_es)
74 |
75 | da = DatasetAugmenter(path_test_data / "supervised_dataset.h5")
76 |
77 | output_path = pathlib.Path(str(tmpdir)) / "output.h5"
78 |
79 | da.augment(
80 | output_path,
81 | n_iter=n_iter,
82 | anchor=anchor,
83 | max_trials=2,
84 | max_corrupted_pixels=max_corrupted_pixels,
85 | ds_f=32,
86 | )
87 |
88 | assert output_path.exists()
89 |
90 | keys = ["dataset_id", "deltas_xy", "image_id", "img", "inv_deltas_xy", "p"]
91 | for key in keys:
92 | array = load_dataset_in_memory(output_path, key)
93 | assert da.n_orig * n_iter == len(array)
94 | assert not np.any(np.isnan(array))
95 |
96 | keys = ["dataset_id", "image_id", "p"]
97 | for key in keys:
98 | original_a = load_dataset_in_memory(da.original_path, key)
99 | new_a = load_dataset_in_memory(output_path, key)
100 | new_a_expected = np.concatenate(n_iter * [original_a])
101 |
102 | np.testing.assert_equal(new_a, new_a_expected)
103 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | """
2 | The package atlalign is a tool for registration of 2D images.
3 |
4 | Copyright (C) 2021 EPFL/Blue Brain Project
5 |
6 | This program is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU Lesser General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | (at your option) any later version.
10 |
11 | This program is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU Lesser General Public License for more details.
15 |
16 | You should have received a copy of the GNU Lesser General Public License
17 | along with this program. If not, see .
18 | """
19 |
20 | import pathlib
21 |
22 | import numpy as np
23 | import pytest
24 |
25 | from atlalign.data import (
26 | annotation_volume,
27 | circles,
28 | manual_registration,
29 | nissl_volume,
30 | rectangles,
31 | segmentation_collapsing_labels,
32 | )
33 | from atlalign.utils import _find_all_children
34 |
35 |
36 | class TestAnnotationVolume:
37 | """A collection of tests focused on the `annotation_volume` function."""
38 |
39 | def test_load_works(self, monkeypatch, tmpdir):
40 | """Test that loading works"""
41 | # Lets do some patching
42 | monkeypatch.setattr(
43 | "numpy.load",
44 | lambda *args, **kwargs: np.zeros((528, 320, 456), dtype=np.float32),
45 | )
46 |
47 | x_atlas = annotation_volume()
48 |
49 | # Final output
50 | assert x_atlas.shape == (528, 320, 456)
51 | assert np.all(np.isfinite(x_atlas))
52 | assert x_atlas.dtype == np.int32
53 |
54 |
55 | class TestNisslVolume:
56 | """A collection of tests focused on the `nissl_volume` function."""
57 |
58 | def test_load_works(self, monkeypatch):
59 | """Test that loading works."""
60 | # Lets do some patching
61 | monkeypatch.setattr(
62 | "numpy.load",
63 | lambda *args, **kwargs: np.zeros((528, 320, 456), dtype=np.float32),
64 | )
65 | monkeypatch.setattr(
66 | "atlalign.data.img_as_float32",
67 | lambda *args, **kwargs: np.zeros((320, 456), dtype=np.float32),
68 | )
69 |
70 | x_atlas = nissl_volume()
71 |
72 | # Final output
73 | assert x_atlas.shape == (528, 320, 456, 1)
74 | assert np.all(np.isfinite(x_atlas))
75 | assert x_atlas.min() >= 0
76 | assert x_atlas.max() <= 1
77 | assert x_atlas.dtype == np.float32
78 |
79 |
80 | class TestManualRegistration:
81 | def test_correct_keys(self, monkeypatch):
82 | path = "tests/data/supervised_dataset.h5" # manual patch
83 |
84 | res = manual_registration(path)
85 |
86 | # Final output
87 | assert set(res.keys()) == {
88 | "dataset_id",
89 | "deltas_xy",
90 | "image_id",
91 | "img",
92 | "inv_deltas_xy",
93 | "p",
94 | }
95 | assert np.all([isinstance(x, np.ndarray) for x in res.values()])
96 |
97 |
98 | class TestCircles:
99 | """Collection of tests focues on the circles function."""
100 |
101 | def test_input_shape(self):
102 | """Test that only allows for 2D."""
103 |
104 | shape_wrong = (41, 21, 312)
105 | shape_correct = (100, 30)
106 |
107 | with pytest.raises(ValueError):
108 | circles(10, shape_wrong, 10)
109 |
110 | circles(10, shape_correct, 10)
111 |
112 | def test_correct_dtype(self):
113 | """Test that float32 ndarray is returned."""
114 |
115 | shape = (100, 30)
116 | res = circles(10, shape, 10)
117 |
118 | assert res.min() >= 0
119 | assert res.max() <= 1
120 | assert res.dtype == np.float32
121 |
122 | @pytest.mark.parametrize("n_levels", [1, 2, 3, 4, 5, 6])
123 | def test_correct_number_of_intensities(self, n_levels):
124 | """Test whether n_unique_intensities = n_levels + 1"""
125 |
126 | res = circles(10, (200, 220), (40, 50), n_levels=n_levels)
127 |
128 | for row in res:
129 | assert len(np.unique(row) == n_levels + 1) # also count black background
130 |
131 | def test_reproducible(self):
132 | """Test that random_state works."""
133 |
134 | res_1 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=None)
135 | res_2 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=1)
136 | res_3 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=2)
137 | res_4 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=1)
138 | res_5 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=None)
139 |
140 | assert np.all(res_2 == res_4)
141 | assert not np.all(res_2 == res_3)
142 | assert not np.all(res_2 == res_1)
143 | assert not np.all(res_3 == res_1)
144 | assert not np.all(res_1 == res_5)
145 |
146 |
147 | class TestRectangles:
148 | """A collection of tests testing the rectangles function."""
149 |
150 | def test_input_shape(self):
151 | """Test that only allows for 2D."""
152 |
153 | shape_wrong = (40, 30, 10)
154 | shape_correct = (40, 30)
155 |
156 | with pytest.raises(ValueError):
157 | rectangles(100, shape_wrong, 10, 20)
158 |
159 | rectangles(100, shape_correct, 10, 20)
160 |
161 | def test_wrong_type(self):
162 | """Test that height, width and n_levels only work with integers."""
163 |
164 | shape = (40, 50)
165 | with pytest.raises(TypeError):
166 | rectangles(100, shape, 10.1, 20)
167 |
168 | with pytest.raises(TypeError):
169 | rectangles(100, shape, 10, 20.3)
170 |
171 | with pytest.raises(TypeError):
172 | rectangles(100, shape, (10, 20), 20, n_levels=3.4)
173 |
174 | rectangles(10, shape, 10, 20, n_levels=2)
175 |
176 | def test_wrong_rectangle_size(self):
177 | """Test that rectangle needs to fit the image."""
178 |
179 | shape = (img_h, img_w) = (40, 50)
180 |
181 | with pytest.raises(ValueError):
182 | rectangles(100, shape, img_h + 1, img_w - 1)
183 |
184 | with pytest.raises(ValueError):
185 | rectangles(100, shape, img_h - 1, img_w + 1)
186 |
187 | with pytest.raises(ValueError):
188 | rectangles(100, shape, img_h + 1, img_w + 1)
189 |
190 | rectangles(100, shape, img_h - 1, img_w - 1)
191 |
192 | def test_wrong_n_levels(self):
193 | """Test that n_levels needs to be correct."""
194 |
195 | with pytest.raises(ValueError):
196 | rectangles(5, (100, 120), 20, 10, n_levels=14)
197 |
198 | with pytest.raises(ValueError):
199 | rectangles(5, (100, 120), 10, 20, n_levels=14)
200 |
201 | rectangles(5, (100, 120), 20, 20, n_levels=14)
202 |
203 | def test_reproducible(self):
204 | """Test that random_state works."""
205 |
206 | res_1 = rectangles(
207 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=None
208 | )
209 | res_2 = rectangles(
210 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=1
211 | )
212 | res_3 = rectangles(
213 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=2
214 | )
215 | res_4 = rectangles(
216 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=1
217 | )
218 | res_5 = rectangles(
219 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=None
220 | )
221 |
222 | assert np.all(res_2 == res_4)
223 | assert not np.all(res_2 == res_3)
224 | assert not np.all(res_2 == res_1)
225 | assert not np.all(res_3 == res_1)
226 | assert not np.all(res_1 == res_5)
227 |
228 | @pytest.mark.parametrize("random_state", [0, 1, 2, 3, 4])
229 | def test_no_empty_images(self, random_state):
230 | """Test that no empty images."""
231 |
232 | shape = (100, 120)
233 | res = rectangles(
234 | 10, shape, (20, 30), (10, 50), n_levels=4, random_state=random_state
235 | )
236 |
237 | zeros = np.zeros((*shape, 1))
238 | for row in res:
239 | assert not np.all(row == zeros)
240 |
241 | def test_output_shape(self):
242 |
243 | """Test that the shape of the output is correct."""
244 | shape = (50, 100)
245 | res = rectangles(10, shape, (20, 30), (10, 50), n_levels=4)
246 |
247 | assert res.shape == (10, *shape, 1)
248 |
249 | def test_correct_dtype(self):
250 | """Test that float32 ndarray is returned."""
251 |
252 | shape = (50, 100)
253 | res = rectangles(10, shape, (20, 30), (10, 50), n_levels=4)
254 |
255 | assert res.min() >= 0
256 | assert res.max() <= 1
257 | assert res.dtype == np.float32
258 |
259 | def test_full_intensity(self):
260 | """Test that there exists a pixel with a full intensity = 1."""
261 |
262 | shape = (50, 100)
263 | res = rectangles(10, shape, (20, 30), (10, 50), n_levels=4)
264 |
265 | for row in res:
266 | assert np.any(row == 1)
267 |
268 | @pytest.mark.parametrize("n_levels", [1, 2, 3, 4, 5, 6])
269 | def test_correct_number_of_intensities(self, n_levels):
270 | """Test whether n_unique_intensities = n_levels + 1"""
271 |
272 | res = rectangles(10, (200, 220), (40, 50), (50, 100), n_levels=n_levels)
273 |
274 | for row in res:
275 | assert len(np.unique(row) == n_levels + 1) # also count black background
276 |
277 |
278 | class TestSegmentationCollapsingLabels:
279 | def test_load_works(self, monkeypatch, tmpdir):
280 | """Test that loading works."""
281 | tmpfile = pathlib.Path(str(tmpdir)) / "temp.json"
282 | tmpfile.touch()
283 |
284 | # Lets do some patching
285 | monkeypatch.setattr("json.load", lambda *args, **kwargs: {})
286 |
287 | res = segmentation_collapsing_labels(tmpfile)
288 |
289 | # Final output
290 | assert isinstance(res, dict)
291 |
292 | @pytest.mark.skip
293 | def test_no_id_equal_to_negative_one(self):
294 | """Make sure that -1 is not an existing label since we want to use it as a default not found value.
295 |
296 | We skip this bacause it assumes the dataset is downloaded and it wouldnt make sense to patch this.
297 |
298 | See Also
299 | --------
300 | """
301 |
302 | all_children = _find_all_children(segmentation_collapsing_labels())
303 |
304 | assert isinstance(all_children, list)
305 | assert all_children
306 | assert -1 not in all_children
307 |
--------------------------------------------------------------------------------
/tests/test_ml_utils/test_augmentation.py:
--------------------------------------------------------------------------------
1 | """Collection of tests focused on the augmentation.py moduele."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import imgaug.augmenters as iaa
23 |
24 | from atlalign.ml_utils import augmenter_1
25 |
26 |
27 | def test_runnable():
28 | aug_1 = augmenter_1()
29 |
30 | assert isinstance(aug_1, iaa.Augmenter)
31 |
--------------------------------------------------------------------------------
/tests/test_ml_utils/test_callbacks.py:
--------------------------------------------------------------------------------
1 | """Tests focused on callbacks."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import pathlib
23 | from unittest.mock import Mock
24 |
25 | import h5py
26 | import mlflow
27 | import numpy as np
28 | import pandas as pd
29 | import pytest
30 |
31 | from atlalign.augmentations import load_dataset_in_memory
32 | from atlalign.ml_utils import (
33 | MLFlowCallback,
34 | SupervisedGenerator,
35 | get_mlflow_artifact_path,
36 | )
37 | from atlalign.nn import supervised_model_factory
38 |
39 |
40 | def test_get_mlflow_artifact_path(monkeypatch, tmpdir):
41 | monkeypatch.chdir(tmpdir)
42 |
43 | with mlflow.start_run():
44 | artifact_path = get_mlflow_artifact_path()
45 |
46 | expected = [
47 | x for x in pathlib.Path(str(tmpdir)).rglob("*") if "artifacts" in str(x)
48 | ][0]
49 |
50 | assert artifact_path == expected
51 |
52 |
53 | class TestMLFlowCallback:
54 | @staticmethod
55 | def create_h5(h5_path, n_samples, random_state):
56 | height, width = 320, 456
57 | np.random.seed(random_state)
58 |
59 | with h5py.File(h5_path, "w") as f:
60 | dset_img = f.create_dataset(
61 | "img", (n_samples, height, width), dtype="uint8"
62 | )
63 | dset_image_id = f.create_dataset("image_id", (n_samples,), dtype="int")
64 | dset_dataset_id = f.create_dataset("dataset_id", (n_samples,), dtype="int")
65 | dset_p = f.create_dataset("p", (n_samples,), dtype="int")
66 | dset_deltas_xy = f.create_dataset(
67 | "deltas_xy",
68 | (n_samples, height, width, 2),
69 | dtype=np.float16,
70 | fillvalue=0,
71 | )
72 | dset_inv_deltas_xy = f.create_dataset(
73 | "inv_deltas_xy",
74 | (n_samples, height, width, 2),
75 | dtype=np.float16,
76 | fillvalue=0,
77 | )
78 |
79 | # Populate
80 | dset_deltas_xy[:] = np.random.random((n_samples, height, width, 2))
81 | dset_inv_deltas_xy[:] = np.random.random((n_samples, height, width, 2))
82 | dset_img[:] = np.random.randint(
83 | 0, high=255, size=(n_samples, height, width)
84 | )
85 | dset_image_id[:] = 50 + np.random.choice(
86 | n_samples, size=n_samples, replace=False
87 | )
88 | dset_dataset_id[:] = 1000 + np.random.choice(
89 | n_samples, size=n_samples, replace=False
90 | )
91 | dset_p[:] = np.random.randint(0, high=12000, size=n_samples)
92 |
93 | def test_hooks(self, tmpdir, path_test_data, monkeypatch):
94 | tmpdir = pathlib.Path(str(tmpdir))
95 | monkeypatch.chdir(tmpdir)
96 | fake_sg_c = Mock(return_value=Mock(spec=SupervisedGenerator))
97 | monkeypatch.setattr(
98 | "atlalign.ml_utils.callbacks.SupervisedGenerator", fake_sg_c
99 | )
100 | monkeypatch.setattr("atlalign.ml_utils.callbacks.keras", Mock())
101 |
102 | h5_path = path_test_data / "supervised_dataset.h5"
103 |
104 | train_ixs_path = "a"
105 | val_ixs_path = "b"
106 |
107 | with mlflow.start_run():
108 | cb = MLFlowCallback(h5_path, train_ixs_path, val_ixs_path, freq=2)
109 |
110 | assert fake_sg_c.call_count == 2
111 |
112 | # Train
113 | train_kwargs = fake_sg_c.call_args_list[0][1]
114 | assert train_kwargs["indexes"] == train_ixs_path
115 | assert train_kwargs["shuffle"] is False
116 | assert train_kwargs["batch_size"] == 1
117 |
118 | # Val
119 | val_kwargs = fake_sg_c.call_args_list[1][1]
120 | assert val_kwargs["indexes"] == val_ixs_path
121 | assert val_kwargs["shuffle"] is False
122 | assert val_kwargs["batch_size"] == 1
123 |
124 | # On train begin
125 | assert not (cb.root_path / "architecture").exists()
126 | assert not (cb.root_path / "checkpoints").exists()
127 |
128 | cb.on_train_begin()
129 |
130 | assert (cb.root_path / "architecture").exists()
131 | assert (cb.root_path / "checkpoints").exists()
132 |
133 | # On batch_end
134 | cb.model = Mock(metrics_names=[]) # Inject a keras model
135 | cb.model.evaluate_generator.return_value = []
136 | monkeypatch.setattr(
137 | cb,
138 | "compute_external_metrics",
139 | Mock(
140 | return_value=pd.DataFrame(
141 | {
142 | "metric": [
143 | 1,
144 | ]
145 | }
146 | )
147 | ),
148 | )
149 |
150 | cb.on_batch_end(None) # 1
151 | cb.on_batch_end(None) # 2
152 |
153 | @pytest.mark.parametrize("random_state", [3, 10])
154 | @pytest.mark.parametrize("return_inverse", [True, False])
155 | def test_compute_external_metrics(
156 | self, monkeypatch, tmpdir, random_state, return_inverse
157 | ):
158 | evaluate_cache = []
159 |
160 | def fake_evaluate(*args, **kwargs):
161 | evaluate_cache.append(
162 | {
163 | "deltas_true": args[0],
164 | "img_mov": args[2],
165 | "p": kwargs["p"],
166 | "deltas_true_inv": kwargs["deltas_true_inv"],
167 | }
168 | )
169 |
170 | return pd.Series([2, 3])
171 |
172 | monkeypatch.setattr(
173 | "atlalign.ml_utils.callbacks.evaluate_single",
174 | Mock(side_effect=fake_evaluate),
175 | )
176 | monkeypatch.setattr("atlalign.ml_utils.callbacks.annotation_volume", Mock())
177 | monkeypatch.setattr(
178 | "atlalign.ml_utils.io.nissl_volume",
179 | Mock(return_value=np.zeros((528, 320, 456, 1))),
180 | )
181 | monkeypatch.setattr(
182 | "atlalign.ml_utils.callbacks.segmentation_collapsing_labels", Mock()
183 | )
184 |
185 | n_samples = 10
186 | n_val_samples = 4
187 | h5_path = pathlib.Path(str(tmpdir)) / "temp.h5"
188 | self.create_h5(h5_path, n_samples, random_state)
189 |
190 | val_indexes = list(np.random.choice(n_samples, n_val_samples, replace=False))
191 |
192 | val_gen = SupervisedGenerator(
193 | h5_path,
194 | indexes=val_indexes,
195 | shuffle=False,
196 | batch_size=1,
197 | return_inverse=return_inverse,
198 | )
199 | losses = ["mse", "mse", "mse"] if return_inverse else ["mse", "mse"]
200 | losses_weights = [1, 1, 1] if return_inverse else [1, 1]
201 |
202 | model = supervised_model_factory(
203 | compute_inv=return_inverse,
204 | losses=losses,
205 | losses_weights=losses_weights,
206 | start_filters=(2,),
207 | downsample_filters=(4, 2),
208 | middle_filters=(2,),
209 | upsample_filters=(2, 4),
210 | )
211 |
212 | df = MLFlowCallback.compute_external_metrics(model, val_gen)
213 |
214 | assert len(df) == len(val_indexes)
215 | assert np.allclose(
216 | df.index.values, load_dataset_in_memory(h5_path, "image_id")[val_indexes]
217 | )
218 | assert len(evaluate_cache) == len(val_indexes)
219 |
220 | for ecache, val_index in zip(evaluate_cache, val_indexes):
221 | expected_deltas = load_dataset_in_memory(h5_path, "deltas_xy")[val_index]
222 | expected_deltas_inv = load_dataset_in_memory(h5_path, "inv_deltas_xy")[
223 | val_index
224 | ]
225 | expected_image = load_dataset_in_memory(h5_path, "img")[val_index] / 255
226 | expected_p = load_dataset_in_memory(h5_path, "p")[val_index]
227 |
228 | assert np.allclose(expected_deltas, ecache["deltas_true"])
229 | assert np.allclose(expected_image, ecache["img_mov"])
230 | assert np.allclose(expected_p, ecache["p"])
231 |
232 | if return_inverse:
233 | assert np.allclose(expected_deltas_inv, ecache["deltas_true_inv"])
234 | else:
235 | assert ecache["deltas_true_inv"] is None # they are not streamed
236 |
--------------------------------------------------------------------------------
/tests/test_ml_utils/test_io.py:
--------------------------------------------------------------------------------
1 | """Tests focues on the atlalign.io module."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import pathlib
23 |
24 | import imgaug.augmenters as iaa
25 | import numpy as np
26 | import pytest
27 |
28 | from atlalign.ml_utils import SupervisedGenerator
29 |
30 | SUPERVISED_H5_PATH = (
31 | pathlib.Path(__file__).parent.parent / "data" / "supervised_dataset.h5"
32 | )
33 |
34 |
35 | @pytest.fixture(scope="session")
36 | def fake_nissl_volume():
37 | return np.random.random((528, 320, 456, 1)).astype(np.float32)
38 |
39 |
40 | class TestSupervisedGenerator:
41 | def test_inexistent_h5(self, monkeypatch, fake_nissl_volume):
42 | monkeypatch.setattr(
43 | "atlalign.ml_utils.io.nissl_volume",
44 | lambda *args, **kwargs: fake_nissl_volume,
45 | )
46 | path_wrong = SUPERVISED_H5_PATH.parent / "fake.h5"
47 |
48 | with pytest.raises(OSError):
49 | SupervisedGenerator(path_wrong)
50 |
51 | def test_correct_indexes(self, monkeypatch, fake_nissl_volume):
52 | """Inner indices are [0, 1]."""
53 | monkeypatch.setattr(
54 | "atlalign.ml_utils.io.nissl_volume",
55 | lambda *args, **kwargs: fake_nissl_volume,
56 | )
57 |
58 | assert SupervisedGenerator(SUPERVISED_H5_PATH).indexes == [0, 1]
59 |
60 | @pytest.mark.parametrize("batch_size", [1, 2])
61 | def test_length(self, batch_size, monkeypatch, fake_nissl_volume):
62 | """Test that length of the Sequence is computed correctly."""
63 | monkeypatch.setattr(
64 | "atlalign.ml_utils.io.nissl_volume",
65 | lambda *args, **kwargs: fake_nissl_volume,
66 | )
67 | correct_len = 2 // batch_size
68 |
69 | assert (
70 | len(SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size))
71 | == correct_len
72 | )
73 |
74 | def test_shuffling(self, monkeypatch, fake_nissl_volume):
75 | """Test shuffling works."""
76 | monkeypatch.setattr(
77 | "atlalign.ml_utils.io.nissl_volume",
78 | lambda *args, **kwargs: fake_nissl_volume,
79 | )
80 | n_trials = 10
81 | is_different = False
82 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, shuffle=True)
83 | orig_indexes = gen.indexes[:]
84 |
85 | for _ in range(n_trials):
86 | gen.on_epoch_end()
87 | is_different = is_different or orig_indexes != gen.indexes
88 |
89 | assert is_different
90 |
91 | @pytest.mark.parametrize("batch_size", [1, 2])
92 | def test_getitem(self, batch_size, monkeypatch, fake_nissl_volume):
93 | """Get item."""
94 | monkeypatch.setattr(
95 | "atlalign.ml_utils.io.nissl_volume",
96 | lambda *args, **kwargs: fake_nissl_volume,
97 | )
98 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size)
99 |
100 | inp, out = gen[0]
101 |
102 | assert inp.shape == (batch_size, 320, 456, 2)
103 | assert inp.dtype == np.float32
104 | assert 0 <= inp.min() <= inp.max() <= 1
105 |
106 | assert isinstance(out, list)
107 | assert len(out) == 2
108 |
109 | assert out[0].shape == (batch_size, 320, 456, 1) and out[1].shape == (
110 | batch_size,
111 | 320,
112 | 456,
113 | 2,
114 | )
115 | assert out[0].dtype == np.float32 and out[1].dtype == np.float16
116 | assert 0 <= out[0].min() <= out[0].max() <= 1
117 |
118 | @pytest.mark.parametrize("aug_ref", [True, False])
119 | @pytest.mark.parametrize("aug_mov", [True, False])
120 | def test_augmenters(self, aug_ref, aug_mov, monkeypatch, fake_nissl_volume):
121 | """Augmenting works."""
122 | monkeypatch.setattr(
123 | "atlalign.ml_utils.io.nissl_volume",
124 | lambda *args, **kwargs: fake_nissl_volume,
125 | )
126 | batch_size = 2
127 | augmenter = iaa.Sequential(
128 | [
129 | iaa.Fliplr(0.5), # horizontal flips
130 | iaa.Crop(percent=(0, 0.1)), # random crops
131 | # Small gaussian blur with random sigma between 0 and 0.5.
132 | # But we only blur about 50% of all images.
133 | iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))),
134 | ]
135 | )
136 |
137 | kwargs = {
138 | "augmenter_ref": augmenter if aug_ref else None,
139 | "augmenter_mov": augmenter if aug_mov else None,
140 | }
141 |
142 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size, **kwargs)
143 |
144 | inp, out = gen[0]
145 |
146 | assert inp.dtype == np.float32
147 |
148 | assert isinstance(out, list)
149 | assert len(out) == 2
150 |
151 | assert out[0].shape == (batch_size, 320, 456, 1) and out[1].shape == (
152 | batch_size,
153 | 320,
154 | 456,
155 | 2,
156 | )
157 | assert out[0].dtype == np.float32 and out[1].dtype == np.float16
158 |
159 | @pytest.mark.parametrize("batch_size", [1, 2])
160 | def test_get_all_data(self, batch_size, monkeypatch, fake_nissl_volume):
161 | monkeypatch.setattr(
162 | "atlalign.ml_utils.io.nissl_volume",
163 | lambda *args, **kwargs: fake_nissl_volume,
164 | )
165 |
166 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size)
167 |
168 | all_inp, all_out = gen.get_all_data()
169 |
170 | assert len(all_inp) == 2
171 | assert len(all_out) == 2
172 |
173 | def test_indexes(self, monkeypatch, fake_nissl_volume):
174 | """Make sure indexes attribute work."""
175 | batch_size = 1
176 | monkeypatch.setattr(
177 | "atlalign.ml_utils.io.nissl_volume",
178 | lambda *args, **kwargs: fake_nissl_volume,
179 | )
180 |
181 | gen_indexes = SupervisedGenerator(
182 | SUPERVISED_H5_PATH, batch_size=batch_size, indexes=[0]
183 | )
184 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size)
185 |
186 | assert len(gen) == 2
187 | assert len(gen_indexes) == 1
188 |
--------------------------------------------------------------------------------
/tests/test_ml_utils/test_models.py:
--------------------------------------------------------------------------------
1 | """Collection of tests focused on the atlalign.ml_utils.models module."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import pathlib
23 | from copy import deepcopy
24 |
25 | import pytest
26 | from tensorflow import keras
27 |
28 | from atlalign.ml_utils import (
29 | load_model,
30 | merge_global_local,
31 | replace_lambda_in_config,
32 | save_model,
33 | )
34 |
35 |
36 | @pytest.fixture(scope="function", params=["compiled", "uncompiled"])
37 | def reg_model(request):
38 | """Minimal registration network."""
39 |
40 | compile = request.param == "compiled"
41 |
42 | inputs = keras.layers.Input((32, 45, 2))
43 |
44 | extract_0 = keras.layers.Lambda(lambda x: x[..., :-1])
45 | dummy_op_1 = keras.layers.Lambda(lambda x: x + 1)
46 | dummy_op_2 = keras.layers.Lambda(lambda x: x * 1)
47 |
48 | reg_imgs = dummy_op_1(extract_0(inputs))
49 | dvfs = dummy_op_2(inputs)
50 |
51 | model = keras.models.Model(inputs=inputs, outputs=[reg_imgs, dvfs])
52 |
53 | if compile:
54 | model.compile(optimizer="adam", loss="mse")
55 |
56 | return model
57 |
58 |
59 | @pytest.fixture()
60 | def lambda_model():
61 | """Model containing Lambda layers"""
62 | inputs = keras.layers.Input(shape=(10, 14, 2))
63 | x = keras.layers.Lambda(lambda x: x, name="inv_dvf")(inputs)
64 | x = keras.layers.Lambda(lambda x: x[..., 1:], name="extract_moving")(x)
65 |
66 | return keras.Model(inputs=inputs, outputs=x)
67 |
68 |
69 | class TestGlobalLocal:
70 | """Tests focused on the merge_global_local function."""
71 |
72 | @pytest.mark.parametrize("expose_global", [True, False])
73 | def test_overall(self, reg_model, expose_global):
74 | model_gl = merge_global_local(reg_model, reg_model, expose_global=expose_global)
75 |
76 | assert isinstance(model_gl, keras.models.Model)
77 |
78 | assert len(model_gl.outputs) == (4 if expose_global else 2)
79 |
80 |
81 | class TestSaveModel:
82 | """Collection of tests focused on the `save_model` function."""
83 |
84 | @pytest.mark.parametrize("separate", [True, False])
85 | def test_correct(self, separate, tmpdir, reg_model):
86 | temp_path = pathlib.Path(str(tmpdir)) / "temp_model"
87 | save_model(reg_model, temp_path, separate=separate)
88 |
89 | if separate:
90 | assert (temp_path / "temp_model.json").exists()
91 | assert (temp_path / "temp_model.h5").exists()
92 | else:
93 | assert pathlib.Path(str(temp_path) + ".h5").exists()
94 |
95 | def test_wrong_path(self, reg_model):
96 | with pytest.raises(ValueError):
97 | save_model(reg_model, "aaa.extension")
98 |
99 | def test_already_exists(self, reg_model, tmpdir):
100 | orig_path = pathlib.Path(str(tmpdir)) / "temp_model"
101 | temp_path = orig_path / "temp_model.json"
102 |
103 | temp_path.parent.mkdir(
104 | parents=True, exist_ok=True
105 | ) # maybe the folder was already created before
106 | temp_path.touch()
107 |
108 | with pytest.raises(FileExistsError):
109 | save_model(reg_model, orig_path, overwrite=False)
110 |
111 |
112 | class TestLoadModel:
113 | """Collection of tests focused on the `load_model` function."""
114 |
115 | @pytest.mark.parametrize("separate", [True, False])
116 | @pytest.mark.parametrize("compile", [True, False])
117 | def test_correct(self, separate, compile, tmpdir, reg_model):
118 | original_compiled = reg_model.optimizer is not None
119 |
120 | temp_path = pathlib.Path(str(tmpdir)) / "temp_model"
121 | save_model(reg_model, temp_path, separate=separate)
122 |
123 | if separate and compile:
124 | with pytest.raises(ValueError):
125 | load_model(
126 | str(temp_path) + "{}".format("" if separate else ".h5"),
127 | compile=compile,
128 | )
129 | return
130 |
131 | model = load_model(
132 | str(temp_path) + "{}".format("" if separate else ".h5"), compile=compile
133 | )
134 |
135 | assert isinstance(model, keras.Model)
136 |
137 | if not separate and compile and original_compiled:
138 | assert model.optimizer is not None
139 | else:
140 | assert model.optimizer is None
141 |
142 | assert len(model.get_config()["layers"]) == len(
143 | reg_model.get_config()["layers"]
144 | )
145 | # assert reg_model.to_json() == model.to_json() # might differ because of lambda layers:D:D
146 |
147 | def test_nonexistent_path(self, tmpdir):
148 | with pytest.raises(OSError):
149 | load_model(tmpdir / "fake")
150 |
151 | def test_ambiguous(self, tmpdir):
152 | path = pathlib.Path(str(tmpdir))
153 | path_architecture_1 = path / "a_1.json"
154 | path_architecture_2 = path / "a_2.json"
155 | path_weights = path / "w_1.h5"
156 |
157 | path_architecture_1.touch()
158 | path_architecture_2.touch()
159 | path_weights.touch()
160 |
161 | with pytest.raises(ValueError):
162 | load_model(path)
163 |
164 |
165 | class TestReplaceLambdaInConfig:
166 | """Collection of tests focused on the `replace_lambda_in_config` method."""
167 |
168 | # Commenting for now
169 | # @pytest.mark.parametrize("input_format", ["json", "keras", "dict", "path"])
170 | # @pytest.mark.parametrize("output_format", ["json", "keras", "dict"])
171 | # def test_identical_results(self, lambda_model, input_format, output_format, tmpdir):
172 | #
173 | # if input_format == "keras":
174 | # input_config = lambda_model
175 | # elif input_format == "json":
176 | # input_config = lambda_model.to_json()
177 | # elif input_format == "dict":
178 | # input_config = lambda_model.get_config()
179 | # elif input_format == "path":
180 | # path_ = pathlib.Path(str(tmpdir))
181 | # input_config = path_ / "model.json"
182 | # with input_config.open("w") as f_a:
183 | # json.dump(lambda_model.to_json(), f_a)
184 | # else:
185 | # raise ValueError()
186 | #
187 | # output = replace_lambda_in_config(input_config, output_format=output_format)
188 | #
189 | # # check types
190 | # if output_format == "keras":
191 | # assert isinstance(output, keras.Model)
192 | # new_model = output
193 | # elif output_format == "json":
194 | # assert isinstance(output, str)
195 | # new_model = keras.models.model_from_json(
196 | # output,
197 | # )
198 | # elif output_format == "dict":
199 | # assert isinstance(output, dict)
200 | # new_model = keras.Model.from_config(
201 | # output,
202 | # custom_objects={
203 | # "ExtractMoving": ExtractMoving,
204 | # "NoOp": NoOp,
205 | # },
206 | # )
207 | #
208 | # shape_input = keras.backend.int_shape(lambda_model.input)[1:]
209 | # x = np.random.random((2, *shape_input))
210 | #
211 | # assert np.allclose(lambda_model.predict(x), new_model.predict(x))
212 |
213 | def test_incorrect_input(self, lambda_model):
214 | # incorrect input type
215 | with pytest.raises(TypeError):
216 | replace_lambda_in_config(1)
217 |
218 | # incorret input path
219 | with pytest.raises(ValueError):
220 | replace_lambda_in_config(pathlib.Path.cwd() / "fake.wrong")
221 |
222 | # incorrect output type
223 | with pytest.raises(TypeError):
224 | replace_lambda_in_config(lambda_model, output_format="fake")
225 |
226 | def test_incorrect_layer_config(self, lambda_model):
227 |
228 | correct_config = lambda_model.get_config()
229 | missing_class_name = deepcopy(correct_config)
230 | missing_config = deepcopy(correct_config)
231 | invalid_name = deepcopy(correct_config)
232 |
233 | # prep
234 | del missing_class_name["layers"][0]["class_name"]
235 | del missing_config["layers"][1]["config"]
236 | invalid_name["layers"][2]["config"]["name"] = "aaaaa"
237 |
238 | with pytest.raises(KeyError):
239 | replace_lambda_in_config(missing_class_name)
240 |
241 | with pytest.raises(KeyError):
242 | replace_lambda_in_config(missing_config)
243 |
244 | with pytest.raises(KeyError):
245 | replace_lambda_in_config(invalid_name)
246 |
--------------------------------------------------------------------------------
/tests/test_nn.py:
--------------------------------------------------------------------------------
1 | """Tests focused on the nn module."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | import numpy as np
23 | import pytest
24 | from tensorflow import keras
25 |
26 | from atlalign.nn import supervised_global_model_factory, supervised_model_factory
27 |
28 |
29 | class TestSupervisedModelFactory:
30 | """Collection of tests focused on the `supervised_model_factory`."""
31 |
32 | def test_default_construction(self):
33 | """Make sure possible to use with the default setting"""
34 |
35 | model = supervised_model_factory()
36 | assert isinstance(model, keras.Model)
37 |
38 | @pytest.mark.parametrize("compute_inv", [True, False])
39 | @pytest.mark.parametrize("use_lambda", [True, False])
40 | def test_use_lambda(self, use_lambda, compute_inv):
41 | """Make sure the `use_lambda` flag is working"""
42 |
43 | losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse")
44 | losses_weights = (1, 1, 1) if compute_inv else (1, 1)
45 | model = supervised_model_factory(
46 | losses=losses,
47 | losses_weights=losses_weights,
48 | compute_inv=compute_inv,
49 | use_lambda=use_lambda,
50 | )
51 | lambda_list = [x for x in model.layers if isinstance(x, keras.layers.Lambda)]
52 | if use_lambda:
53 | assert lambda_list
54 | else:
55 | assert not lambda_list
56 |
57 | # @pytest.mark.parametrize("compute_inv", [True, False])
58 | # def test_equivalence(self, compute_inv):
59 | # """Make sure moving Lambda layers does not affect the results."""
60 | # losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse")
61 | # losses_weights = (1, 1, 1) if compute_inv else (1, 1)
62 | # params = {
63 | # "start_filters": (2,),
64 | # "downsample_filters": (2, 3),
65 | # "middle_filters": (2,),
66 | # "upsample_filters": (2, 3),
67 | # "end_filters": tuple(),
68 | # "compute_inv": compute_inv,
69 | # "losses": losses,
70 | # "losses_weights": losses_weights,
71 | # }
72 | #
73 | # np.random.seed(1337)
74 | # tf.random.set_seed(1337)
75 | # model_with = supervised_model_factory(use_lambda=True, **params)
76 | # np.random.seed(1337)
77 | # tf.random.set_seed(1337)
78 | # model_without = supervised_model_factory(use_lambda=False, **params)
79 | # x = np.random.random((1, 320, 456, 2))
80 | # pred_with = model_with.predict([x, x] if compute_inv else x)
81 | # pred_without = model_without.predict([x, x] if compute_inv else x)
82 | #
83 | # assert np.allclose(pred_with[0], pred_without[0])
84 | # assert np.allclose(pred_with[1], pred_without[1])
85 | # if compute_inv:
86 | # assert np.allclose(pred_with[2], pred_without[2])
87 |
88 | def test_down_up_samples(self):
89 | """Make sure raises an error if downsamples and upsamples have not the same number of layers"""
90 | with pytest.raises(ValueError):
91 | supervised_model_factory(downsample_filters=(2,), upsample_filters=(2, 3))
92 | with pytest.raises(ValueError):
93 | supervised_model_factory(
94 | downsample_filters=(2, 2, 2, 2, 2, 2, 2),
95 | upsample_filters=(2, 2, 2, 2, 2, 2, 2),
96 | )
97 |
98 |
99 | class TestSupervisedGlobalModelFactory:
100 | """Collection of tests focused on the `supervised_model_factory`."""
101 |
102 | def test_default_construction(self):
103 | """Make sure possible to use with the default setting"""
104 |
105 | model = supervised_global_model_factory()
106 | assert isinstance(model, keras.Model)
107 |
108 | @pytest.mark.parametrize("use_lambda", [True, False])
109 | def test_use_lambda(self, use_lambda):
110 | """Make sure the `use_lambda` flag is working"""
111 |
112 | model = supervised_global_model_factory(use_lambda=use_lambda)
113 | lambda_list = [x for x in model.layers if isinstance(x, keras.layers.Lambda)]
114 | if use_lambda:
115 | assert lambda_list
116 | else:
117 | assert not lambda_list
118 |
119 | def test_equivalence(self):
120 | """Make sure moving Lambda layers does not affect the results."""
121 | params = {"filters": (2, 2, 2, 2), "dense_layers": (2,)}
122 | np.random.seed(1337)
123 | model_with = supervised_global_model_factory(use_lambda=True, **params)
124 | np.random.seed(1337)
125 | model_without = supervised_global_model_factory(use_lambda=False, **params)
126 | x = np.random.random((1, 320, 456, 2))
127 | pred_with = model_with.predict(x)
128 | pred_without = model_without.predict(x)
129 |
130 | assert np.allclose(pred_with[0], pred_without[0])
131 | assert np.allclose(pred_with[1], pred_without[1])
132 |
--------------------------------------------------------------------------------
/tests/test_non_ml.py:
--------------------------------------------------------------------------------
1 | """
2 | The package atlalign is a tool for registration of 2D images.
3 |
4 | Copyright (C) 2021 EPFL/Blue Brain Project
5 |
6 | This program is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU Lesser General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | (at your option) any later version.
10 |
11 | This program is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU Lesser General Public License for more details.
15 |
16 | You should have received a copy of the GNU Lesser General Public License
17 | along with this program. If not, see .
18 | """
19 |
20 | import numpy as np
21 | import pytest
22 |
23 | from atlalign.base import DisplacementField
24 | from atlalign.non_ml import antspy_registration
25 |
26 |
27 | class TestAntspyRegistration:
28 | """A collection of tests focused on the ANTsPy registration."""
29 |
30 | @pytest.mark.slow
31 | @pytest.mark.parametrize(
32 | "registration_type",
33 | [
34 | "Translation",
35 | "Rigid",
36 | "Similarity",
37 | "QuickRigid",
38 | "DenseRigid",
39 | "BOLDRigid",
40 | "Affine",
41 | "AffineFast",
42 | "BOLDAffine",
43 | "TRSAA",
44 | "ElasticSyN",
45 | "SyN",
46 | "SyNRA",
47 | "SyNOnly",
48 | "SyNCC",
49 | "SyNabp",
50 | "SyNBold",
51 | "SyNBoldAff",
52 | "SyNAggro",
53 | "TVMSQ",
54 | ],
55 | )
56 | def test_each_type(self, tmp_path, img_grayscale, registration_type):
57 | """Make sure that every type of registration gives a valid DisplacementField as output."""
58 |
59 | moving_img = img_grayscale
60 | fixed_img = img_grayscale
61 |
62 | reg_iterations = (4, 2, 0)
63 |
64 | df_final, meta = antspy_registration(
65 | fixed_img,
66 | moving_img,
67 | registration_type=registration_type,
68 | reg_iterations=reg_iterations,
69 | path=tmp_path,
70 | )
71 |
72 | assert isinstance(df_final, DisplacementField)
73 | assert df_final.is_valid
74 |
75 | def test_displacement_field(self, tmp_path, img_grayscale):
76 | """Make sure that the displacement extracted can reproduce the registered moving image.
77 |
78 | Done by comparing with the image contained in the registration output.
79 | """
80 |
81 | fixed_img = img_grayscale
82 | size = fixed_img.shape
83 | df = DisplacementField.generate(
84 | size, approach="affine_simple", translation_x=20, translation_y=20
85 | )
86 | moving_img = df.warp(img_grayscale)
87 |
88 | df_final, meta = antspy_registration(fixed_img, moving_img, path=tmp_path)
89 | img1 = meta["warpedmovout"].numpy()
90 | img2 = df_final.warp(
91 | moving_img, interpolation="linear", border_mode="constant", c=0
92 | )
93 |
94 | if img_grayscale.dtype == "uint8":
95 | epsilon = 1
96 | else:
97 | epsilon = 0.005
98 |
99 | assert abs(img1 - img2).mean() < epsilon
100 |
101 | @pytest.mark.todo
102 | @pytest.mark.parametrize(
103 | "registration_type",
104 | [
105 | "Translation",
106 | "Rigid",
107 | "Similarity",
108 | "QuickRigid",
109 | "DenseRigid",
110 | "BOLDRigid",
111 | "Affine",
112 | "AffineFast",
113 | "BOLDAffine",
114 | "TRSAA",
115 | "ElasticSyN",
116 | "SyN",
117 | "SyNRA",
118 | "SyNOnly",
119 | "SyNCC",
120 | "SyNabp",
121 | ],
122 | )
123 | def test_same_results(
124 | self, tmp_path, monkeypatch, img_grayscale, registration_type
125 | ):
126 | """Make sure that the registration is reproducible.
127 |
128 | Done by checking if the displacement field extracted are equal if we run the registration with
129 | the same parameters.
130 |
131 | Notes
132 | -----
133 | Marked as `todo` because we did not find a way how to force ANTsPY to be always reproducible.
134 | """
135 |
136 | monkeypatch.setenv("ANTS_RANDOM_SEED", "1")
137 | monkeypatch.setenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS", "1")
138 |
139 | size = img_grayscale.shape
140 | p = img_grayscale / np.sum(img_grayscale)
141 | df = DisplacementField.generate(size, approach="paper", p=p, random_state=1)
142 | moving_img = df.warp(img_grayscale)
143 |
144 | df_final1, meta1 = antspy_registration(
145 | img_grayscale,
146 | moving_img,
147 | registration_type=registration_type,
148 | path=tmp_path,
149 | verbose=True,
150 | )
151 | df_final2, meta2 = antspy_registration(
152 | img_grayscale,
153 | moving_img,
154 | registration_type=registration_type,
155 | path=tmp_path,
156 | verbose=True,
157 | )
158 |
159 | assert np.allclose(df_final1.delta_x, df_final2.delta_x, atol=1)
160 | assert np.allclose(df_final1.delta_y, df_final2.delta_y, atol=1)
161 |
162 | @pytest.mark.slow
163 | @pytest.mark.parametrize(
164 | "registration_type",
165 | [
166 | "Translation",
167 | "Rigid",
168 | "Similarity",
169 | "QuickRigid",
170 | "DenseRigid",
171 | "BOLDRigid",
172 | "Affine",
173 | "AffineFast",
174 | "BOLDAffine",
175 | "TRSAA",
176 | "ElasticSyN",
177 | "SyN",
178 | "SyNRA",
179 | "SyNOnly",
180 | "SyNCC",
181 | "SyNabp",
182 | "SyNBold",
183 | "SyNBoldAff",
184 | "SyNAggro",
185 | "TVMSQ",
186 | ],
187 | )
188 | def test_different_results(self, tmp_path, img_grayscale_uint, registration_type):
189 | """Make sure that the registration is not reproducible if some environment variables not set.
190 |
191 |
192 | The environment variables are ANTS_RANDOM_SEED and ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS.
193 | """
194 |
195 | size = img_grayscale_uint.shape
196 | df = DisplacementField.generate(
197 | size, approach="affine_simple", translation_x=20
198 | )
199 | moving_img = df.warp(img_grayscale_uint)
200 |
201 | df_final1, meta1 = antspy_registration(
202 | img_grayscale_uint, moving_img, path=tmp_path
203 | )
204 | df_final2, meta2 = antspy_registration(
205 | img_grayscale_uint, moving_img, path=tmp_path
206 | )
207 |
208 | assert not np.allclose(df_final1.delta_x, df_final2.delta_x, atol=0.1)
209 | assert not np.allclose(df_final1.delta_y, df_final2.delta_y, atol=0.1)
210 |
211 | @pytest.mark.todo
212 | def test_different_types(
213 | self, tmp_path, monkeypatch, img_grayscale_uint, img_grayscale_float
214 | ):
215 | """Make sure that the registration does not depend on the type of input images.
216 |
217 | Notes
218 | -----
219 | Marked as `todo` because we did not find a way how to force ANTsPY to be always reproducible.
220 | """
221 |
222 | monkeypatch.setenv("ANTS_RANDOM_SEED", "4")
223 | monkeypatch.setenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS", "1")
224 |
225 | size = img_grayscale_uint.shape
226 | p = img_grayscale_uint / np.sum(img_grayscale_uint)
227 | df = DisplacementField.generate(size, approach="paper", p=p)
228 | moving_img_uint = df.warp(img_grayscale_uint)
229 | moving_img_float = df.warp(img_grayscale_float)
230 |
231 | df_final1, meta1 = antspy_registration(
232 | img_grayscale_uint, moving_img_uint, path=tmp_path, verbose=False
233 | )
234 | df_final2, meta2 = antspy_registration(
235 | img_grayscale_float, moving_img_float, path=tmp_path, verbose=False
236 | )
237 |
238 | assert np.allclose(df_final1.delta_x, df_final2.delta_x, atol=0.1)
239 | assert np.allclose(df_final1.delta_y, df_final2.delta_y, atol=0.1)
240 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | The package atlalign is a tool for registration of 2D images.
3 |
4 | Copyright (C) 2021 EPFL/Blue Brain Project
5 |
6 | This program is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU Lesser General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | (at your option) any later version.
10 |
11 | This program is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU Lesser General Public License for more details.
15 |
16 | You should have received a copy of the GNU Lesser General Public License
17 | along with this program. If not, see .
18 | """
19 |
20 | import numpy as np
21 | import pytest
22 |
23 | from atlalign.utils import _find_all_children, find_labels_dic
24 |
25 |
26 | class TestSegmentationConcatenation:
27 | """A set of methods testing the concatenation of labels for the segmentation."""
28 |
29 | def test_find_children(self, label_dict):
30 | dic = {"id": 2, "children": []}
31 |
32 | assert np.all(_find_all_children(dic) == [2])
33 | assert np.all(set(_find_all_children(label_dict)) == {2, 3, 4, 5, 6})
34 |
35 | @pytest.mark.parametrize("depth", [0, 1, 10])
36 | def test_background(self, depth, label_dict):
37 | shape = (20, 20)
38 | segmentation_array = np.zeros(shape)
39 | new_segmentation_array = np.zeros(shape)
40 |
41 | assert np.all(
42 | find_labels_dic(segmentation_array, label_dict, depth)
43 | == new_segmentation_array
44 | )
45 |
46 | @pytest.mark.parametrize("depth", [0, 1, 10])
47 | def test_labels_not_in_tree(self, depth, label_dict):
48 | shape = (20, 20)
49 | segmentation_array = np.ones(shape) * 10
50 | new_segmentation_array = np.ones(shape) * -1
51 |
52 | assert np.all(
53 | find_labels_dic(segmentation_array, label_dict, depth)
54 | == new_segmentation_array
55 | )
56 |
57 | @pytest.mark.parametrize("label", [2, 3, 4, 5, 6])
58 | def test_tree_concatenation(self, label, label_dict):
59 | depth = 0
60 | shape = (20, 20)
61 | segmentation_array = np.ones(shape) * label
62 | new_segmentation_array = np.ones(shape) * 2
63 |
64 | assert np.all(
65 | find_labels_dic(segmentation_array, label_dict, depth)
66 | == new_segmentation_array
67 | )
68 |
69 | @pytest.mark.parametrize("unchanged_label", [2, 3, 4])
70 | def test_tree_concatenation_unchanged_labels(self, unchanged_label, label_dict):
71 | depth = 1
72 | shape = (20, 20)
73 | segmentation_array = np.ones(shape) * unchanged_label
74 | new_segmentation_array = np.ones(shape) * unchanged_label
75 |
76 | assert np.all(
77 | find_labels_dic(segmentation_array, label_dict, depth)
78 | == new_segmentation_array
79 | )
80 |
81 | def test_specific_example(self, label_dict):
82 | """A simple 3 x 3 matrix segmentation array."""
83 |
84 | segmentation_array = np.array([[[1, 2, 3], [20, 14, 50], [4, 5, 6]]])
85 |
86 | segmentation_array_d0 = np.array([[[-1, 2, 2], [-1, -1, -1], [2, 2, 2]]])
87 |
88 | segmentation_array_d1 = np.array([[[-1, 2, 3], [-1, -1, -1], [4, 4, 4]]])
89 |
90 | segmentation_array_d2 = np.array([[[-1, 2, 3], [-1, -1, -1], [4, 5, 6]]])
91 |
92 | segmentation_array_d3 = np.array([[[-1, 2, 3], [-1, -1, -1], [4, 5, 6]]])
93 |
94 | assert np.all(
95 | find_labels_dic(segmentation_array, label_dict, 0) == segmentation_array_d0
96 | )
97 | assert np.all(
98 | find_labels_dic(segmentation_array, label_dict, 1) == segmentation_array_d1
99 | )
100 | assert np.all(
101 | find_labels_dic(segmentation_array, label_dict, 2) == segmentation_array_d2
102 | )
103 | assert np.all(
104 | find_labels_dic(segmentation_array, label_dict, 3) == segmentation_array_d3
105 | )
106 |
--------------------------------------------------------------------------------
/tests/test_visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | The package atlalign is a tool for registration of 2D images.
3 |
4 | Copyright (C) 2021 EPFL/Blue Brain Project
5 |
6 | This program is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU Lesser General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | (at your option) any later version.
10 |
11 | This program is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU Lesser General Public License for more details.
15 |
16 | You should have received a copy of the GNU Lesser General Public License
17 | along with this program. If not, see .
18 | """
19 |
20 | import numpy as np
21 | import pytest
22 | from matplotlib import animation
23 |
24 | from atlalign.base import DisplacementField
25 | from atlalign.visualization import (
26 | create_animation,
27 | create_segmentation_image,
28 | generate_df_plots,
29 | )
30 |
31 |
32 | class TestUtils:
33 | """A collection of tests focues on the `utils` module."""
34 |
35 | def test_wrong_input_dtype(self):
36 | """Make sure that float input dtype not allowed."""
37 |
38 | shape = (4, 5)
39 |
40 | segmentation_array = np.ones(shape, dtype=np.float32) / 5
41 |
42 | with pytest.raises(TypeError):
43 | create_segmentation_image(segmentation_array)
44 |
45 | def test_correct_output_type(self):
46 | """Make sure that uint8 output dtype."""
47 |
48 | shape = (4, 5)
49 |
50 | segmentation_array = np.random.randint(100, size=shape)
51 |
52 | segmentation_img, _ = create_segmentation_image(segmentation_array)
53 |
54 | assert segmentation_img.dtype == np.uint8
55 | assert np.all((0 <= segmentation_img) & (segmentation_img < 256))
56 |
57 | def test_different_classes_different_colors(self):
58 | """Test different classes have different colors."""
59 |
60 | segmentation_array = np.array([[0, 10], [2, 0]])
61 |
62 | segmentation_img, _ = create_segmentation_image(segmentation_array)
63 |
64 | assert (
65 | len(
66 | {
67 | tuple(x)
68 | for x in [
69 | segmentation_img[0, 0],
70 | segmentation_img[0, 1],
71 | segmentation_img[1, 0],
72 | segmentation_img[1, 1],
73 | ]
74 | }
75 | )
76 | == 3
77 | )
78 |
79 | assert np.all(segmentation_img[0, 0] == segmentation_img[1, 1])
80 |
81 | def test_predefined_colors(self):
82 | """Test possible to pass colors."""
83 |
84 | segmentation_array = np.array([[0, 1], [2, 22]])
85 |
86 | colors_dict = {0: (0, 0, 0), 1: (255, 0, 0)}
87 |
88 | segmentation_img, _ = create_segmentation_image(segmentation_array, colors_dict)
89 |
90 | assert np.all(segmentation_img[0, 0] == (0, 0, 0))
91 | assert np.all(segmentation_img[0, 1] == (255, 0, 0))
92 |
93 | def test_animation(self, img):
94 | """Possible to generate animations."""
95 |
96 | df = DisplacementField.generate(img.shape, approach="identity")
97 |
98 | ani = create_animation(df, img)
99 | ani_many = create_animation([df, df], img)
100 |
101 | assert isinstance(ani, animation.Animation)
102 | assert isinstance(ani_many, animation.Animation)
103 |
104 |
105 | class TestGenerateDFPlots:
106 | """Tests focused on the `generate_df_plots` function."""
107 |
108 | @pytest.mark.parametrize("df_id", [(320, 456)], indirect=True)
109 | def test_basic(self, df_id, tmpdir, monkeypatch):
110 | generate_df_plots(df_id, df_id, tmpdir)
111 |
--------------------------------------------------------------------------------
/tests/test_volume.py:
--------------------------------------------------------------------------------
1 | """Collection of tests focused on the `volume` module."""
2 |
3 | """
4 | The package atlalign is a tool for registration of 2D images.
5 |
6 | Copyright (C) 2021 EPFL/Blue Brain Project
7 |
8 | This program is free software: you can redistribute it and/or modify
9 | it under the terms of the GNU Lesser General Public License as published by
10 | the Free Software Foundation, either version 3 of the License, or
11 | (at your option) any later version.
12 |
13 | This program is distributed in the hope that it will be useful,
14 | but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | GNU Lesser General Public License for more details.
17 |
18 | You should have received a copy of the GNU Lesser General Public License
19 | along with this program. If not, see .
20 | """
21 |
22 | from unittest.mock import MagicMock
23 |
24 | import numpy as np
25 | import pytest
26 |
27 | from atlalign.base import DisplacementField
28 | from atlalign.volume import CoronalInterpolator, GappedVolume, Volume
29 |
30 |
31 | @pytest.fixture()
32 | def minimal_vol(monkeypatch):
33 | sn = [1, 13]
34 | mov_imgs = 2 * [np.zeros((12, 13))]
35 | dvfs = 2 * [DisplacementField.generate((12, 13), approach="identity")]
36 |
37 | nvol_mock = MagicMock()
38 | nvol_mock.__getitem__.return_value = np.zeros((len(mov_imgs), 12, 13, 1))
39 | monkeypatch.setattr("atlalign.volume.nissl_volume", lambda: nvol_mock)
40 |
41 | return Volume(sn, mov_imgs, dvfs)
42 |
43 |
44 | class TestVolume:
45 | def test_wrong_input(self):
46 | with pytest.raises(ValueError):
47 | Volume([None], [], [])
48 |
49 | with pytest.raises(ValueError):
50 | Volume(["a", "a"], ["a", "a"], ["a", "a"])
51 |
52 | with pytest.raises(ValueError):
53 | Volume([1111, 12], ["a", "a"], ["a", "a"])
54 |
55 | def test_construction(self, minimal_vol):
56 | assert isinstance(minimal_vol, Volume)
57 |
58 | def test_getitem(self, minimal_vol):
59 | with pytest.raises(KeyError):
60 | minimal_vol[4]
61 |
62 | outputs = minimal_vol[1]
63 |
64 | assert len(outputs) == 4
65 | assert isinstance(outputs[0], np.ndarray)
66 | assert isinstance(outputs[1], np.ndarray)
67 | assert isinstance(outputs[2], np.ndarray)
68 | assert isinstance(outputs[3], DisplacementField)
69 |
70 | def test_sorted_attributes(self, minimal_vol):
71 | assert isinstance(minimal_vol.sorted_dvfs[0], list)
72 | assert isinstance(minimal_vol.sorted_mov[0], list)
73 | assert isinstance(minimal_vol.sorted_reg[0], list)
74 | assert isinstance(minimal_vol.sorted_ref[0], list)
75 |
76 |
77 | class TestGappedVolume:
78 | def test_incorrect_input(self):
79 | with pytest.raises(ValueError):
80 | GappedVolume([1], [])
81 |
82 | with pytest.raises(ValueError):
83 | GappedVolume([1, 1], [np.zeros((1, 2)), np.zeros((2, 3))])
84 |
85 | def test_array2list_conversion(self):
86 | sn = [1, 44, 12]
87 | shape = (10, 11)
88 | imgs = np.array([np.zeros(shape) for _ in range(len(sn))])
89 |
90 | gv = GappedVolume(sn, imgs)
91 |
92 | assert isinstance(gv.sn, list)
93 | assert np.allclose(sn, gv.sn)
94 | assert isinstance(gv.imgs, list)
95 |
96 |
97 | class TestCoronalInterpolator:
98 | @pytest.mark.parametrize(
99 | "kind", ["linear", "nearest", "zero", "slinear", "previous", "next"]
100 | )
101 | def test_all_kinds(self, kind):
102 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False)
103 |
104 | sn = [0, 527]
105 | imgs = np.zeros((len(sn), 10, 11))
106 | dummy_gv = GappedVolume(sn, imgs)
107 |
108 | final_volume = ip.interpolate(dummy_gv)
109 |
110 | assert np.allclose(np.zeros((528, *dummy_gv.shape)), final_volume)
111 |
112 | @pytest.mark.parametrize(
113 | "kind",
114 | [
115 | "linear",
116 | "quadratic",
117 | "cubic",
118 | "nearest",
119 | "zero",
120 | "slinear",
121 | "previous",
122 | "next",
123 | ],
124 | )
125 | def test_precise_on_known(self, kind):
126 | """Make sure that on the known slices the interpolation is precise."""
127 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False)
128 |
129 | shape = (10, 11)
130 | sn = list(range(0, 528, 8)) + [527]
131 | imgs = np.random.random((len(sn), *shape))
132 |
133 | gv = GappedVolume(sn, imgs)
134 |
135 | final_volume = ip.interpolate(gv)
136 |
137 | for i, s in enumerate(sn):
138 | assert np.allclose(imgs[i], final_volume[s])
139 |
140 | @pytest.mark.parametrize(
141 | "kind", ["linear", "nearest", "zero", "slinear", "previous", "next"]
142 | )
143 | def test_nan_all(self, kind):
144 | """Make sure that if one input section composed fully of NaN pixels then things work."""
145 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False)
146 |
147 | shape = (10, 11)
148 | sn = [0, 100, 527]
149 | imgs = [np.zeros(shape), np.ones(shape) * np.nan, np.zeros(shape)]
150 | gv = GappedVolume(sn, imgs)
151 |
152 | final_volume = ip.interpolate(gv)
153 |
154 | assert np.all(np.isfinite(final_volume))
155 |
156 | @pytest.mark.parametrize(
157 | "kind", ["linear", "nearest", "zero", "slinear", "previous", "next"]
158 | )
159 | def test_nan_some(self, kind):
160 | """Make sure that if input section has a NaN pixel then things work."""
161 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False)
162 |
163 | shape = (10, 11)
164 | sn = [0, 100, 527]
165 |
166 | valid = np.ones(shape, dtype=bool)
167 | valid[5:9, 2:7] = False
168 |
169 | weird_img = np.random.random(shape)
170 | weird_img[5:9, 2:7] = np.nan
171 | imgs = [np.zeros(shape), weird_img, np.zeros(shape)]
172 |
173 | gv = GappedVolume(sn, imgs)
174 |
175 | final_volume = ip.interpolate(gv)
176 |
177 | assert np.all(np.isfinite(final_volume))
178 | assert np.allclose(final_volume[sn[1]][valid], weird_img[valid])
179 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | minversion = 3.1.0
3 | requires = virtualenv >= 20.0.0
4 | source = atlalign
5 | envlist =
6 | lint
7 | py38
8 | py39
9 | docs
10 |
11 | [testenv]
12 | download = true
13 | deps =
14 | lpips_tf @ git+http://github.com/alexlee-gk/lpips-tensorflow.git#egg=lpips_tf
15 | extras =
16 | dev
17 | commands =
18 | pytest {posargs:tests}
19 |
20 | [testenv:lint]
21 | skip_install = true
22 | deps =
23 | flake8
24 | isort
25 | pydocstyle
26 | black==22.3.0
27 | commands =
28 | flake8 setup.py {[tox]source} tests
29 | isort --honor-noqa --profile=black --check setup.py {[tox]source} tests
30 | pydocstyle {[tox]source}
31 | black --check setup.py {[tox]source} tests
32 |
33 | [testenv:format]
34 | skip_install = true
35 | deps =
36 | isort
37 | black
38 | commands =
39 | isort --honor-noqa --profile=black setup.py {[tox]source} tests
40 | black setup.py {[tox]source} tests
41 |
42 | [testenv:docs]
43 | changedir = docs
44 | extras =
45 | dev
46 | docs
47 | allowlist_externals = make
48 | commands =
49 | make clean
50 | make doctest SPHINXOPTS=-W
51 | make html SPHINXOPTS=-W
52 |
53 | [pytest]
54 | addopts =
55 | -v
56 | -m "not todo and not slow and not internet"
57 | --disable-warnings
58 | --strict
59 | --cov=atlalign
60 | --cov-report=term-missing
61 | testpaths = tests
62 | markers =
63 | internet: requires connection to the internet
64 | slow: mark denoting a test that is too slow
65 | todo: mark denoting a test that is not written yet
66 |
67 | [flake8]
68 | count = True
69 | max-line-length = 120
70 | ignore = E402, W503, E203
71 |
72 | [pydocstyle]
73 | convention = numpy
74 |
75 | [gh-actions]
76 | python =
77 | 3.6: py36
78 | 3.7: py37
79 | 3.8: py38
80 | 3.9: py39
81 |
--------------------------------------------------------------------------------