├── .github └── workflows │ └── run-tests.yml ├── .gitignore ├── .readthedocs.yml ├── AUTHORS.md ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── atlalign ├── __init__.py ├── augmentations.py ├── base.py ├── data.py ├── label │ ├── __init__.py │ ├── cli.py │ ├── io.py │ └── new_GUI.py ├── metrics.py ├── ml_utils │ ├── __init__.py │ ├── augmentation.py │ ├── callbacks.py │ ├── io.py │ ├── layers.py │ ├── losses.py │ └── models.py ├── nn.py ├── non_ml │ ├── __init__.py │ └── intensity.py ├── utils.py ├── visualization.py ├── volume.py └── zoo.py ├── docs ├── Makefile ├── _images │ ├── affine_simple.png │ ├── anchoring.png │ ├── annot_warping.png │ ├── antspy.png │ ├── aug_pipeline.png │ ├── clipped_vd.png │ ├── composition.png │ ├── control_points.png │ ├── coronal_interpolator.png │ ├── edge_stretching.png │ ├── evaluation_metrics.png │ ├── example_augmentation.png │ ├── feature_based.png │ ├── image_registration.png │ ├── image_registration_2.png │ ├── int_augmentations.gif │ ├── inverse.png │ ├── labeling_tool.png │ ├── metrics_overview.png │ ├── resizing.png │ ├── typical_dataset.png │ └── warping.png ├── _static │ └── .keep ├── conf.py ├── generate_metadata.py ├── index.rst └── source │ ├── 3d_interpolation.rst │ ├── api │ ├── atlalign.label.rst │ ├── atlalign.ml_utils.rst │ ├── atlalign.non_ml.rst │ ├── atlalign.rst │ └── modules.rst │ ├── building_blocks.rst │ ├── datasets.rst │ ├── deep_learning_data.rst │ ├── deep_learning_inference.rst │ ├── deep_learning_training.rst │ ├── evaluation.rst │ ├── image_registration.rst │ ├── installation.rst │ ├── intensity.rst │ ├── labeling_tool.rst │ └── logo │ └── Atlas_Alignment_banner.jpg ├── scripts └── experiment_1.py ├── setup.py ├── tests ├── conftest.py ├── data │ ├── animals.jpg │ ├── mild_inversion.npy │ └── supervised_dataset.h5 ├── test_augmentations.py ├── test_base.py ├── test_data.py ├── test_metrics.py ├── test_ml_utils │ ├── test_augmentation.py │ ├── test_callbacks.py │ ├── test_io.py │ ├── test_layers.py │ └── test_models.py ├── test_nn.py ├── test_non_ml.py ├── test_utils.py ├── test_visualization.py ├── test_volume.py └── test_zoo.py └── tox.ini /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: ci testing 2 | 3 | 4 | on: 5 | push: 6 | branches: master 7 | pull_request: 8 | 9 | 10 | jobs: 11 | 12 | run_test: 13 | 14 | runs-on: ${{ matrix.os }} 15 | 16 | strategy: 17 | 18 | matrix: 19 | os: [ubuntu-latest] # macos 11 is currently in preview, macos-latest == 1.10.15 20 | python-version: [ 21 | 3.8, 22 | 3.9, 23 | ] 24 | include: 25 | - python-version: 3.8 26 | tox-env: py38 27 | - python-version: 3.9 28 | tox-env: py39 29 | 30 | steps: 31 | 32 | - name: checkout latest commit 33 | uses: actions/checkout@v2 34 | 35 | - name: setup python ${{ matrix.python-version }} 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | 40 | - name: install python dependencies 41 | run: | 42 | python -m pip install --upgrade pip 43 | pip install tox tox-gh-actions 44 | 45 | - name: linting and code style 46 | run: tox -vv -e lint 47 | 48 | - name: tests and coverage 49 | run: tox -vv -e ${{ matrix.tox-env }} -- --color=yes 50 | 51 | - name: docs 52 | run: tox -vv -e docs 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom entries 2 | /data/ 3 | /.idea/ 4 | .DS_Store 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | 143 | # Cython debug symbols 144 | cython_debug/ 145 | 146 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: [] 4 | 5 | sphinx: 6 | builder: html 7 | configuration: docs/conf.py 8 | 9 | build: 10 | image: latest 11 | 12 | python: 13 | version: 3.7 14 | install: 15 | - method: pip 16 | path: . 17 | extra_requirements: 18 | - tf 19 | - docs 20 | system_packages: true 21 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | ## Maintainers 2 | - Jan Krepl 3 | - Francesco Casalegno 4 | - Emilie Delattre 5 | 6 | ## Authors 7 | - Jan Krepl 8 | - Francesco Casalegno 9 | - Emilie Delattre 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This page includes some guidelines to enable you to contribute to the project. 4 | 5 | ## Found a bug? 6 | 7 | If you find a bug in the source code or in using the theme, you can 8 | open an issue on GitHub. 9 | Even better, you can submit a pull request with a fix. 10 | 11 | ## Submission guidelines 12 | 13 | ### Submitting an issue 14 | 15 | Before you submit an issue, please search the issue tracker, maybe an issue 16 | for your problem already exists and the discussion might inform you of workarounds 17 | readily available. 18 | 19 | We want to fix all the issues as soon as possible, but before fixing a bug we 20 | need to reproduce and confirm it. In order to reproduce bugs we will need as 21 | much information as possible, and preferably a sample demonstrating the issue. 22 | 23 | ### Submitting a pull request (PR) 24 | 25 | If you wish to contribute to the code base, please open a pull request by 26 | following GitHub's guidelines. 27 | 28 | ## Development Conventions 29 | 30 | `atlalign` uses: 31 | - Black for formatting code 32 | - Flake8 for linting code 33 | - PyDocStyle for checking docstrings 34 | 35 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Atlas Alignment 4 | 5 | 6 | 7 | 8 | 13 | 14 | 15 | 16 | 21 | 22 | 23 | 24 | 29 | 30 | 31 | 32 | 37 | 38 | 39 | 40 | 54 | 55 | 56 | 57 | 62 | 63 |
Latest Release 9 | 10 | Latest release 11 | 12 |
License 17 | 18 | License 19 | 20 |
Data 25 | 26 | Data 27 | 28 |
Build Status 33 | 34 | Build status 35 | 36 |
Code Style 41 | 42 | Black 43 | 44 | 45 | Isort 46 | 47 | 48 | Pydocstyle 49 | 50 | 51 | Pydocstyle 52 | 53 |
Python Versions 58 | 59 | Python Versions 60 | 61 |
64 | 65 | Atlas Alignment is a toolbox to perform multimodal image registration. It 66 | includes both traditional and supervised deep learning models. 67 | 68 | This project originated from the Blue Brain Project efforts on aligning mouse 69 | brain atlases obtained with ISH gene expression and Nissl stains. 70 | 71 | 72 | ### Official documentation 73 | All details related to installation and logic are described in the 74 | [official documentation](https://atlas-alignment.readthedocs.io/). 75 | 76 | 77 | ### Installation 78 | 79 | #### Installation Requirements 80 | 81 | Some of the functionalities of `atlalign` depend on the [TensorFlow implementation 82 | of the Learned Perceptual Image Patch Similarity (LPIPS)](https://github.com/alexlee-gk/lpips-tensorflow). Unfortunately, the 83 | package is not available on PyPI and must be installed manually as follows 84 | for full functionality. 85 | ```shell script 86 | pip install git+http://github.com/alexlee-gk/lpips-tensorflow.git#egg=lpips_tf 87 | ``` 88 | 89 | You can now move on to installing the actual `atlalign` package! 90 | 91 | #### Installation from PyPI 92 | The `atlalign` package can be easily installed from PyPI. 93 | ```shell script 94 | pip install atlalign 95 | ``` 96 | 97 | #### Installation from source 98 | As an alternative to installing from PyPI, if you want to try the latest version 99 | you can also install from source. 100 | ```shell script 101 | pip install git+https://github.com/BlueBrain/atlas_alignment#egg=atlalign 102 | ``` 103 | 104 | #### Installation for development 105 | If you want a dev install, you should install the latest version from source with 106 | all the extra requirements for running test and generating docs. 107 | ```shell script 108 | git clone https://github.com/BlueBrain/atlas_alignment 109 | cd atlas_alignment 110 | pip install -e .[dev,docs] 111 | ``` 112 | 113 | ### Examples 114 | You can find multiple examples in the documentation. Specifically, make 115 | sure to read the 116 | [Building Blocks](https://atlas-alignment.readthedocs.io/en/latest/source/building_blocks.html) 117 | section of the docs to understand the basics. 118 | 119 | ### Data 120 | You can find example data on [Zenodo](https://zenodo.org/record/4541446#.YCqGFc9Kg4g). 121 | Unzip the files to `~/.atlalign/` folder so that you can use the `data.py` module 122 | without manual specification of paths. 123 | 124 | #### Allen Brain Institute Database 125 | You can find and download ISH data from Allen Brain Institute thanks to 126 | [Atlas Download Tools](https://github.com/BlueBrain/Atlas-Download-Tools) repository. 127 | 128 | ### Funding & Acknowledgment 129 | This project was supported by funding to the Blue Brain 130 | Project, a research center of the Ecole polytechnique fédérale de Lausanne, from 131 | the Swiss government's ETH Board of the Swiss Federal Institutes of Technology. 132 | 133 | COPYRIGHT (c) 2021-2022 Blue Brain Project/EPFL 134 | -------------------------------------------------------------------------------- /atlalign/__init__.py: -------------------------------------------------------------------------------- 1 | """Image registration package. 2 | 3 | Release markers: 4 | X.Y 5 | X.Y.Z for bug fixes 6 | """ 7 | 8 | """ 9 | The package atlalign is a tool for registration of 2D images. 10 | 11 | Copyright (C) 2021 EPFL/Blue Brain Project 12 | 13 | This program is free software: you can redistribute it and/or modify 14 | it under the terms of the GNU Lesser General Public License as published by 15 | the Free Software Foundation, either version 3 of the License, or 16 | (at your option) any later version. 17 | 18 | This program is distributed in the hope that it will be useful, 19 | but WITHOUT ANY WARRANTY; without even the implied warranty of 20 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 21 | GNU Lesser General Public License for more details. 22 | 23 | You should have received a copy of the GNU Lesser General Public License 24 | along with this program. If not, see . 25 | """ 26 | 27 | __version__ = "0.6.2" 28 | -------------------------------------------------------------------------------- /atlalign/augmentations.py: -------------------------------------------------------------------------------- 1 | """Module creating one-to-many augmentations.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import h5py 23 | import numpy as np 24 | from skimage.feature import canny 25 | from skimage.util import img_as_float32 26 | 27 | from atlalign.base import DisplacementField 28 | 29 | 30 | def load_dataset_in_memory(h5_path, dataset_name): 31 | """Load a dataset of a h5 file in memory.""" 32 | with h5py.File(h5_path, "r") as f: 33 | return f[dataset_name][:] 34 | 35 | 36 | class DatasetAugmenter: 37 | """Class that does the augmentation. 38 | 39 | Attributes 40 | ---------- 41 | original_path : str 42 | Path to where the original dataset is located. 43 | 44 | """ 45 | 46 | def __init__(self, original_path): 47 | self.original_path = original_path 48 | 49 | self.n_orig = len(load_dataset_in_memory(self.original_path, "image_id")) 50 | 51 | def augment( 52 | self, 53 | output_path, 54 | n_iter=10, 55 | anchor=True, 56 | p_reg=0.5, 57 | random_state=None, 58 | max_corrupted_pixels=500, 59 | ds_f=8, 60 | max_trials=5, 61 | ): 62 | """Augment the original dataset and create a new one. 63 | 64 | Note that this not modify the original dataset. 65 | 66 | Parameters 67 | ---------- 68 | output_path : str 69 | Path to where the new h5 file stored. 70 | 71 | n_iter : int 72 | Number of augmented samples per each sample in the original dataset. 73 | 74 | anchor : bool 75 | If True, then dvf anchored before inverted. 76 | 77 | p_reg : bool 78 | Probability that we start from a registered image 79 | (rather than the moving). 80 | 81 | random_state : bool 82 | Random state 83 | 84 | max_corrupted_pixels : int 85 | Maximum numbr of corrupted pixels allowed for a dvf - the actual 86 | number is computed as np.sum(df.jacobian() < 0) 87 | 88 | ds_f : int 89 | Downsampling factor for inverses. 1 creates the least artifacts. 90 | 91 | max_trials : int 92 | Max number of attemps to augment before an identity displacement 93 | used as augmentation. 94 | """ 95 | np.random.seed(random_state) 96 | 97 | n_new = n_iter * self.n_orig 98 | print(n_new) 99 | 100 | with h5py.File(self.original_path, "r") as f_orig: 101 | # extract 102 | dset_img_orig = f_orig["img"] 103 | dset_image_id_orig = f_orig["image_id"] 104 | dset_dataset_id_orig = f_orig["dataset_id"] 105 | dset_deltas_xy_orig = f_orig["deltas_xy"] 106 | dset_inv_deltas_xy_orig = f_orig["inv_deltas_xy"] 107 | dset_p_orig = f_orig["p"] 108 | 109 | with h5py.File(output_path, "w") as f_aug: 110 | dset_img_aug = f_aug.create_dataset( 111 | "img", (n_new, 320, 456), dtype="uint8" 112 | ) 113 | dset_image_id_aug = f_aug.create_dataset( 114 | "image_id", (n_new,), dtype="int" 115 | ) 116 | dset_dataset_id_aug = f_aug.create_dataset( 117 | "dataset_id", (n_new,), dtype="int" 118 | ) 119 | dset_p_aug = f_aug.create_dataset("p", (n_new,), dtype="int") 120 | dset_deltas_xy_aug = f_aug.create_dataset( 121 | "deltas_xy", (n_new, 320, 456, 2), dtype=np.float16 122 | ) 123 | dset_inv_deltas_xy_aug = f_aug.create_dataset( 124 | "inv_deltas_xy", (n_new, 320, 456, 2), dtype=np.float16 125 | ) 126 | 127 | for i in range(n_new): 128 | print(i) 129 | i_orig = i % self.n_orig 130 | 131 | mov2reg = DisplacementField( 132 | dset_deltas_xy_orig[i_orig, ..., 0], 133 | dset_deltas_xy_orig[i_orig, ..., 1], 134 | ) 135 | 136 | # copy 137 | dset_image_id_aug[i] = dset_image_id_orig[i_orig] 138 | dset_dataset_id_aug[i] = dset_dataset_id_orig[i_orig] 139 | dset_p_aug[i] = dset_p_orig[i_orig] 140 | 141 | use_reg = np.random.random() > p_reg 142 | print("Using registered: {}".format(use_reg)) 143 | 144 | if not use_reg: 145 | # mov != reg 146 | img_mov = dset_img_orig[i_orig] 147 | else: 148 | # mov=reg 149 | img_mov = mov2reg.warp(dset_img_orig[i_orig]) 150 | mov2reg = DisplacementField.generate( 151 | (320, 456), approach="identity" 152 | ) 153 | 154 | is_nice = False 155 | n_trials = 0 156 | 157 | while not is_nice: 158 | n_trials += 1 159 | 160 | if n_trials == max_trials: 161 | print("Replicating original: out of trials") 162 | dset_img_aug[i] = dset_img_orig[i_orig] 163 | dset_deltas_xy_aug[i] = dset_deltas_xy_orig[i_orig] 164 | dset_inv_deltas_xy_aug[i] = dset_inv_deltas_xy_orig[i_orig] 165 | break 166 | 167 | else: 168 | mov2art = self.generate_mov2art(img_mov) 169 | 170 | reg2mov = mov2reg.pseudo_inverse(ds_f=ds_f) 171 | reg2art = reg2mov(mov2art) 172 | 173 | # anchor 174 | if anchor: 175 | print("ANCHORING") 176 | reg2art = reg2art.anchor( 177 | ds_f=50, smooth=0, h_kept=0.9, w_kept=0.9 178 | ) 179 | 180 | art2reg = reg2art.pseudo_inverse(ds_f=ds_f) 181 | 182 | validity_check = np.all( 183 | np.isfinite(reg2art.delta_x) 184 | ) and np.all(np.isfinite(reg2art.delta_y)) 185 | validity_check &= np.all( 186 | np.isfinite(art2reg.delta_x) 187 | ) and np.all(np.isfinite(art2reg.delta_y)) 188 | jacobian_check = ( 189 | np.sum(reg2art.jacobian < 0) < max_corrupted_pixels 190 | ) 191 | jacobian_check &= ( 192 | np.sum(art2reg.jacobian < 0) < max_corrupted_pixels 193 | ) 194 | 195 | if validity_check and jacobian_check: 196 | is_nice = True 197 | print("Check passed") 198 | else: 199 | print("Check failed") 200 | 201 | if n_trials != max_trials: 202 | dset_img_aug[i] = mov2art.warp(img_mov) 203 | dset_deltas_xy_aug[i] = np.stack( 204 | [art2reg.delta_x, art2reg.delta_y], axis=-1 205 | ) 206 | dset_inv_deltas_xy_aug[i] = np.stack( 207 | [reg2art.delta_x, reg2art.delta_y], axis=-1 208 | ) 209 | 210 | @staticmethod 211 | def generate_mov2art(img_mov, verbose=True, radius_max=60, use_normal=True): 212 | """Generate geometric augmentation and its inverse.""" 213 | shape = img_mov.shape 214 | img_mov_float = img_as_float32(img_mov) 215 | edge_mask = canny(img_mov_float) 216 | 217 | if use_normal: 218 | c = np.random.normal(0.7, 0.3) 219 | else: 220 | c = np.random.random() 221 | 222 | if verbose: 223 | print("Scalar: {}".format(c)) 224 | 225 | mov2art = c * DisplacementField.generate( 226 | shape, 227 | approach="edge_stretching", 228 | edge_mask=edge_mask, 229 | interpolation_method="rbf", 230 | interpolator_kwargs={"function": "linear"}, 231 | n_perturbation_points=6, 232 | radius_max=radius_max, 233 | ) 234 | 235 | return mov2art 236 | -------------------------------------------------------------------------------- /atlalign/label/__init__.py: -------------------------------------------------------------------------------- 1 | """Module containing the interactive labeling tool.""" 2 | # The package atlalign is a tool for registration of 2D images. 3 | # 4 | # Copyright (C) 2021 EPFL/Blue Brain Project 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU Lesser General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | -------------------------------------------------------------------------------- /atlalign/label/cli.py: -------------------------------------------------------------------------------- 1 | """Command line interface implementation.""" 2 | # The package atlalign is a tool for registration of 2D images. 3 | # 4 | # Copyright (C) 2021 EPFL/Blue Brain Project 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU Lesser General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | import argparse 19 | import datetime 20 | import pathlib 21 | import sys 22 | from contextlib import redirect_stdout 23 | 24 | 25 | def main(argv=None): 26 | """Run CLI.""" 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 29 | ) 30 | parser.add_argument( 31 | "ref", 32 | type=str, 33 | help="Either a path to a reference image or a number from [0, 528) " 34 | "representing the coronal dimension in the nissl stain volume.", 35 | ) 36 | parser.add_argument( 37 | "mov", 38 | type=str, 39 | help="Path to a moving image. Needs to be of the same shape as " "reference.", 40 | ) 41 | parser.add_argument( 42 | "output_path", type=str, help="Folder where the outputs will be stored." 43 | ) 44 | parser.add_argument( 45 | "-s", 46 | "--swap", 47 | default=False, 48 | help="Swap to the moving to reference mode.", 49 | action="store_true", 50 | ) 51 | parser.add_argument( 52 | "-g", 53 | "--force-grayscale", 54 | default=False, 55 | help="Force the images to be grayscale. Convert RGB to grayscale if necessary.", 56 | action="store_true", 57 | ) 58 | args = parser.parse_args(argv) 59 | 60 | # Imports 61 | import matplotlib.pyplot as plt 62 | import numpy as np 63 | 64 | from atlalign.data import nissl_volume 65 | from atlalign.label.io import load_image 66 | from atlalign.label.new_GUI import run_gui 67 | 68 | # Read input images 69 | output_channels = 1 if args.force_grayscale else None 70 | if args.ref.isdigit(): 71 | img_ref = nissl_volume()[int(args.ref), ..., 0] 72 | else: 73 | img_ref_path = pathlib.Path(args.ref) 74 | img_ref = load_image( 75 | img_ref_path, output_channels=output_channels, output_dtype="float32" 76 | ) 77 | img_mov_path = pathlib.Path(args.mov) 78 | img_mov = load_image( 79 | img_mov_path, 80 | output_channels=output_channels, 81 | output_dtype="float32", 82 | ) 83 | 84 | # Launch GUI 85 | ( 86 | result_df, 87 | keypoints, 88 | symmetric_registration, 89 | img_reg, 90 | interpolation_method, 91 | kernel, 92 | ) = run_gui(img_ref, img_mov, mode="mov2ref" if args.swap else "ref2mov") 93 | 94 | # Dump results and metadata to disk 95 | output_path = pathlib.Path(args.output_path) 96 | output_path.mkdir(exist_ok=True, parents=True) 97 | 98 | result_df.save(output_path / "df.npy") 99 | np.save(output_path / "img_reg.npy", img_reg) 100 | np.save(output_path / "img_ref.npy", img_ref) 101 | np.save(output_path / "img_mov.npy", img_mov) 102 | plt.imsave(output_path / "img_reg.png", img_reg) 103 | plt.imsave(output_path / "img_ref.png", img_ref) 104 | plt.imsave(output_path / "img_mov.png", img_mov) 105 | with open(output_path / "keypoints.csv", "w") as file, redirect_stdout(file): 106 | if args.swap: 107 | print("mov x,mov y,ref x,ref y") 108 | else: 109 | print("ref x,ref y,mov x,mov y") 110 | for (x1, y1), (x2, y2) in keypoints.items(): 111 | print(f"{x1},{y1},{x2},{y2}") 112 | with open(output_path / "info.log", "w") as file, redirect_stdout(file): 113 | print("Timestamp :", datetime.datetime.now().ctime()) 114 | print("") 115 | print("Parameters") 116 | print("----------") 117 | print("ref :", args.ref) 118 | print("mov :", args.mov) 119 | print("output_path :", output_path.resolve()) 120 | print("swap :", args.swap) 121 | print("force-grayscale :", args.force_grayscale) 122 | print() 123 | print("Interpolation") 124 | print("-------------") 125 | print("Symmetric :", symmetric_registration) 126 | print("Method :", interpolation_method) 127 | print("Kernel :", kernel) 128 | print("Results were saved to", output_path.resolve()) 129 | 130 | 131 | if __name__ == "__main__": 132 | sys.exit(main()) 133 | -------------------------------------------------------------------------------- /atlalign/label/io.py: -------------------------------------------------------------------------------- 1 | """Input and output utilities for the command line interface.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import cv2 23 | 24 | 25 | def load_image( 26 | file_path, 27 | allowed_suffix=("jpg", "png"), 28 | output_dtype=None, 29 | output_channels=None, 30 | output_shape=None, 31 | input_shape=None, 32 | keep_last=False, 33 | ): 34 | """Load image. 35 | 36 | Parameters 37 | ---------- 38 | file_path : str or pathlib.Path 39 | Path to where the image stored. 40 | 41 | allowed_suffix : tuple 42 | List of allowed suffixes. 43 | 44 | output_dtype : str or None 45 | Determines the dtype of the output image. If None, then the same as input. 46 | 47 | output_channels : int, {1, 3} or None 48 | If 1 then grayscale, if 3 then RGB. If None then the sampe as the input image. 49 | 50 | output_shape : tuple 51 | Two element tuple representing (h_output, w_output). 52 | 53 | input_shape : tuple or None 54 | If None no assertion on the input shape. If not None then a tuple representing 55 | (h_input_expected, w_input_expected). 56 | 57 | keep_last : bool 58 | Only active if `output_channels=1`. If True, then the output has shape (h, w, 1). Else (h, w). 59 | 60 | Returns 61 | ------- 62 | img : np.array 63 | Array of shape (h, w) 64 | 65 | """ 66 | raw_input_ = cv2.imread(str(file_path), cv2.IMREAD_UNCHANGED) 67 | 68 | if input_shape is not None and input_shape != raw_input_[:2]: 69 | raise ValueError( 70 | "Asserted input shape {} different than actual one {}".format( 71 | input_shape, raw_input_.shape 72 | ) 73 | ) 74 | 75 | # If len(shape) == 2, i.e. shape = (width, height), then raw_input_ is 76 | # already a grayscale image and we don't need to do any conversion 77 | if len(raw_input_.shape) > 2: 78 | if output_channels is None or output_channels == 3: 79 | input_img = cv2.cvtColor(raw_input_, cv2.COLOR_BGR2RGB) 80 | elif output_channels == 1: 81 | input_img = cv2.cvtColor(raw_input_, cv2.COLOR_BGR2GRAY) 82 | else: 83 | raise ValueError("Invalid output channels: {}".format(output_channels)) 84 | else: 85 | input_img = raw_input_ 86 | 87 | if output_shape is not None: 88 | input_img = cv2.resize(input_img, (output_shape[1], output_shape[0])) 89 | 90 | if input_img.ndim == 3 and input_img.shape[2] == 1 and not keep_last: 91 | input_img = input_img[..., 0] 92 | 93 | if output_dtype: 94 | if "float" in output_dtype: 95 | input_img = (input_img / 255).astype(output_dtype) 96 | 97 | return input_img 98 | -------------------------------------------------------------------------------- /atlalign/ml_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Prevent 3 level imports for important objects.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | from atlalign.ml_utils.augmentation import augmenter_1 # noqa 23 | from atlalign.ml_utils.callbacks import MLFlowCallback, get_mlflow_artifact_path # noqa 24 | from atlalign.ml_utils.io import SupervisedGenerator # noqa 25 | from atlalign.ml_utils.layers import ( # noqa 26 | Affine2DVF, 27 | BilinearInterpolation, 28 | DVFComposition, 29 | ExtractMoving, 30 | NoOp, 31 | block_stn, 32 | ) 33 | from atlalign.ml_utils.losses import ( # noqa 34 | DVF2IMG, 35 | NCC, 36 | Grad, 37 | Mixer, 38 | PerceptualLoss, 39 | VDClipper, 40 | cross_correlation, 41 | jacobian, 42 | jacobian_distance, 43 | mse_po, 44 | psnr, 45 | ssim, 46 | vector_distance, 47 | ) 48 | from atlalign.ml_utils.models import ( # noqa 49 | load_model, 50 | merge_global_local, 51 | replace_lambda_in_config, 52 | save_model, 53 | ) 54 | 55 | # Create utility dictionary 56 | ALL_IMAGE_LOSSES = { 57 | "mae": "mae", 58 | "mse": "mse", 59 | "ncc_5": NCC(win=5).loss, 60 | "ncc_9": NCC(win=9).loss, 61 | "ncc_12": NCC(win=12).loss, 62 | "ncc_20": NCC(win=20).loss, 63 | "pearson": cross_correlation, 64 | "perceptual_loss_net-lin_alex": PerceptualLoss(model="net-lin", net="alex").loss, 65 | "perceptual_loss_net-lin_vgg": PerceptualLoss(model="net-lin", net="vgg").loss, 66 | "perceptual_loss_net_alex": PerceptualLoss(model="net", net="alex").loss, 67 | "perceptual_loss_net_vgg": PerceptualLoss(model="net", net="vgg").loss, 68 | "psnr": psnr, 69 | "ssim": ssim, 70 | } 71 | ALL_DVF_LOSSES = { 72 | "grad": Grad().loss, 73 | "jacobian": jacobian, 74 | "jacobian_distance": jacobian_distance, 75 | "mae": "mae", 76 | "mse": "mse", 77 | "mse_po": mse_po, 78 | "vector_distance": vector_distance, 79 | "vdclip2": VDClipper(20, power=2).loss, 80 | "vdclip3": VDClipper(20, power=3).loss, 81 | "vector_jacobian_distance": Mixer(vector_distance, jacobian_distance).loss, 82 | "vector_jacobian_distance_02": Mixer( 83 | vector_distance, jacobian_distance, weights=[0.2, 0.8] 84 | ).loss, 85 | } 86 | 87 | ALL_DVF_LOSSES = { 88 | **ALL_DVF_LOSSES, 89 | **{k: DVF2IMG(v).loss for k, v in ALL_IMAGE_LOSSES.items() if callable(v)}, 90 | } 91 | 92 | all_dvf_losses_items = list(ALL_DVF_LOSSES.items()) 93 | all_dvf_losses_items.sort(key=lambda x: x[0]) 94 | 95 | MIXED_DVF_LOSSES = { 96 | "{}&{}".format(k_i, k_o): Mixer(v_i, v_o).loss 97 | for i, (k_o, v_o) in enumerate(all_dvf_losses_items) 98 | for j, (k_i, v_i) in enumerate(all_dvf_losses_items) 99 | if j < i and callable(v_i) and callable(v_o) 100 | } 101 | 102 | ALL_DVF_LOSSES = {**ALL_DVF_LOSSES, **MIXED_DVF_LOSSES} 103 | -------------------------------------------------------------------------------- /atlalign/ml_utils/augmentation.py: -------------------------------------------------------------------------------- 1 | """Augmentation related tools.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import imgaug.augmenters as iaa 23 | 24 | 25 | def augmenter_1(p=0.99): 26 | """Create augmenter. 27 | 28 | Contains no coordinate transforms. 29 | 30 | Parameters 31 | ---------- 32 | p : float 33 | Number in [0, 1] representing the probability of a random augmentation happening. 34 | 35 | Returns 36 | ------- 37 | seq : iaa.Augmenter 38 | Augmenter where each augmentation was manually inspected and makes 39 | sense. 40 | 41 | """ 42 | subsubseq_1 = iaa.Multiply(mul=(0.8, 1.2)) 43 | subsubseq_2 = iaa.Sequential([iaa.Sharpen(alpha=(0, 1))]) 44 | 45 | subsubseq_3 = iaa.Sequential([iaa.EdgeDetect(alpha=(0, 0.9))]) 46 | 47 | subsubseq_4 = iaa.OneOf([iaa.GaussianBlur((0, 3.0)), iaa.AverageBlur(k=(2, 7))]) 48 | 49 | subsubseq_5 = iaa.AdditiveGaussianNoise(loc=(0, 0.5), scale=(0, 0.2)) 50 | 51 | subsubseq_6 = iaa.Add((-0.3, 0.3)) 52 | 53 | subsubseq_7 = iaa.Invert(p=1) 54 | 55 | subsubseq_8 = iaa.CoarseDropout(p=0.25, size_percent=(0.005, 0.06)) 56 | 57 | subsubseq_9 = iaa.SigmoidContrast(gain=(0.8, 1.2)) 58 | 59 | subsubseq_10 = iaa.LinearContrast(alpha=(0.8, 1.2)) 60 | 61 | subsubseq_11 = iaa.Sequential([iaa.Emboss(alpha=(0, 1))]) 62 | 63 | seq = iaa.Sometimes( 64 | p, 65 | iaa.OneOf( 66 | [ 67 | subsubseq_1, 68 | subsubseq_2, 69 | subsubseq_3, 70 | subsubseq_4, 71 | subsubseq_5, 72 | subsubseq_6, 73 | subsubseq_7, 74 | subsubseq_8, 75 | subsubseq_9, 76 | subsubseq_10, 77 | subsubseq_11, 78 | ] 79 | ), 80 | ) 81 | 82 | return seq 83 | -------------------------------------------------------------------------------- /atlalign/ml_utils/callbacks.py: -------------------------------------------------------------------------------- 1 | """Callbacks and aggregation functions.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import pathlib 23 | 24 | import h5py 25 | import mlflow 26 | import pandas as pd 27 | from tensorflow import keras 28 | 29 | from atlalign.data import annotation_volume, segmentation_collapsing_labels 30 | from atlalign.metrics import evaluate_single 31 | from atlalign.ml_utils.io import SupervisedGenerator 32 | 33 | 34 | def get_mlflow_artifact_path(start_char=7): 35 | """Get path to the MLFlow artifacts of the active run. 36 | 37 | Stupid implementation. 38 | 39 | Parameters 40 | ---------- 41 | start_char : int 42 | Since the string will start like "file:///actual/path..." we just 43 | slice it. 44 | """ 45 | return pathlib.Path(mlflow.active_run().info.artifact_uri[start_char:]) 46 | 47 | 48 | class MLFlowCallback(keras.callbacks.Callback): 49 | """Logs metrics into ML. 50 | 51 | Notes 52 | ----- 53 | Only runs inside of an mlflow context. 54 | 55 | Parameters 56 | ---------- 57 | merged_path : str 58 | Path to the master h5 file containing all the data. 59 | 60 | train_original_ixs_path : str 61 | Path to where original training indices are stored. 62 | 63 | val_original_ixs_path : str 64 | Path to where original validation indices are stored. 65 | 66 | freq : int 67 | Reports metrics on every `freq` batch. 68 | 69 | workers : int 70 | Number of workers to be used for each of the evaluations. 71 | 72 | return_inverse : bool 73 | If True, then generators behave differently. 74 | 75 | starting_step : int 76 | Useful when we want to use a checkpointed model and log metrics as of a different step then 1. 77 | 78 | use_validation : int 79 | If True, then the custom metrics are computed on the validation set. 80 | Otherwise they will be computed on the training set. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | merged_path, 86 | train_original_ixs_path, 87 | val_original_ixs_path, 88 | freq=10, 89 | workers=1, 90 | return_inverse=False, 91 | starting_step=0, 92 | use_validation=True, 93 | ): 94 | super().__init__() 95 | 96 | # Check if inside of an mlflow context 97 | if mlflow.active_run() is None: 98 | raise ValueError( 99 | "To use the MLFlowCallback one needs to be inside of a mlflow.start_run context." 100 | ) 101 | 102 | # mlflow 103 | self.root_path = get_mlflow_artifact_path() 104 | mlflow.log_params( 105 | { 106 | "train_original_ixs_path": train_original_ixs_path, 107 | "val_original_ixs_path": val_original_ixs_path, 108 | "merged_path": merged_path, 109 | } 110 | ) 111 | 112 | self.train_original_gen = SupervisedGenerator( 113 | merged_path, 114 | indexes=train_original_ixs_path, 115 | shuffle=False, 116 | batch_size=1, 117 | return_inverse=return_inverse, 118 | ) 119 | 120 | self.val_original_gen = SupervisedGenerator( 121 | merged_path, 122 | indexes=val_original_ixs_path, 123 | shuffle=False, 124 | batch_size=1, 125 | return_inverse=return_inverse, 126 | ) 127 | self.freq = freq 128 | 129 | self.workers = workers 130 | self.overall_batch = starting_step 131 | self.use_validation = use_validation 132 | 133 | def on_train_begin(self, logs=None): 134 | """Save model architecture.""" 135 | arch_path = self.root_path / "architecture" 136 | checkpoints_path = self.root_path / "checkpoints" 137 | 138 | arch_path.mkdir(parents=True, exist_ok=True) 139 | checkpoints_path.mkdir(parents=True, exist_ok=True) 140 | 141 | def on_batch_end(self, batch, logs=None): 142 | """Log metrics to mlflow. 143 | 144 | The goal here is two extract 3 types of metrics: 145 | - train_merged - extracted from logs (it is a running average over epoch) 146 | - train_original - computed via evaluate_generator 147 | - val_original - computed via evaluate_generator 148 | """ 149 | self.overall_batch += 1 150 | 151 | if self.overall_batch % self.freq != 0: 152 | return 153 | 154 | model = self.model 155 | metric_names = model.metrics_names 156 | 157 | all_metrics = {} 158 | 159 | # Keras 160 | all_metrics.update( 161 | {"{}_train_merged".format(metric): logs[metric] for metric in metric_names} 162 | ) 163 | 164 | eval_train_original = model.evaluate_generator( 165 | self.train_original_gen, workers=self.workers 166 | ) 167 | all_metrics.update( 168 | { 169 | "{}_train_original".format(metric): value 170 | for metric, value in zip(metric_names, eval_train_original) 171 | } 172 | ) 173 | 174 | eval_val_original = model.evaluate_generator( 175 | self.val_original_gen, workers=self.workers 176 | ) 177 | all_metrics.update( 178 | { 179 | "{}_val_original".format(metric): value 180 | for metric, value in zip(metric_names, eval_val_original) 181 | } 182 | ) 183 | 184 | # Custom 185 | print( 186 | "\nComputing custom metrics on {} set!".format( 187 | "val" if self.use_validation else "train" 188 | ) 189 | ) 190 | gen = self.val_original_gen if self.use_validation else self.train_original_gen 191 | 192 | external_metrics_df = self.compute_external_metrics(model, gen) 193 | 194 | stats_dir = self.root_path / str(self.overall_batch) / "stats" 195 | stats_dir.mkdir(parents=True, exist_ok=True) 196 | 197 | external_metrics_df.to_csv(str(stats_dir / "stats.csv")) 198 | external_metrics_df.to_html(str(stats_dir / "stats.html")) 199 | 200 | external_metrics = dict(external_metrics_df.mean()) 201 | all_metrics.update(external_metrics) 202 | 203 | # log into mlflow 204 | mlflow.log_metrics(all_metrics, step=self.overall_batch) 205 | 206 | keras.models.save_model( 207 | model, 208 | str( 209 | self.root_path 210 | / "checkpoints" 211 | / "model_{}.h5".format(self.overall_batch) 212 | ), 213 | ) 214 | 215 | @staticmethod 216 | def compute_external_metrics(model, gen): 217 | """Compute external matrics sample by sample. 218 | 219 | Parameters 220 | ---------- 221 | model 222 | Keras model 223 | 224 | gen : SupervisedGenerator 225 | Generator 226 | 227 | Returns 228 | ------- 229 | metrics : dict 230 | Various metrics. 231 | """ 232 | # checks 233 | if gen.shuffle: 234 | raise ValueError("Shuffling is not allowed for external metrics!") 235 | if gen.batch_size != 1: 236 | raise ValueError("Batch size has to be 1 for external metrics") 237 | 238 | # Prepare annotation related stuff (load in RAM, small arrays) 239 | indexes = gen.indexes 240 | with h5py.File(gen.hdf_path, "r") as f: 241 | ps = f["p"][:][indexes] 242 | ids = f["image_id"][:][indexes] 243 | 244 | avol = annotation_volume() 245 | collapsing_labels = segmentation_collapsing_labels() 246 | 247 | external_metrics_per_sample = [] 248 | 249 | for i, p in enumerate(ps): 250 | sample = gen[i] # data[indexes[i]] 251 | if gen.return_inverse: 252 | img_mov = sample[0][0][0, ..., 1] 253 | deltas_true = sample[1][1][0] 254 | deltas_true_inv = sample[1][2][0] 255 | 256 | else: 257 | img_mov = sample[0][0, ..., 1] 258 | deltas_true = sample[1][1][0] 259 | deltas_true_inv = None 260 | 261 | deltas_pred = model.predict(sample[0])[1][0] 262 | deltas_pred_inv = None # we do not use the predicted one 263 | 264 | # compute external metrics 265 | res, images = evaluate_single( 266 | deltas_true, 267 | deltas_pred, 268 | img_mov, 269 | ds_f=8, # orig 8 270 | p=p, 271 | deltas_true_inv=deltas_true_inv, 272 | deltas_pred_inv=deltas_pred_inv, 273 | avol=avol, 274 | collapsing_labels=collapsing_labels, 275 | depths=(0, 1, 2, 3, 4, 5, 6, 7, 8), 276 | ) 277 | external_metrics_per_sample.append(res) 278 | 279 | external_metrics_df = pd.DataFrame(external_metrics_per_sample, index=ids) 280 | 281 | return external_metrics_df 282 | -------------------------------------------------------------------------------- /atlalign/ml_utils/io.py: -------------------------------------------------------------------------------- 1 | """Collection of functions dealing with input and output.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import copy 23 | import datetime 24 | import pathlib 25 | 26 | import h5py 27 | import mlflow 28 | import numpy as np 29 | from tensorflow import keras 30 | 31 | from atlalign.base import DisplacementField 32 | from atlalign.data import nissl_volume 33 | 34 | 35 | class SupervisedGenerator(keras.utils.Sequence): 36 | """Generator streaming supervised data from a HDF5 files. 37 | 38 | Parameters 39 | ---------- 40 | hdf_path : str or pathlib.Path 41 | Path to where the hdf5 file is. 42 | 43 | batch_size : int 44 | Batch size. 45 | 46 | shuffle : bool 47 | If True, then data shuffled at the end of each epoch and also in the constructor. 48 | 49 | augmenter_ref : None or imgaug.augmenters.Sequential 50 | If None, no augmentation. If instance of a imgaug `Sequential` then its a pipeline that will be used 51 | to augment all reference images in a batch. 52 | 53 | augmenter_mov : None or imgaug.augmenters.Sequential 54 | If None, no augmentation. If instance of a imgaug `Sequential` then its a pipeline that will be used 55 | to augment all moving images in a batch. 56 | 57 | return_inverse : bool 58 | If True, then targets are [img_reg, dvf, dvf_inv] else its only [img_reg, dvf]. 59 | 60 | indexes : None or list or str/pathlib.Path 61 | A user defined list of indices to be used for streaming. This is used to give user the chance 62 | to only stream parts of the data. If str or pathlib.Path then read from a .npy file 63 | 64 | Attributes 65 | ---------- 66 | indexes : list 67 | List of indices determining the slicing order. 68 | 69 | volume : np.array 70 | Array representing the nissle stain volume of dtype float32. 71 | 72 | times : list 73 | List of timedeltas representing the for each batch yeild. 74 | """ 75 | 76 | def __init__( 77 | self, 78 | hdf_path, 79 | batch_size=32, 80 | shuffle=False, 81 | augmenter_ref=None, 82 | augmenter_mov=None, 83 | return_inverse=False, 84 | mlflow_log=False, 85 | indexes=None, 86 | ): 87 | 88 | self._locals = locals() 89 | del self._locals["self"] 90 | self._locals["augmenter_mov"] = None if augmenter_mov is None else "active" 91 | self._locals["augmenter_ref"] = None if augmenter_ref is None else "active" 92 | 93 | if mlflow_log: 94 | mlflow.log_params(self._locals) 95 | 96 | self.hdf_path = str(hdf_path) 97 | self.batch_size = batch_size 98 | self.shuffle = shuffle 99 | self.augmenter_ref = augmenter_ref 100 | self.augmenter_mov = augmenter_mov 101 | self.return_inverse = return_inverse 102 | 103 | with h5py.File(self.hdf_path, "r") as f: 104 | length = len(f["p"]) 105 | 106 | if indexes is None: 107 | self.indexes = list(np.arange(length)) 108 | 109 | elif isinstance(indexes, list): 110 | self.indexes = indexes 111 | 112 | elif isinstance(indexes, (str, pathlib.Path)): 113 | self.indexes = list(np.load(str(indexes))) 114 | 115 | else: 116 | raise TypeError("Invalid indexes type {}".format(type(indexes))) 117 | 118 | self.volume = nissl_volume() 119 | 120 | self.times = [] 121 | self.temp = [] 122 | 123 | self.on_epoch_end() 124 | 125 | def __len__(self): 126 | """Length of the iterator = number of steps per epoch.""" 127 | n_samples = len(self.indexes) 128 | 129 | return int(np.floor(n_samples / self.batch_size)) 130 | 131 | def __getitem__(self, index): 132 | """Load samples in memory and possibly augment. 133 | 134 | Parameters 135 | ---------- 136 | index : int 137 | Integer representing the index of to be returned batch. 138 | 139 | Returns 140 | ------- 141 | X : np.array 142 | Array of shape (`self.batch_size`, 320, 456, 2) representing the stacked samples of reference 143 | and moving images. 144 | 145 | targets : list 146 | If `self.return_inverse=False` then 2 element list. First element represents the true registered images 147 | - shape (`self.batch_size`, 320, 456, 1). The second element is a batch of ground truth displacement 148 | vector fields - shape (`self.batch_size`, 320, 456, 2). 149 | If `self.return_inverse=True` then 3 element list. The first two elements are like above and the 150 | third one is a batch of ground truth inverse displacements fields (warping images in reference space to 151 | moving space) of shape (`self.batch_size`, 320, 456, 2). 152 | """ 153 | begin_time = datetime.datetime.now() 154 | indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size] 155 | sorted_indexes = sorted(indexes) # hdf5 only supports sorted indexing 156 | 157 | # Generate indexes of the batch 158 | with h5py.File(self.hdf_path, "r") as f: 159 | dset_img = f["img"] 160 | dset_deltas_xy = f["deltas_xy"] 161 | dset_inv_deltas_xy = f["inv_deltas_xy"] 162 | dset_p = f["p"] 163 | 164 | sn = np.minimum( 165 | dset_p[sorted_indexes] // 25, np.ones(len(indexes), dtype=int) * 527 166 | ) 167 | ref_images = self.volume[sn] 168 | mov_images = ( 169 | dset_img[sorted_indexes][..., np.newaxis].astype("float32") / 255 170 | ) 171 | batch_deltas_xy = dset_deltas_xy[sorted_indexes] 172 | batch_inv_deltas_xy = ( 173 | dset_inv_deltas_xy[sorted_indexes] if self.return_inverse else None 174 | ) 175 | 176 | if self.augmenter_ref is not None: 177 | ref_images = self.augmenter_ref.augment_images(ref_images) 178 | 179 | if self.augmenter_mov is not None: 180 | mov_images = self.augmenter_mov.augment_images(mov_images) 181 | 182 | X = np.concatenate([ref_images, mov_images], axis=3) 183 | if self.return_inverse: 184 | X_mr = np.concatenate([mov_images, ref_images], axis=3) 185 | 186 | # Registered images 187 | reg_images = np.zeros_like(mov_images) 188 | 189 | for i in range(len(mov_images)): 190 | df = DisplacementField( 191 | batch_deltas_xy[i, ..., 0], batch_deltas_xy[i, ..., 1] 192 | ) 193 | assert df.is_valid, "{} is not valid".format(sorted_indexes[i]) 194 | reg_images[i, ..., 0] = df.warp(mov_images[i, ..., 0]) 195 | 196 | self.times.append((datetime.datetime.now() - begin_time)) 197 | 198 | if self.return_inverse: 199 | 200 | return [X, X_mr], [reg_images, batch_deltas_xy, batch_inv_deltas_xy] 201 | else: 202 | return X, [reg_images, batch_deltas_xy] 203 | 204 | def on_epoch_end(self): 205 | """Take end of epoch action.""" 206 | if self.shuffle: 207 | np.random.shuffle(self.indexes) 208 | 209 | def get_all_data(self): 210 | """Load entire dataset into memory.""" 211 | orig_params = copy.deepcopy(self._locals) 212 | 213 | orig_params["batch_size"] = 1 214 | orig_params["shuffle"] = False 215 | orig_params["mlflow_log"] = False 216 | 217 | new_gen = self.__class__(**orig_params) 218 | 219 | all_inps, all_outs = [], [] 220 | for inps, outs in new_gen: 221 | all_inps.append(inps) 222 | all_outs.append(outs) 223 | 224 | return all_inps, all_outs 225 | -------------------------------------------------------------------------------- /atlalign/non_ml/__init__.py: -------------------------------------------------------------------------------- 1 | """Registration algorithm that do not use learning approaches.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | from atlalign.non_ml.intensity import antspy_registration # noqa 23 | -------------------------------------------------------------------------------- /atlalign/non_ml/intensity.py: -------------------------------------------------------------------------------- 1 | """Collection of intensity based registration methods.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | import warnings 22 | 23 | warnings.simplefilter(action="ignore") # noqa 24 | 25 | import ants # noqa 26 | import nibabel as nib # noqa 27 | 28 | from atlalign.base import GLOBAL_CACHE_FOLDER, DisplacementField # noqa 29 | 30 | 31 | def antspy_registration( 32 | fixed_img, 33 | moving_img, 34 | registration_type="SyN", 35 | reg_iterations=(40, 20, 0), 36 | aff_metric="mattes", 37 | syn_metric="mattes", 38 | verbose=False, 39 | initial_transform=None, 40 | path=GLOBAL_CACHE_FOLDER, 41 | ): 42 | """Register images using ANTsPY. 43 | 44 | Parameters 45 | ---------- 46 | fixed_img: np.ndarray 47 | Fixed image. 48 | 49 | moving_img: np.ndarray 50 | Moving image to register. 51 | 52 | registration_type: {'Translation', 'Rigid', 'Similarity', 'QuickRigid', 'DenseRigid', 'BOLDRigid', 'Affine', 53 | 'AffineFast', 'BOLDAffine', 'TRSAA', 'ElasticSyN', 'SyN', 'SyNRA', 'SyNOnly', 'SyNCC', 'SyNabp', 54 | 'SyNBold', 'SyNBoldAff', 'SyNAggro', 'TVMSQ', 'TVMSQC'}, default 'SyN' 55 | 56 | Optimization algorithm to use to register (more info: https://antspy.readthedocs.io/en/latest/registration. 57 | html?highlight=registration#ants.registration) 58 | 59 | reg_iterations: tuple, default (40, 20, 0) 60 | Vector of iterations for SyN. 61 | 62 | aff_metric: {'GC', 'mattes', 'meansquares'}, default 'mattes' 63 | The metric for the affine part. 64 | 65 | syn_metric: {'CC', 'mattes', 'meansquares', 'demons'}, default 'mattes' 66 | The metric for the SyN part. 67 | 68 | verbose : bool, default False 69 | If True, then the inner solver prints convergence related information in standard output. 70 | 71 | path : str 72 | Path to a folder to where to save the `.nii.gz` file representing the composite transform. 73 | 74 | initial_transform : list or None 75 | Transforms to prepend the before the registration. 76 | 77 | Returns 78 | ------- 79 | df: DisplacementField 80 | Displacement field between the moving and the fixed image 81 | 82 | meta : dict 83 | Contains relevant images and paths. 84 | 85 | """ 86 | path = str(path) 87 | path += "" if path[-1] == "/" else "/" 88 | 89 | fixed_ants_image = ants.image_clone(ants.from_numpy(fixed_img), pixeltype="float") 90 | moving_ants_image = ants.image_clone(ants.from_numpy(moving_img), pixeltype="float") 91 | meta = ants.registration( 92 | fixed_ants_image, 93 | moving_ants_image, 94 | registration_type, 95 | reg_iterations=reg_iterations, 96 | aff_metric=aff_metric, 97 | syn_metric=syn_metric, 98 | verbose=verbose, 99 | initial_transform=initial_transform, 100 | syn_sampling=32, 101 | aff_sampling=32, 102 | ) 103 | 104 | filename = ants.apply_transforms( 105 | fixed_ants_image, 106 | moving_ants_image, 107 | meta["fwdtransforms"], 108 | compose=path + "final_transform", 109 | ) 110 | 111 | df = nib.load(filename) 112 | data = df.get_fdata() 113 | data = data.squeeze() 114 | dx = data[:, :, 1] 115 | dy = data[:, :, 0] 116 | df_final = DisplacementField(dx, dy) 117 | 118 | return df_final, meta 119 | -------------------------------------------------------------------------------- /atlalign/utils.py: -------------------------------------------------------------------------------- 1 | """Collection of helper classes and function that do not deserve to be in base.py. 2 | 3 | Notes 4 | ----- 5 | This module cannot import from anywhere else within this project to prevent circular dependencies. 6 | 7 | """ 8 | 9 | """ 10 | The package atlalign is a tool for registration of 2D images. 11 | 12 | Copyright (C) 2021 EPFL/Blue Brain Project 13 | 14 | This program is free software: you can redistribute it and/or modify 15 | it under the terms of the GNU Lesser General Public License as published by 16 | the Free Software Foundation, either version 3 of the License, or 17 | (at your option) any later version. 18 | 19 | This program is distributed in the hope that it will be useful, 20 | but WITHOUT ANY WARRANTY; without even the implied warranty of 21 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 | GNU Lesser General Public License for more details. 23 | 24 | You should have received a copy of the GNU Lesser General Public License 25 | along with this program. If not, see . 26 | """ 27 | 28 | import numpy as np 29 | import scipy.spatial.qhull as qhull 30 | 31 | 32 | def _triangulate(xyz, uvw): 33 | """Perform Delaunay triangulation. 34 | 35 | Parameters 36 | ---------- 37 | xyz : np.ndarray 38 | An array of shape (N, 2) where each row represents one point in 2D (stable points) for which we know the 39 | function value. 40 | 41 | uvw : np.ndarray 42 | An array of shape (K, 2) where each row represents one point in 2D (query point) for which we want to 43 | interpolate the function value. 44 | 45 | Returns 46 | ------- 47 | vertices : np.ndarray 48 | An array of shape (K, 3) representing the triangle vertices of each query point. Note that these 49 | vertices are always stable points (from range [O, N)) 50 | 51 | wts : np.ndarray 52 | An array of shape (K, 3) representing the weights of respective vertices at each query point. 53 | 54 | """ 55 | tri = qhull.Delaunay(xyz) 56 | simplex = tri.find_simplex(uvw) 57 | vertices = np.take(tri.simplices, simplex, axis=0) 58 | temp = np.take(tri.transform, simplex, axis=0) 59 | delta = uvw - temp[:, 2] 60 | bary = np.einsum("njk,nk->nj", temp[:, :2, :], delta) 61 | wts = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True))) 62 | 63 | return vertices, wts 64 | 65 | 66 | def _interpolate(values, vertices, wts): 67 | """Interpolate inside a triangle. 68 | 69 | Parameters 70 | ---------- 71 | values : np.ndarray 72 | An array of shape (N,) that represents function value on the know points which are the vertices after 73 | Delaunay triangulation. 74 | 75 | vertices : np.ndarray 76 | An array of shape (K, 3) representing the triangle vertices of each query point. Note that these 77 | vertices are always stable points (from range [O, N)). 78 | 79 | wts : np.ndarray 80 | An array of shape (K, 3) representing the weights of respective vertices at each query point. 81 | 82 | Returns 83 | ------- 84 | interpolations : np.ndarray 85 | An array of shape (K,) representing the interpolated function values on the query points. 86 | 87 | """ 88 | return np.einsum("nj,nj->n", np.take(values, vertices), wts) 89 | 90 | 91 | def griddata_custom(points, values_f_1, values_f_2, xi): 92 | """Run griddata extensions that performs only one triangulation. 93 | 94 | Notes 95 | ----- 96 | The scipy implementation does not allow to separate triangulation from interpolation. Since we need 97 | to evaluate 2 different functions on the !same! non-regular grid if points the triangulation can be simply 98 | just done once and stored. 99 | 100 | Parameters 101 | ---------- 102 | points : np.ndarray 103 | An array of shape (N, 2) where each row represents one point in 2D for which we know the function value. 104 | 105 | values_f_1 : np.ndarray 106 | An array of shape (N,) where each row represents a value of function f_1 on the corresponding point in `points`. 107 | 108 | values_f_2 : np.ndarray 109 | An array of shape (N,) where each row represents a value of function f_2 on the corresponding point in `points`. 110 | 111 | xi : tuple 112 | Tuple of 2 np.ndarray of shapes (h, w) representing the x and y coordinates of the points where we want to 113 | interpolate data. Note that this is simply the result of `np.meshgrid` if our points of interest lie on a 114 | regular grid. 115 | 116 | Returns 117 | ------- 118 | f_1_interpolation_on_xi : np.ndarray 119 | An array of shape (h, w) representing the interpolation of f_1 on the `xi` points. 120 | 121 | f_2_interpolation_on_xi : np.ndarray 122 | An array of shape (h, w) representing the interpolation of f_2 on the `xi` points. 123 | 124 | References 125 | ---------- 126 | https://stackoverflow.com/questions/20915502/speedup-scipy-griddata-for-multiple-interpolations-between-two-irregular-grids # noqa 127 | 128 | """ 129 | if isinstance(xi, tuple): 130 | shape = xi[0].shape 131 | xi = np.hstack((xi[0].reshape(-1, 1), xi[1].reshape(-1, 1))) # possible speedup 132 | 133 | else: 134 | raise TypeError("The xi needs to be a tuple of equally shaped np.ndarrays.") 135 | 136 | vertices, wts = _triangulate(points, xi) 137 | 138 | f_1_interpolation_on_xi = _interpolate(values_f_1, vertices, wts).reshape(shape) 139 | f_2_interpolation_on_xi = _interpolate(values_f_2, vertices, wts).reshape(shape) 140 | 141 | return f_1_interpolation_on_xi, f_2_interpolation_on_xi 142 | 143 | 144 | def _find_all_children(d, children_list=None): 145 | """Construct a list of all the ids of the children of a node and the node itself. 146 | 147 | Parameters 148 | ---------- 149 | d : dict 150 | Dictionary node from whom we want the list of all the children and children's children. 151 | 152 | children_list : list, default None 153 | List of children which has to be empty for the first iteration of the function. 154 | 155 | Returns 156 | ------- 157 | children_list : list 158 | List of children's ids. 159 | 160 | """ 161 | if children_list is None: 162 | children_list = [] 163 | 164 | for key, value in d.items(): 165 | if key == "id": 166 | children_list.append(value) 167 | if isinstance(value, list): 168 | for child in value: 169 | _find_all_children(child, children_list) 170 | 171 | return children_list 172 | 173 | 174 | def _find_concatenate_labels(d, chosen_depth, dict_of_labels=None, current_depth=0): 175 | """Construct a dictionary which has for each key, the value of the new label after concatenation. 176 | 177 | Parameters 178 | ---------- 179 | d : dict 180 | Dictionary node for which we want to concatenate some ids depending on the depth branch. 181 | 182 | chosen_depth : int 183 | Depth at which it is wanted to concatenate the labels. 184 | 185 | dict_of_labels : dict, default {} 186 | Dictionary of corresponding labels (empty at the first call). 187 | 188 | current_depth : int, default 0 189 | Depth of the dictionary node. 190 | 191 | Returns 192 | ------- 193 | dict_of_labels: dict 194 | Dictionary of corresponding labels after concatenation of labels tree. 195 | 196 | """ 197 | if dict_of_labels is None: 198 | dict_of_labels = {} 199 | 200 | if current_depth < chosen_depth: 201 | for key, value in d.items(): 202 | if key == "id": 203 | dict_of_labels[value] = value 204 | if isinstance(value, list): 205 | current_depth = current_depth + 1 206 | for child in value: 207 | _find_concatenate_labels( 208 | child, 209 | chosen_depth, 210 | dict_of_labels=dict_of_labels, 211 | current_depth=current_depth, 212 | ) 213 | else: 214 | children_list = [] 215 | _find_all_children(d, children_list) 216 | for key, value in d.items(): 217 | if key == "id": 218 | for child in children_list: 219 | dict_of_labels[child] = value 220 | 221 | return dict_of_labels 222 | 223 | 224 | def find_labels_dic(segmentation_array, dic, chosen_depth): 225 | """Collapse existing labels into parent labels corresponding to the tree provided in a dictionary. 226 | 227 | Parameters 228 | ---------- 229 | segmentation_array : np.array 230 | Annotation array before the concatenation of the labels. 231 | 232 | dic : dict 233 | Dictionary of tree of labels. 234 | 235 | chosen_depth : int 236 | Depth at which it is wanted to concatenate the labels. 237 | 238 | Returns 239 | ------- 240 | new_segmentation_array : np.array 241 | New Annotation array with the concatenation of the labels at the desired depth. If a specific label 242 | does not exist in the tree it is assigned -1. 243 | 244 | """ 245 | labels_dic = _find_concatenate_labels(dic, chosen_depth) 246 | 247 | new_segmentation_array = segmentation_array.copy() 248 | all_labels = np.unique(segmentation_array) 249 | 250 | for label in all_labels: 251 | if label != 0: 252 | new_label = labels_dic.get(label, -1) 253 | new_segmentation_array[new_segmentation_array == label] = new_label 254 | 255 | return new_segmentation_array 256 | -------------------------------------------------------------------------------- /atlalign/volume.py: -------------------------------------------------------------------------------- 1 | """Collection of tools for aggregating slices to 3D models.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import numpy as np 23 | import scipy 24 | 25 | from atlalign.data import nissl_volume 26 | 27 | 28 | class Volume: 29 | """Class representing mutliple slices. 30 | 31 | Parameters 32 | ---------- 33 | sn : list 34 | List of section numbers. 35 | 36 | mov_imgs : list 37 | List of np.ndarrays representing the moving images corresponding to the `sn`. 38 | 39 | dvfs : list 40 | List of displacement fields corresponding to the `sn`. 41 | """ 42 | 43 | def __init__(self, sn, mov_imgs, dvfs): 44 | # initial checks 45 | if not len(sn) == len(mov_imgs) == len(dvfs): 46 | raise ValueError("All the input lists need to have the same length") 47 | 48 | if len(set(sn)) != len(sn): 49 | raise ValueError("There are duplicate section numbers.") 50 | 51 | if not all([0 <= x < 528 for x in sn]): 52 | raise ValueError("All section numbers must lie in [0, 528).") 53 | 54 | self.sn = sn 55 | self.mov_imgs = mov_imgs 56 | self.dvfs = dvfs 57 | 58 | # attributes 59 | self.ref_imgs = nissl_volume()[self.sn, ..., 0] 60 | self.sn_to_ix = [ 61 | None if x not in self.sn else self.sn.index(x) for x in range(528) 62 | ] 63 | self.reg_imgs = self._warp() 64 | 65 | @property 66 | def sorted_dvfs(self): 67 | """Return displacement fields sorted by the coronal section.""" 68 | sorted_sn = sorted(self.sn) 69 | 70 | return [self.dvfs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn 71 | 72 | @property 73 | def sorted_mov(self): 74 | """Return moving images as sorted by the coronal section.""" 75 | sorted_sn = sorted(self.sn) 76 | 77 | return [self.mov_imgs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn 78 | 79 | @property 80 | def sorted_ref(self): 81 | """Return reference images as sorted by the coronal section.""" 82 | sorted_sn = sorted(self.sn) 83 | 84 | return [self.ref_imgs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn 85 | 86 | @property 87 | def sorted_reg(self): 88 | """Return registered images as sorted by the coronal section.""" 89 | sorted_sn = sorted(self.sn) 90 | 91 | return [self.reg_imgs[self.sn_to_ix[s]] for s in sorted_sn], sorted_sn 92 | 93 | def _warp(self): 94 | """Warp the moving images to get registered ones.""" 95 | return [df.warp(img) for df, img in zip(self.dvfs, self.mov_imgs)] 96 | 97 | def __getitem__(self, key): 98 | """Get all relevant data for a specified section. 99 | 100 | Parameters 101 | ---------- 102 | key : int 103 | Section number to query. 104 | 105 | Returns 106 | ------- 107 | ref_img : np.ndarray 108 | Reference image. 109 | 110 | mov_img : np.ndarray 111 | Moving image. 112 | 113 | reg_img : np.ndarray 114 | Registered image. 115 | 116 | df : DisplacementField 117 | Displacement field (mov2reg). 118 | """ 119 | if self.sn_to_ix[key] is None: 120 | raise KeyError("The section {} not found".format(key)) 121 | 122 | ix = self.sn_to_ix[key] 123 | 124 | return self.ref_imgs[ix], self.mov_imgs[ix], self.reg_imgs[ix], self.dvfs[ix] 125 | 126 | 127 | class GappedVolume: 128 | """Volume containing gaps. 129 | 130 | Parameters 131 | ---------- 132 | sn : list 133 | List of section numbers. Note that not required to be ordered. 134 | 135 | imgs : np.ndarray or list 136 | Internally converted to list of grayscale images of same shape representing different coronal sections. 137 | Order corresponds to the one in `sn`. 138 | 139 | """ 140 | 141 | def __init__(self, sn, imgs): 142 | 143 | if isinstance(imgs, np.ndarray): 144 | # turn into a list 145 | imgs = np.squeeze(imgs) 146 | imgs = [imgs[i] for i in range(len(imgs))] 147 | 148 | # checks 149 | if len(sn) != len(imgs): 150 | raise ValueError("Inconsitent lenghts") 151 | 152 | if len({img.shape for img in imgs}) != 1: 153 | raise ValueError("All the images need to have the same shape") 154 | 155 | if len(sn) != len(set(sn)): 156 | raise ValueError("There are duplicates in section numbers.") 157 | 158 | self.sn = sn 159 | self.imgs = imgs 160 | 161 | self.shape = imgs[0].shape 162 | 163 | 164 | class CoronalInterpolator: 165 | """Interpolator that works pixel by pixel in the coronal dimension.""" 166 | 167 | def __init__(self, kind="linear", fill_value=0, bounds_error=False): 168 | """Construct.""" 169 | self.kind = kind 170 | self.fill_value = fill_value 171 | self.bounds_error = bounds_error 172 | 173 | def interpolate(self, gv): 174 | """Interpolate. 175 | 176 | Note that some section images might have pixels equal to np.nan. In this case these pixels are skipped in the 177 | interpolation. 178 | 179 | Parameters 180 | ---------- 181 | gv : GappedVolume 182 | Instance of the ``GappedVolume`` to be interpolated. 183 | 184 | Returns 185 | ------- 186 | final_volume : np.ndarray 187 | Array of shape (528, 320, 456) that holds the entire interpolated volume without gaps. 188 | 189 | """ 190 | grid = np.array(range(528)) 191 | final_volume = np.empty((len(grid), *gv.shape)) 192 | 193 | for r in range(gv.shape[0]): 194 | for c in range(gv.shape[1]): 195 | x_pixel, y_pixel = zip( 196 | *[ 197 | (s, img[r, c]) 198 | for s, img in zip(gv.sn, gv.imgs) 199 | if not np.isnan(img[r, c]) 200 | ] 201 | ) 202 | 203 | f = scipy.interpolate.interp1d( 204 | x_pixel, 205 | y_pixel, 206 | kind=self.kind, 207 | bounds_error=self.bounds_error, 208 | fill_value=self.fill_value, 209 | ) 210 | try: 211 | final_volume[:, r, c] = f(grid) 212 | except Exception as e: 213 | print(e) 214 | 215 | return final_volume 216 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_images/affine_simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/affine_simple.png -------------------------------------------------------------------------------- /docs/_images/anchoring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/anchoring.png -------------------------------------------------------------------------------- /docs/_images/annot_warping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/annot_warping.png -------------------------------------------------------------------------------- /docs/_images/antspy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/antspy.png -------------------------------------------------------------------------------- /docs/_images/aug_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/aug_pipeline.png -------------------------------------------------------------------------------- /docs/_images/clipped_vd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/clipped_vd.png -------------------------------------------------------------------------------- /docs/_images/composition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/composition.png -------------------------------------------------------------------------------- /docs/_images/control_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/control_points.png -------------------------------------------------------------------------------- /docs/_images/coronal_interpolator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/coronal_interpolator.png -------------------------------------------------------------------------------- /docs/_images/edge_stretching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/edge_stretching.png -------------------------------------------------------------------------------- /docs/_images/evaluation_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/evaluation_metrics.png -------------------------------------------------------------------------------- /docs/_images/example_augmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/example_augmentation.png -------------------------------------------------------------------------------- /docs/_images/feature_based.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/feature_based.png -------------------------------------------------------------------------------- /docs/_images/image_registration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/image_registration.png -------------------------------------------------------------------------------- /docs/_images/image_registration_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/image_registration_2.png -------------------------------------------------------------------------------- /docs/_images/int_augmentations.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/int_augmentations.gif -------------------------------------------------------------------------------- /docs/_images/inverse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/inverse.png -------------------------------------------------------------------------------- /docs/_images/labeling_tool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/labeling_tool.png -------------------------------------------------------------------------------- /docs/_images/metrics_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/metrics_overview.png -------------------------------------------------------------------------------- /docs/_images/resizing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/resizing.png -------------------------------------------------------------------------------- /docs/_images/typical_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/typical_dataset.png -------------------------------------------------------------------------------- /docs/_images/warping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_images/warping.png -------------------------------------------------------------------------------- /docs/_static/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/_static/.keep -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | import atlalign 17 | 18 | # sys.path.insert(0, os.path.abspath('.')) 19 | print(sys.path) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'Atlas Alignment' 24 | author = 'Blue Brain Project, EPFL' 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = ['sphinx.ext.autodoc', 32 | 'sphinx.ext.doctest', 33 | 'sphinx.ext.napoleon', 34 | 'sphinx.ext.viewcode' 35 | ] 36 | 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ['_templates'] 40 | 41 | # List of patterns, relative to source directory, that match files and 42 | # directories to ignore when looking for source files. 43 | # This pattern also affects html_static_path and html_extra_path. 44 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | html_theme = 'sphinx-bluebrain-theme' 52 | html_title = 'Atlas Alignment' 53 | version = atlalign.__version__ 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | html_static_path = ['_static'] 59 | 60 | # Do not mention module names 61 | add_module_names = False 62 | 63 | # Blue brain theme specific 64 | html_show_sourceling = False 65 | -------------------------------------------------------------------------------- /docs/generate_metadata.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import datetime 3 | 4 | import atlalign 5 | 6 | metadata_template = \ 7 | """--- 8 | packageurl: https://github.com/BlueBrain/atlas_alignment 9 | major: {major_version} 10 | description: Image registration with deep learning 11 | repository: https://github.com/BlueBrain/atlas_alignment 12 | externaldoc: https://bbpteam.epfl.ch/documentation/a.html#atlas-alignment 13 | updated: {date} 14 | maintainers: Jan Krepl 15 | name: Atlas Alignment 16 | license: BBP-internal-confidential 17 | issuesurl: https://github.com/BlueBrain/atlas_alignment 18 | version: {version} 19 | contributors: Jan Krepl 20 | minor: {minor_version} 21 | --- 22 | """ 23 | 24 | file_directory = pathlib.Path(__file__).parent.absolute() 25 | metadata_path = file_directory / 'metadata.md' 26 | 27 | version = atlalign.__version__ 28 | major_version = version.split('.')[0] 29 | minor_version = version.split('.')[1] 30 | date = datetime.datetime.now().strftime("%d/%m/%y") 31 | 32 | metadata_instance = metadata_template.format(version=version, 33 | major_version=major_version, 34 | minor_version=minor_version, 35 | date=date) 36 | 37 | with metadata_path.open('w') as f: 38 | f.write(metadata_instance) 39 | 40 | print('Finished') 41 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Atlas Alignment documentation master file, created by 2 | sphinx-quickstart on Wed Oct 30 16:22:27 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Atlas Alignment 7 | ================ 8 | The goal of this project is to facilitate the registration of gene expression images to a reference (Nissl) atlas. 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Contents: 13 | 14 | source/installation 15 | source/image_registration 16 | source/building_blocks 17 | source/datasets 18 | source/intensity 19 | source/labeling_tool 20 | source/deep_learning_data 21 | source/deep_learning_training 22 | source/deep_learning_inference 23 | source/evaluation 24 | source/3d_interpolation 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :caption: API Reference 29 | 30 | source/api/modules 31 | 32 | 33 | 34 | Indices and tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | -------------------------------------------------------------------------------- /docs/source/3d_interpolation.rst: -------------------------------------------------------------------------------- 1 | 3D Interpolation 2 | ================ 3 | Once we register consecutive 2D slices the next step is to think of creating 3D volumes. This is especially relevant 4 | for data from the **Allen Brain Institute** that come from experiments with a very specific design. 5 | 6 | There are multiple ways how to perform this interpolation and some of these are contained in the :code:`atlalign.volume` 7 | module: 8 | 9 | - :code:`CoronalInterpolator` 10 | 11 | CoronalInterpolator 12 | ------------------- 13 | The :code:`CoronalInterpolator` turns the entire problem into a 1D function interpolation. See below the sketch: 14 | 15 | .. image:: ../_images/coronal_interpolator.png 16 | :width: 400 17 | :height: 300 18 | :alt: Coronal interpolator 19 | :align: center 20 | 21 | For each pixel separately one interpolates over missing sections based purely on the corresponding pixels in all 22 | of the existing sections. 23 | 24 | .. testcode:: 25 | 26 | import numpy as np 27 | 28 | from atlalign.volume import CoronalInterpolator, GappedVolume 29 | 30 | n_sections = 55 31 | shape = (30, 40) 32 | 33 | sn = np.random.choice(np.arange(527), size=n_sections, replace=False) 34 | imgs = np.random.random((n_sections, *shape)) 35 | 36 | gv = GappedVolume(sn, imgs) 37 | ci = CoronalInterpolator() 38 | 39 | final_volume = ci.interpolate(gv) 40 | print(final_volume.shape) # (528, shape[0], shape[1]) 41 | 42 | .. testoutput:: 43 | :hide: 44 | :options: -ELLIPSIS, +NORMALIZE_WHITESPACE 45 | 46 | (528, 30, 40) 47 | -------------------------------------------------------------------------------- /docs/source/api/atlalign.label.rst: -------------------------------------------------------------------------------- 1 | atlalign.label package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | atlalign.label.cli module 8 | ------------------------- 9 | 10 | .. automodule:: atlalign.label.cli 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | atlalign.label.io module 16 | ------------------------ 17 | 18 | .. automodule:: atlalign.label.io 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | atlalign.label.new\_GUI module 24 | ------------------------------ 25 | 26 | .. automodule:: atlalign.label.new_GUI 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: atlalign.label 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/api/atlalign.ml_utils.rst: -------------------------------------------------------------------------------- 1 | atlalign.ml\_utils package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | atlalign.ml\_utils.augmentation module 8 | -------------------------------------- 9 | 10 | .. automodule:: atlalign.ml_utils.augmentation 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | atlalign.ml\_utils.callbacks module 16 | ----------------------------------- 17 | 18 | .. automodule:: atlalign.ml_utils.callbacks 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | atlalign.ml\_utils.io module 24 | ---------------------------- 25 | 26 | .. automodule:: atlalign.ml_utils.io 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | atlalign.ml\_utils.layers module 32 | -------------------------------- 33 | 34 | .. automodule:: atlalign.ml_utils.layers 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | atlalign.ml\_utils.losses module 40 | -------------------------------- 41 | 42 | .. automodule:: atlalign.ml_utils.losses 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | atlalign.ml\_utils.models module 48 | -------------------------------- 49 | 50 | .. automodule:: atlalign.ml_utils.models 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: atlalign.ml_utils 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /docs/source/api/atlalign.non_ml.rst: -------------------------------------------------------------------------------- 1 | atlalign.non\_ml package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | atlalign.non\_ml.intensity module 8 | --------------------------------- 9 | 10 | .. automodule:: atlalign.non_ml.intensity 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: atlalign.non_ml 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/api/atlalign.rst: -------------------------------------------------------------------------------- 1 | atlalign package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | atlalign.label 11 | atlalign.ml_utils 12 | atlalign.non_ml 13 | 14 | Submodules 15 | ---------- 16 | 17 | atlalign.augmentations module 18 | ----------------------------- 19 | 20 | .. automodule:: atlalign.augmentations 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | atlalign.base module 26 | -------------------- 27 | 28 | .. automodule:: atlalign.base 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | atlalign.data module 34 | -------------------- 35 | 36 | .. automodule:: atlalign.data 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | atlalign.metrics module 42 | ----------------------- 43 | 44 | .. automodule:: atlalign.metrics 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | atlalign.nn module 50 | ------------------ 51 | 52 | .. automodule:: atlalign.nn 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | 57 | atlalign.utils module 58 | --------------------- 59 | 60 | .. automodule:: atlalign.utils 61 | :members: 62 | :undoc-members: 63 | :show-inheritance: 64 | 65 | atlalign.visualization module 66 | ----------------------------- 67 | 68 | .. automodule:: atlalign.visualization 69 | :members: 70 | :undoc-members: 71 | :show-inheritance: 72 | 73 | atlalign.volume module 74 | ---------------------- 75 | 76 | .. automodule:: atlalign.volume 77 | :members: 78 | :undoc-members: 79 | :show-inheritance: 80 | 81 | atlalign.zoo module 82 | ------------------- 83 | 84 | .. automodule:: atlalign.zoo 85 | :members: 86 | :undoc-members: 87 | :show-inheritance: 88 | 89 | Module contents 90 | --------------- 91 | 92 | .. automodule:: atlalign 93 | :members: 94 | :undoc-members: 95 | :show-inheritance: 96 | -------------------------------------------------------------------------------- /docs/source/api/modules.rst: -------------------------------------------------------------------------------- 1 | atlalign 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | atlalign 8 | -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | .. _datasets: 2 | 3 | Datasets 4 | ======== 5 | 6 | We include multiple dataset loading utilities. Note that for some functions 7 | implemented in the :code:`atlalign.data` module we assume that the user has 8 | the underlying raw data locally. 9 | 10 | Nissl volume 11 | ------------ 12 | Contains the 25 microns reference volume used within the entire project. The corresponding function 13 | is :code:`nissl_volume` that returns a (528, 320, 456, 1) numpy array. 14 | 15 | 16 | Annotation volume 17 | ----------------- 18 | Contains per pixel segmentation/annotation of the entire reference atlas. The corresponding 19 | function is :code:`annotation_volume` that returns a (528, 320, 456) numpy array. Note 20 | that one can use the corresponding :code:`segmentation_collapsing_labels` function 21 | to get the tree-like hierarchy of the classes. 22 | 23 | 24 | Manual registration 25 | ------------------- 26 | This dataset comes from the labeling tool that extracts displacement fields. The corresponding function is 27 | :code:`manual_registration`. 28 | Available genes: 29 | 30 | - **Calb1** 31 | - **Calb2** 32 | - **Cck** 33 | - **Npy** 34 | - **Pvalb** 35 | - **Sst** 36 | - **Vip** 37 | 38 | .. code-block:: python 39 | 40 | from atlalign.data import manual_registration 41 | 42 | res = manual_registration() 43 | 44 | assert set(res.keys()) == {'dataset_id', 'deltas_xy', 'image_id', 'img', 'inv_deltas_xy', 'p'} 45 | assert len(res['image_id']) == 278 46 | 47 | The returned dictionary contains the following keys: 48 | 49 | - :code:`dataset_id` - unique id of the section dataset 50 | - :code:`deltas_xy` - array of shape (320, 456, 2) where the last dimension represents the x (resp y) deltas of the transformation 51 | - :code:`image_id` - unique id of the section image 52 | - :code:`img` - moving image of shape (320, 456) that was preregistered with the Allen API 53 | - :code:`inv_deltas_xy` - same as :code:`deltas_xy` but represents the inverse transformation 54 | - :code:`p` - coronal coordinate in microns [0, 13200] 55 | 56 | To perform the registration instantiate :code:`atlalign.base.DisplacementField` using the :code:`deltas_xy` and warp the 57 | :code:`img` with it. 58 | 59 | .. code-block:: python 60 | 61 | from atlalign.base import DisplacementField 62 | from atlalign.data import manual_registration 63 | 64 | import numpy as np 65 | 66 | res = manual_registration() 67 | i = 10 68 | delta_x = res['deltas_xy'][i, ..., 0] 69 | delta_y = res['deltas_xy'][i, ..., 1] 70 | img_mov = res['img'][i] 71 | 72 | df = DisplacementField(delta_x, delta_y) 73 | img_reg = df.warp(img_mov) 74 | 75 | For more details on :code:`atlalign.base.DisplacementField` see :ref:`building_blocks`. 76 | 77 | Dummy 78 | ----- 79 | Artificially generated datasets. 80 | 81 | Rectangles 82 | ~~~~~~~~~~ 83 | Rectangles with stripes of different intensities. Corresponding function - :code:`rectangles`. 84 | 85 | Circles 86 | ~~~~~~~ 87 | Circles with inner circles of different intensities. Corresponding function - :code:`circles`. 88 | 89 | -------------------------------------------------------------------------------- /docs/source/deep_learning_data.rst: -------------------------------------------------------------------------------- 1 | .. _dl_data: 2 | 3 | 4 | Deep Learning - Generating a dataset 5 | ==================================== 6 | This section describes how to create a dataset for supervised learning. Note that we already described how to easily 7 | load existing datasets in :ref:`datasets`. However, in this chapter we will discuss in detail how to perform 8 | **augmentations** on these or similar datasets. 9 | 10 | 11 | Augmentation is a common strategy in deep learning. The goal is to use an existing dataset and make it larger 12 | (possibly infinitely) by altering both the inputs and the targets in some sensible way. 13 | 14 | Generally, we will distinguish two types of augmentations 15 | 16 | 1. **Intensity** - change the pixel intensities 17 | 2. **Geometric** - change the geometric structure of the underlying image 18 | 19 | 20 | The reason why in our use case we split augmentations into two groups is very simple : the **intensity** augmentations 21 | do not change the labels, whereas the **geometric** ones do change our labels. 22 | 23 | 24 | Geometric augmentations 25 | ----------------------- 26 | The geometric augmentations are a major part of :code:`atlalign` functionality. All the augmentations can be found 27 | in :code:`atlalign.zoo` module. They are easily accessible to the user via the :code:`generate` class method. 28 | 29 | - :code:`affine` 30 | - :code:`affine_simple` 31 | - :code:`control_points` 32 | - :code:`edge_stretching` 33 | - :code:`projective` 34 | 35 | 36 | Affine simple 37 | ~~~~~~~~~~~~~ 38 | 39 | .. testcode:: 40 | 41 | import numpy as np 42 | import matplotlib.pyplot as plt 43 | 44 | from atlalign.base import DisplacementField 45 | from atlalign.data import rectangles 46 | 47 | shape=(320, 456) 48 | 49 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31)) 50 | df = DisplacementField.generate(shape, 51 | approach='affine_simple', 52 | scale_x=1.9, 53 | scale_y=1, 54 | translation_x=-300, 55 | translation_y=0, 56 | rotation=0.2 57 | ) 58 | 59 | img_aug = df.warp(img) 60 | 61 | _, (ax_orig, ax_aug) = plt.subplots(1, 2, figsize=(10, 14)) 62 | ax_orig.imshow(img) 63 | ax_aug.imshow(img_aug) 64 | 65 | .. image:: ../_images/affine_simple.png 66 | :width: 600 67 | :alt: Affine simple 68 | :align: center 69 | 70 | Control points 71 | ~~~~~~~~~~~~~~ 72 | Control points is a generic augmentation that gives the user the possibility to specify displacement only on 73 | a selected set of control points. For the remaining pixels the displacement will be intepolation. 74 | 75 | .. testcode:: 76 | 77 | import matplotlib.pyplot as plt 78 | 79 | from atlalign.base import DisplacementField 80 | from atlalign.data import rectangles 81 | 82 | shape = (320, 456) 83 | 84 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31)) 85 | 86 | 87 | points = np.array([[200, 150]]) 88 | 89 | values_delta_x = np.array([-100]) 90 | values_delta_y = np.array([0]) 91 | 92 | df = DisplacementField.generate(shape, 93 | approach='control_points', 94 | points=points, 95 | values_delta_x=values_delta_x, 96 | values_delta_y=values_delta_y, 97 | interpolation_method='rbf') 98 | 99 | img_aug = df.warp(img) 100 | 101 | _, (ax_orig, ax_aug) = plt.subplots(1, 2, figsize=(10, 14)) 102 | ax_orig.imshow(img) 103 | ax_aug.imshow(img_aug) 104 | 105 | .. image:: ../_images/control_points.png 106 | :width: 600 107 | :alt: Affine simple 108 | :align: center 109 | 110 | 111 | 112 | Edge stretching 113 | ~~~~~~~~~~~~~~~ 114 | Edge stretching using :code:`control_points` in the background. However, instead of requiring the user to specify 115 | these points manually the user simply passes a mask array of edges. The algorithm then 116 | selects randomly :code:`n_perturbation_points` points out of the edges and randomly displaces them. Note that 117 | the :code:`interpolation_method='rbf'` and :code:`interpolator_kwargs={'function': 'linear'}` gives the nicest 118 | results. 119 | 120 | .. testcode:: 121 | 122 | import matplotlib.pyplot as plt 123 | 124 | from atlalign.base import DisplacementField 125 | from atlalign.data import rectangles 126 | 127 | from skimage.feature import canny 128 | 129 | shape = (320, 456) 130 | 131 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31)) 132 | 133 | 134 | edge_mask = canny(img) 135 | 136 | df = DisplacementField.generate(shape, 137 | approach='edge_stretching', 138 | edge_mask=edge_mask, 139 | interpolation_method='rbf', 140 | interpolator_kwargs={'function': 'linear'}, 141 | n_perturbation_points=5) 142 | 143 | img_aug = df.warp(img) 144 | 145 | _, (ax_orig, ax_aug, ax_edges) = plt.subplots(1, 3, figsize=(10, 14)) 146 | ax_orig.imshow(img) 147 | ax_aug.imshow(img_aug) 148 | ax_edges.imshow(edge_mask) 149 | 150 | .. image:: ../_images/edge_stretching.png 151 | :width: 600 152 | :alt: Affine simple 153 | :align: center 154 | 155 | 156 | Intensity augmentations 157 | ----------------------- 158 | Intensity augmentations do not change the label (displacement field). As opposed to the geometric ones, we fully 159 | delegate these augmentations to a 3rd party package - :code:`imgaug`. For more details see the official documentation. 160 | 161 | The user can use some preset augmentor pipeplines in :code:`atlalign.ml_utils.augmentation` or create his own. 162 | See below an example of using a preset augmentor together with a small animation showing 10 random augmentations. 163 | 164 | 165 | .. testcode:: 166 | 167 | from atlalign.ml_utils import augmenter_1 168 | 169 | img = np.squeeze(rectangles(n_samples=1, shape=shape, height=200, width=150, random_state=31)) 170 | aug = augmenter_1() 171 | 172 | img_aug = aug.augment_image(img) 173 | 174 | .. image:: ../_images/int_augmentations.gif 175 | :width: 00 176 | :alt: Affine simple 177 | :align: center 178 | 179 | Make sure that :code:`imgaug` pipelines do not contain any geometric transformation. 180 | 181 | 182 | Putting things together 183 | ----------------------- 184 | With the tools described above and the ones from :ref:`building_blocks` we can significantly increase the size of 185 | our supervised learning datasets. Note that in general we want to augment the moving image with both the 186 | intensity and geometric augmentations. The reference image stays the same or only intensity augmentation is applied. 187 | 188 | To better demonstrate the geometric augmentation logic for real data, we refer the reader to the sketch below. 189 | 190 | .. image:: ../_images/aug_pipeline.png 191 | :width: 700 192 | :alt: Affine simple 193 | :align: center 194 | 195 | We assume that at the beginning we are given the **moving image** and transformation that registers this image - 196 | :code:`mov2reg` (in green). Note that if only registered images are provided this is equivalent to setting 197 | :code:`mov2reg` equal to an identity mapping. The actual augmentation is captured by :code:`mov2art` (in yellow). 198 | Once the user specifies it (randomly generates), then :code:`atlalign` can imply the :code:`art2reg`. How? 199 | 200 | 1. Invert :code:`mov2art` to obtain :code:`art2mov`. 201 | 2. Compose :code:`art2mov` and :code:`mov2reg` 202 | 203 | 204 | One clearly sees that the final transformation :code:`art2reg` will be a combination of the :code:`mov2reg` and :code:`art2mov`. 205 | Ideally, we want to make sure that these transformations are as nice as possible - differentiable and invertible. 206 | 207 | Note that one can inspect :code:`df.jacobian` to programatically determine how smooth the transformation is. 208 | Specifically, the pixels with nonpositive jacobian represent possible artifacts. Using :code:`edge_stretching` or 209 | similar it can happen occasionally that the transformations are ugly. 210 | 211 | See below an end-to-end example. 212 | 213 | Example 214 | ------- 215 | 216 | .. code-block:: python 217 | 218 | import matplotlib.pyplot as plt 219 | import numpy as np 220 | from skimage.feature import canny 221 | from skimage.util import img_as_float32 222 | 223 | from atlalign.base import DisplacementField 224 | from atlalign.data import new_registration 225 | from atlalign.visualization import create_grid 226 | 227 | 228 | # helper function 229 | def generate_mov2art(img_mov, anchor=True): 230 | """Generate geometric augmentation and its inverse.""" 231 | shape = img_mov.shape 232 | img_mov_float = img_as_float32(img_mov) 233 | edge_mask = canny(img_mov_float) 234 | mov2art = DisplacementField.generate(shape, 235 | approach='edge_stretching', 236 | edge_mask=edge_mask, 237 | interpolation_method='rbf', 238 | interpolator_kwargs={'function': 'linear'}, 239 | n_perturbation_points=5) 240 | 241 | if anchor: 242 | mov2art = mov2art.anchor() 243 | 244 | art2mov = mov2art.pseudo_inverse() 245 | 246 | return mov2art, art2mov 247 | 248 | # load datast 249 | orig_dataset = new_registration() 250 | i = 10 251 | 252 | # load existing data 253 | img_mov = orig_dataset['img'][i] 254 | mov2reg = DisplacementField(orig_dataset['deltas_xy'][i,..., 0], orig_dataset['deltas_xy'][i,..., 1]) 255 | img_grid = create_grid(img_mov.shape) 256 | 257 | # generate mov2art 258 | mov2art, art2mov = generate_mov2art(img_mov) 259 | img_art = mov2art.warp(img_mov) 260 | 261 | # numerically approximate composition 262 | art2reg = art2mov(mov2reg, border_mode='reflect') 263 | 264 | 265 | # Plotting 266 | _, ((ax_mov, ax_art), 267 | (ax_reg_mov, ax_reg_art), 268 | ( ax_reg_mov_grid, ax_reg_art_grid)) = plt.subplots(3, 2, figsize=(15, 10)) 269 | 270 | 271 | ax_mov.imshow(img_mov) 272 | ax_mov.set_title('Moving') 273 | ax_art.imshow(img_art) 274 | ax_art.set_title('Artificial') 275 | 276 | ax_reg_mov.imshow(mov2reg.warp(img_mov)) 277 | ax_reg_mov.set_title('Registered (starting from moving)') 278 | ax_reg_art.imshow(art2reg.warp(img_art)) 279 | ax_reg_art.set_title('Registered (starting from artifical)') 280 | 281 | ax_reg_mov_grid.imshow(mov2reg.warp(img_grid)) 282 | ax_reg_art_grid.imshow(art2reg.warp(img_grid)) 283 | 284 | 285 | .. image:: ../_images/example_augmentation.png 286 | :width: 600 287 | :alt: Augmentation example 288 | :align: center 289 | 290 | Since generation of the augmentations or finding the inverse numerically might be 291 | slow we highly recommend precomputing everything in advance (before training) 292 | and storing it in a :code:`.h5` file. See the :code:`atlalign.augmentations.py` module 293 | that implements a similar strategy. Note that the training logic requires 294 | the data to be stored in a :code:`.h5` file. 295 | -------------------------------------------------------------------------------- /docs/source/deep_learning_inference.rst: -------------------------------------------------------------------------------- 1 | Deep Learning - Inference 2 | ========================= 3 | 4 | Loading model 5 | ------------- 6 | To load a pretrained network one should use :code:`atlalign.ml_utils.load_model`. 7 | Note that it loads all possible custom layers in the background so that the user does not have 8 | to worry about it. 9 | 10 | .. code-block:: python 11 | 12 | from atlalign.ml_utils import load_model 13 | 14 | path_1 = 'path/to/a/folder' # inside of this folder .json (architecture) and .h5 (weights) 15 | path_2 = 'path/to/a/file.h5' # architecture and weights not separated + additional info (loss and optimizer) 16 | 17 | model_path_1 = load_model(path_1) 18 | model_path_2 = load_model(path_2, compile=True) 19 | 20 | Merging 21 | ------- 22 | To merge a global and a local alignment network one can perform the composition via the :code:`__call__` method 23 | of :code:`atlalign.base.DisplacementField` on a per sample basis. A better approach is to use a custom keras layer 24 | implementing composition. For the latter option we provide a utility function :code:`atlalign.ml_utils.merge_global_local`. 25 | 26 | .. code-block:: python 27 | 28 | from atlalign.ml_utils import load_model, merge_global_local 29 | 30 | path_global = 'global_model.h5' 31 | path_local = 'local_model.h5' 32 | 33 | model_global = load_model(path_global) 34 | model_local = load_model(path_local) 35 | 36 | model_merged = merge_global_local(model_global, model_local) 37 | 38 | Forward pass 39 | ------------ 40 | Performing the actual inference is extremely simple. Please review the :ref:`dl_training.supervised_generator` to 41 | understand the shape of the expected input. To quickly summarize (for non-inverse models) the user only needs to create 42 | a 4D array of the following shape 43 | 44 | .. code-block:: python 45 | 46 | (batch_size, height=320, width=456, depth=2) 47 | 48 | The last dimension is simply a stack of the :code:`img_ref` and :code:`img_mov`. 49 | 50 | .. code-block:: python 51 | 52 | import os 53 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' # to enable GPU 54 | 55 | import numpy as np 56 | 57 | from atlalign.base import DisplacementField 58 | from atlalign.ml_utils import load_model 59 | 60 | batch_size = 32 # how many samples are grouped together at inference time 61 | 62 | model = load_model('path/to/model.h5') 63 | X = np.random.random((200, 320, 456, 2)) 64 | 65 | [reg_images, deltas_xy] = model.predict(X, batch_size=batch_size) 66 | 67 | # one can also create instances of DisplacementFields to perform many other tasks 68 | dfs = [DisplacementField(deltas_xy[i, ..., 0], deltas_xy[i, ..., 1]) for i in range(len(X))] 69 | -------------------------------------------------------------------------------- /docs/source/evaluation.rst: -------------------------------------------------------------------------------- 1 | Metrics and Evaluation 2 | ====================== 3 | Assuming we have the ground truth labels (necessary for supervised learning) we can use multiple evaluation metrics. 4 | In general they can be grouped into 3 categories based on what they compare: 5 | 6 | 7 | - **Image similarity** 8 | - **Displacement field** 9 | - **Segmentation** 10 | 11 | .. image:: ../_images/evaluation_metrics.png 12 | :width: 600 13 | :alt: Evaluation metrics 14 | :align: center 15 | 16 | 17 | All these metrics are implemented in :code:`atlalign.metrics`. A subset of them is also available as drop-in 18 | losses for deep learning in :code:`atlalign.ml_utils.losses`. 19 | 20 | The following sub-sections list the available metrics in each of the three categories. 21 | They are all part of the ``atlalign.metrics`` module and have a common interface:: 22 | 23 | atlalign.metrics.(y_true, y_pred, **kwargs) 24 | 25 | The parameters ``y_true`` and ``y_pred`` are pairs of images, displacement fields, 26 | or segmentation maps. Multiple pairs of images can be processed at once by stacking 27 | them along the first dimension, so that ``y_true`` and ``y_pred`` have the shape 28 | ``(n_images, ...)``. 29 | 30 | Some metrics have optional keyword arguments that differ from metric to metric, 31 | the API reference for more details. 32 | 33 | Image Similarity Metrics 34 | ------------------------ 35 | 36 | Loss-like (the smaller the better): 37 | 38 | - ``mse_img`` -- mean squared error 39 | - ``mae_img`` -- mean absolute error 40 | - ``demons_img`` -- ANTsPy's demons metric 41 | - ``perceptual_loss_img`` -- perceptual loss 42 | 43 | Similarity-like (the higher the better): 44 | 45 | - ``psnr_img`` -- peak signal to noise ratio (max = infinity) 46 | - ``cross_correlation_img`` -- image cross correlation (max = 1) 47 | - ``ssmi_img`` -- structural similarity (max = 1) 48 | - ``mi_img`` -- mutual information (max = mutual information with itself) 49 | 50 | Displacement Field Metrics 51 | -------------------------- 52 | 53 | - ``correlation_combined`` -- combined version of correlation 54 | - ``mae_combined`` -- combined version of mean absolute error 55 | - ``mse_combined`` -- combined version of mean squared error 56 | - ``r2_combined`` -- combined version of r2 57 | - ``vector_distance_combined`` -- combined version of vector distance 58 | 59 | Segmentation Metrics 60 | -------------------- 61 | 62 | - ``iou_score`` -- intersection over union (between 0 and 1, the higher the better) 63 | - ``dice_score`` -- dice score (between 0 and 1, the higher the better) 64 | 65 | Compute Many Metrics at Once 66 | ---------------------------- 67 | 68 | To get a comprehensive overview of how specific model performs, we implemented a utility function 69 | :code:`atlalign.metrics.evaluate` that computes multiple metrics at the same time and returns the results 70 | in a :code:`pandas.DataFrame`. 71 | 72 | 73 | .. code-block:: python 74 | 75 | import numpy as np 76 | 77 | from atlalign.metrics import evaluate 78 | 79 | n_samples = 5 80 | shape = (320, 456) 81 | 82 | y_true = np.random.randint(0, 20, size=(n_samples, *shape, 2)) 83 | y_pred = np.random.randint(0, 20, size=(n_samples, *shape, 2)) 84 | 85 | imgs_mov = np.random.random((n_samples, *shape)) 86 | img_ids = np.array(range(n_samples)) 87 | dataset_ids = np.array(range(n_samples)) 88 | ps = np.linspace(0, 12200, num=n_samples).astype('int') 89 | 90 | _, res_df = evaluate(y_true, 91 | y_pred, 92 | imgs_mov=imgs_mov, 93 | img_ids=img_ids, 94 | ps=ps, 95 | dataset_ids=dataset_ids, 96 | depths=(1, 2, 3, 4, 5)) 97 | 98 | print(res_df.columns) 99 | 100 | .. code-block:: python 101 | 102 | Index(['angular_error_a', 'cross_correlation_img_a', 'dataset_id', 103 | 'iou_depth_1', 'iou_depth_2', 'iou_depth_3', 'iou_depth_4', 104 | 'iou_depth_5', 'jacobian_nonpositive_pixels_a', 105 | 'jacobian_nonpositive_pixels_perc_a', 'mae_img_a', 'mi_img_a', 106 | 'mse_img_a', 'norm_a', 'p', 'psnr_img_a', 'ssmi_img_a', 107 | 'vector_distance_a'], 108 | dtype='object') 109 | -------------------------------------------------------------------------------- /docs/source/image_registration.rst: -------------------------------------------------------------------------------- 1 | Image Registration 101 2 | ====================== 3 | Image registration is a task of geometrically transforming one image (moving) to a domain of another image (reference). 4 | This problem can be seen as two separate steps. 5 | 6 | **1) Prediction of a geometric transformation** 7 | 8 | .. image:: ../_images/image_registration.png 9 | :width: 600 10 | :alt: Inverse 11 | :align: center 12 | 13 | **2) Warping the moving image with the transformation** 14 | 15 | .. image:: ../_images/image_registration_2.png 16 | :width: 600 17 | :alt: Inverse 18 | :align: center 19 | 20 | 21 | There are multiple different algorithms that achieve the first step. This project focuses on supervised **deep learning** 22 | methods however also provides easy interface to other approaches - **feature** and **intensity** based registration. -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | ============ 5 | It is highly recommended to install the project into a new virtual environment. 6 | 7 | Python Requirements 8 | ------------------- 9 | The project is only available for **Python 3.7**. The main reason for this 10 | restriction is an external dependency **ANTsPy** that does 11 | not provide many precompiled wheels on PyPI. 12 | 13 | External Dependencies 14 | --------------------- 15 | Some of the functionalities of :code:`atlalign` depend on the 16 | `TensorFlow implementation of the Learned Perceptual Image Patch Similarity `_. 17 | Unfortunately, the 18 | package is not available on PyPI and must be installed manually as follows. 19 | 20 | .. code-block:: bash 21 | 22 | pip install git+http://github.com/alexlee-gk/lpips-tensorflow.git#egg=lpips_tf 23 | 24 | You can now move on to installing the actual `atlalign` package! 25 | 26 | Installation from PyPI 27 | ---------------------- 28 | The :code:`atlalign` package can be easily installed from PyPI. 29 | 30 | .. code-block:: bash 31 | 32 | pip install atlalign 33 | 34 | Installation from source 35 | ------------------------ 36 | As an alternative to installing from PyPI, if you want to try the latest version 37 | you can also install from source. 38 | 39 | .. code-block:: bash 40 | 41 | pip install git+https://github.com/BlueBrain/atlas_alignment#egg=atlalign 42 | 43 | Development installation 44 | ------------------------ 45 | For development installation one needs additional dependencies grouped in :code:`extras_requires` in the 46 | following way: 47 | 48 | - **dev** - pytest + plugins, flake8, pydocstyle, tox 49 | - **docs** - sphinx 50 | 51 | .. code-block:: bash 52 | 53 | git clone https://github.com/BlueBrain/atlas_alignment 54 | cd atlas_alignment 55 | pip install -e .[dev,docs] 56 | 57 | 58 | Generating documentation 59 | ------------------------ 60 | To generate the documentation make sure you have dependencies from :code:`extras_requires` - :code:`docs`. 61 | 62 | .. code-block:: bash 63 | 64 | cd docs 65 | make clean && make html 66 | 67 | One can view the docs by opening :code:`docs/_build/html/index.html` in a browser. 68 | -------------------------------------------------------------------------------- /docs/source/intensity.rst: -------------------------------------------------------------------------------- 1 | Intensity Based Registration 2 | ============================ 3 | Intensity based registration refers to a set of algorithms that try to solve the registration problem via per sample 4 | maximization of an image similarity metric. 5 | 6 | 7 | ANTsPy 8 | ------ 9 | We provide a very simple interface to an existing registration package caled **ANTsPy**. To find out more about 10 | the original package we refer the reader to 11 | 12 | - `github `_ 13 | - `docs `_ 14 | 15 | 16 | To use **ANTsPy** within :code:`atlalign` one can simply use :code:`atlalign.non_ml.antspy_registration`. See below 17 | a minimal example. 18 | 19 | .. code-block:: python 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | 24 | from atlalign.base import DisplacementField 25 | from atlalign.data import circles 26 | from atlalign.non_ml import antspy_registration 27 | 28 | random_state = 4 29 | shape = (200, 230) 30 | p_drop_pixel = 0.1 31 | 32 | df = DisplacementField.generate(shape, approach='affine_simple', scale_x=1.4) 33 | 34 | img_ref = circles(1, shape, radius=75, n_levels=15, random_state=random_state)[0,..., 0] 35 | img_mov = df.warp(img_ref) 36 | 37 | to_drop = np.random.choice((True, False), p=(p_drop_pixel, 1 - p_drop_pixel), size=shape) 38 | img_ref[to_drop] = 0 39 | 40 | 41 | df_mov2ref_anstpy, _ = antspy_registration(img_ref, img_mov) 42 | df_mov2ref_truth = df.pseudo_inverse() 43 | 44 | img_reg_antspy = df_mov2ref_anstpy.warp(img_mov) 45 | img_reg_truth = df_mov2ref_truth.warp(img_mov) 46 | 47 | fig, ((ax_ref, ax_mov), (ax_reg_antspy, ax_reg_truth)) = plt.subplots(2, 2, figsize=(15, 15)) 48 | ax_ref.imshow(img_ref) 49 | ax_ref.set_axis_off() 50 | ax_ref.set_title('Reference') 51 | 52 | ax_mov.imshow(img_mov) 53 | ax_mov.set_axis_off() 54 | ax_mov.set_title('Moving') 55 | 56 | ax_reg_antspy.imshow(img_reg_antspy) 57 | ax_reg_antspy.set_axis_off() 58 | ax_reg_antspy.set_title('Registered - ANTsPy') 59 | 60 | ax_reg_truth.imshow(img_reg_truth) 61 | ax_reg_truth.set_axis_off() 62 | ax_reg_truth.set_title('Registered - Ground Truth') 63 | 64 | .. image:: ../_images/antspy.png 65 | :width: 600 66 | :alt: ANTsPy 67 | :align: center -------------------------------------------------------------------------------- /docs/source/labeling_tool.rst: -------------------------------------------------------------------------------- 1 | Labeling tool 2 | ============= 3 | The goal of the labeling tool is to allow the user to align manually any two images and extract 4 | the displacement field. 5 | 6 | Launching 7 | --------- 8 | To launch the labeling tool one needs to use the entry point ``label-tool``. 9 | 10 | .. code-block:: bash 11 | 12 | label-tool --help 13 | usage: label-tool [-h] [-s] ref mov output_path 14 | 15 | positional arguments: 16 | ref Either a path to a reference image or a number from [0, 528) 17 | representing the coronal dimension in the nissl stain volume. 18 | mov Path to a moving image. Needs to be of the same shape as 19 | reference. 20 | output_path Folder where the outputs will be stored. 21 | 22 | optional arguments: 23 | -h, --help show this help message and exit 24 | -s, --swap Swap to the moving to reference mode. (default: False) 25 | 26 | Arguments 27 | ~~~~~~~~~ 28 | 29 | 1. ``ref`` - number from 0 - 527 representing coronal section or a path to a custom reference image 30 | 2. ``mov`` - path to the moving image that needs to have the same shape as the reference one 31 | 3. ``output_path`` - path to folder where output is saved 32 | 33 | Options 34 | ~~~~~~~ 35 | 36 | - ``--swap`` - flag that if activated the first landmark of a pair is on moving image otherwise it 37 | is on the reference 38 | - ``--force-grayscale`` - force color images to be converted to grayscale, can be useful for taking 39 | advantage of the colormap selection which does not work for color images. 40 | 41 | Interface 42 | --------- 43 | .. image:: ../_images/labeling_tool.png 44 | :width: 600 45 | :alt: Inverse 46 | :align: center 47 | 48 | The user adds landmark points until they are fully happy with the registration. Once the GUI is 49 | closed the resulting displacement field, the registered and moving images, the key points, and other 50 | metadata are stored into the :code:`OUTPUT_PATH`. 51 | 52 | 1 53 | ~ 54 | The **interactive window** where the user specifies pairs of points that represent the landmarks in 55 | the moving (resp. reference) image. The landmarks in the moving image are represented by **crosses** 56 | and the ones in the reference image by **circles**. To delete a point hover your mouse over a 57 | reference point and press "d". 58 | 59 | 2 60 | ~ 61 | The **result window** that displays the current registered image and the reference. Note that one 62 | can also see the deformation of a regular grid if one clicks on the **Show grid** button. 63 | 64 | 3 65 | ~ 66 | **Sliders** 67 | 68 | - **Alpha Moving/Registered** - Controls the blending together of the reference and moving images. 69 | Value 0 corresponds to only the reference image, value 1 to only the moving image. The slider 70 | remembers two different settings that can be toggled using the space bar. This can be useful for 71 | quickly getting the full view of one of the images. 72 | - **Alpha Ref** - Control the translucency of the reference image. 73 | - **Moving/Registered Threshold** - Remove pixels below a given intensity from the moving image. 74 | - **Reference Threshold** - Remove pixels below a given intensity from the reference image. 75 | 76 | **Buttons** 77 | 78 | - **Reset** - Deletes all registration points. 79 | - **Symmetric Registration** - Toggle symmetry registration where all keypoints are automatically 80 | mirrored in the left/right direction. 81 | - **Change Order** - Swap the moving and reference image order. 82 | - **Show Arrows** - Toggle the visibility of lines between keypoints. 83 | - **Show Grid** - Show regular grid warped by the current displacement. Red spots represent regions 84 | that are not invertible. 85 | 86 | 4 87 | ~ 88 | **Color maps** for the reference and moving image. 89 | 90 | 5 91 | ~ 92 | **Overview statistics** 93 | 94 | - **Transformation quality** - If lower than 100% then displacement contains folds. Check the grid 95 | to see where exactly. 96 | - **Average displacement** - Average displacement size 97 | 98 | **Keyboard shortcuts** 99 | 100 | - **"a"** - pan 101 | - **"s"** - zoom 102 | - **"d"** - delete a keypoint 103 | - **"f"** - reset zoom 104 | - **space** - toggle the reference/moving blending alpha value 105 | 106 | Note that after zooming/panning the respective shortcut key needs to be pressed again to get back 107 | to the keypoint addition mode. 108 | 109 | 6 110 | ~ 111 | **Interpolation window**. The first column represents the general algorithm 112 | 113 | - **griddata** - Delaunay triangulation followed by affine interpolation for each triangle 114 | - **rbf** - Interpolation using basis function 115 | 116 | The second column is only used for the **rbf** interpolation and it specifies the kernel. 117 | 118 | FAQ 119 | --- 120 | 121 | - **Can the moving and reference image have different sizes?** No, make sure they are the same and 122 | one can resize the displacement field after registration 123 | - **How do I save the displacement?** Simply by closing the GUI 124 | - **Can I delete some landmark pairs?** Yes, point at the reference landmark and press space bar 125 | -------------------------------------------------------------------------------- /docs/source/logo/Atlas_Alignment_banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/docs/source/logo/Atlas_Alignment_banner.jpg -------------------------------------------------------------------------------- /scripts/experiment_1.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from atlalign.base import DisplacementField 7 | from atlalign.data import ( 8 | annotation_volume, 9 | manual_registration, 10 | nissl_volume, 11 | segmentation_collapsing_labels, 12 | ) 13 | from atlalign.metrics import evaluate_single 14 | from atlalign.ml_utils import load_model 15 | 16 | # Define all relevant paths 17 | cache_dir = pathlib.Path.home() / ".atlalign" 18 | 19 | annotation_path = cache_dir / "annotation.npy" 20 | h5_path = cache_dir / "manual_registration.h5" 21 | labels_path = cache_dir / "annotation_hierarchy.json" 22 | model_local_path = cache_dir / "local.h5" 23 | model_global_path = cache_dir / "global.h5" 24 | nissl_path = cache_dir / "nissl.npy" 25 | 26 | 27 | # Validation dataset definition 28 | validation_ixs = [ 29 | 2, 30 | 7, 31 | 15, 32 | 21, 33 | 25, 34 | 28, 35 | 37, 36 | 40, 37 | 49, 38 | 50, 39 | 54, 40 | 55, 41 | 56, 42 | 58, 43 | 60, 44 | 62, 45 | 64, 46 | 66, 47 | 78, 48 | 85, 49 | 93, 50 | 97, 51 | 121, 52 | 123, 53 | 125, 54 | 127, 55 | 129, 56 | 130, 57 | 133, 58 | 137, 59 | 140, 60 | 143, 61 | 144, 62 | 154, 63 | 169, 64 | 175, 65 | 183, 66 | 187, 67 | 192, 68 | 204, 69 | 206, 70 | 214, 71 | 216, 72 | 219, 73 | 223, 74 | 225, 75 | 226, 76 | 248, 77 | 251, 78 | 252, 79 | 254, 80 | 260, 81 | 271, 82 | 276, 83 | ] 84 | 85 | 86 | manual_labels = manual_registration(h5_path) 87 | nissl = nissl_volume(nissl_path) 88 | annotation = annotation_volume(annotation_path) 89 | labels = segmentation_collapsing_labels(labels_path) 90 | 91 | 92 | validation_set = {} 93 | keys = manual_labels.keys() 94 | 95 | for val_ix in validation_ixs: 96 | validation_set[val_ix] = {} 97 | for key in keys: 98 | validation_set[val_ix][key] = manual_labels[key][val_ix] 99 | 100 | 101 | model_local = load_model(model_local_path) 102 | model_global = load_model(model_global_path) 103 | 104 | 105 | metrics = {} 106 | df_locals = {} 107 | 108 | for k, data in validation_set.items(): 109 | print(k) 110 | 111 | # Preparation 112 | img_mov = data["img"] / 255 113 | p = data["p"] 114 | section_num = p // 25 115 | img_ref = nissl[section_num][..., 0] 116 | 117 | # Global model 118 | inp_global = np.stack([img_ref, img_mov], axis=-1)[None, ...] 119 | deltas_xy_global = model_global.predict(inp_global)[1][0] 120 | df_global = DisplacementField(deltas_xy_global[..., 0], deltas_xy_global[..., 1]) 121 | 122 | # Local model 123 | inp_local = np.stack([img_ref, df_global.warp(img_mov)], axis=-1)[None, ...] 124 | deltas_xy_local = model_local.predict([inp_local, np.zeros_like(inp_local)])[1][0] 125 | df_local = DisplacementField(deltas_xy_local[..., 0], deltas_xy_local[..., 1]) 126 | df_locals[k] = df_local 127 | 128 | # Overall model 129 | df_pred = df_local(df_global) 130 | 131 | deltas_true = data["deltas_xy"] 132 | deltas_pred = np.stack([df_pred.delta_x, df_pred.delta_y], axis=-1) 133 | 134 | metrics[k], _ = evaluate_single( 135 | deltas_true, 136 | deltas_pred, 137 | img_mov, 138 | p=p, 139 | avol=annotation, 140 | collapsing_labels=labels, 141 | deltas_pred_inv=None, 142 | deltas_true_inv=data["inv_deltas_xy"], 143 | ds_f=8, 144 | depths=(0, 2, 4, 6, 8), 145 | ) 146 | metrics[k]["p"] = p 147 | metrics[k]["image_id"] = data["image_id"] 148 | 149 | 150 | df = pd.DataFrame(metrics).transpose() 151 | 152 | cols = ["dice_0", "dice_2", "dice_4", "dice_6", "dice_8"] 153 | df_overview = pd.DataFrame({"mean": df[cols].mean(), "std": df[cols].std()}) 154 | 155 | print(df_overview) 156 | 157 | 158 | # Corrupted pixels analysis - we exclude border pixels 159 | h, w = 320, 456 160 | margin_ud = 1 161 | margin_lr = 6 # the network pads the sides with 0s 162 | 163 | corrupted_pixels = np.array( 164 | [ 165 | np.sum(x.jacobian[margin_ud : h - margin_ud, margin_lr : w - margin_lr] <= 0) 166 | for x in df_locals.values() 167 | ] 168 | ) 169 | 170 | n_pixels = (h - 2 * (margin_ud)) * (w - 2 * (margin_lr)) 171 | 172 | 173 | mean = 100 * (corrupted_pixels.mean() / n_pixels) 174 | std = 100 * (corrupted_pixels.std() / n_pixels) 175 | 176 | print("Corrupted_pixels") 177 | print( 178 | "Up and Down cuttoff: {}\nLeft and Right cutoff: {}\n".format(margin_ud, margin_lr) 179 | ) 180 | print("Mean:\n{}\n".format(mean)) 181 | print("STD:\n{}".format(std)) 182 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | import atlalign 4 | 5 | # install with `pip install -e .` 6 | 7 | # Preparations 8 | VERSION = atlalign.__version__ 9 | DESCRIPTION = "Blue Brain multi-modal registration and alignment toolbox" 10 | 11 | LONG_DESCRIPTION = """ 12 | Atlas Alignment is a toolbox to perform multimodal image registration. It 13 | includes both traditional and supervised deep learning models. 14 | 15 | This project originated from the Blue Brain Project efforts on aligning mouse 16 | brain atlases obtained with ISH gene expression and Nissl stains.""" 17 | 18 | PYTHON_REQUIRES = ">=3.6.0" 19 | INSTALL_REQUIRES = [ 20 | "antspyx", 21 | "imgaug<0.3", 22 | "matplotlib>=3.0.3", 23 | "mlflow", 24 | "nibabel>=2.4.0", 25 | "numpy<1.24.0", 26 | "seaborn", 27 | "scikit-image>=0.17.1", 28 | "scikit-learn>=0.20.2", 29 | "scipy", 30 | "tensorflow>=2.6.0", 31 | "tensorflow-addons", # For resampler in atlalign/ml_utils/layers.py 32 | ] 33 | 34 | setup( 35 | name="atlalign", 36 | version=VERSION, 37 | description=DESCRIPTION, 38 | long_description=LONG_DESCRIPTION, 39 | url="https://github.com/BlueBrain/atlas_alignment", 40 | author="Blue Brain Project, EPFL", 41 | license="LGPLv3", 42 | packages=find_packages(), 43 | classifiers=[ 44 | "Development Status :: 4 - Beta", 45 | "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", 46 | "Operating System :: Unix", 47 | "Programming Language :: Python", 48 | "Programming Language :: Python :: 3", 49 | "Programming Language :: Python :: 3.6", 50 | "Programming Language :: Python :: 3.7", 51 | "Programming Language :: Python :: 3.8", 52 | "Programming Language :: Python :: 3.9", 53 | "Programming Language :: Python :: 3 :: Only", 54 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 55 | "Topic :: Scientific/Engineering :: Image Processing", 56 | "Topic :: Software Development :: Libraries :: Python Modules", 57 | ], 58 | python_requires=PYTHON_REQUIRES, 59 | install_requires=INSTALL_REQUIRES, 60 | extras_require={ 61 | "dev": [ 62 | "black>=20.8b1", 63 | "flake8>=3.7.4", 64 | "pydocstyle>=3.0.0", 65 | "pytest>=3.10.1", 66 | "pytest-cov", 67 | "pytest-mock>=1.10.1", 68 | ], 69 | "docs": ["sphinx>=1.3", "sphinx-bluebrain-theme"], 70 | }, 71 | entry_points={"console_scripts": ["label-tool = atlalign.label.cli:main"]}, 72 | ) 73 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Define all fixtures. 2 | 3 | Notes 4 | ----- 5 | A lot of the fixtures are built on top of the following 4 atomic fixtures: 6 | * img_grayscale_uint 7 | * img_grayscale_float 8 | * img_rgb_uint 9 | * img_rgb_float 10 | 11 | The fixtures below are just parameterized in a way that only a subset of the above 12 | atomic fixtures is used 13 | * img 14 | * img_grayscale 15 | * img_rgb 16 | * img_uint 17 | * img_float 18 | """ 19 | 20 | """ 21 | The package atlalign is a tool for registration of 2D images. 22 | 23 | Copyright (C) 2021 EPFL/Blue Brain Project 24 | 25 | This program is free software: you can redistribute it and/or modify 26 | it under the terms of the GNU Lesser General Public License as published by 27 | the Free Software Foundation, either version 3 of the License, or 28 | (at your option) any later version. 29 | 30 | This program is distributed in the hope that it will be useful, 31 | but WITHOUT ANY WARRANTY; without even the implied warranty of 32 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 33 | GNU Lesser General Public License for more details. 34 | 35 | You should have received a copy of the GNU Lesser General Public License 36 | along with this program. If not, see . 37 | """ 38 | 39 | import pathlib 40 | 41 | import cv2 42 | import numpy as np 43 | import pytest 44 | from skimage.util import img_as_float32 45 | 46 | from atlalign.base import DisplacementField 47 | 48 | SHAPE = (20, 30) # used for `img_dummy` and `img_random` 49 | RANDOM_STATE = 2 50 | 51 | ROOT_PATH = pathlib.Path(__file__).resolve().parent.parent 52 | 53 | 54 | @pytest.fixture(scope="session") 55 | def path_test_data(): 56 | return ROOT_PATH / "tests" / "data" 57 | 58 | 59 | @pytest.fixture(scope="session") 60 | def img_grayscale_uint(path_test_data): 61 | """Generate a grayscale image with dtype uint8.""" 62 | file_path = path_test_data / "animals.jpg" 63 | 64 | img_out = cv2.imread(str(file_path), 0) 65 | 66 | assert img_out.ndim == 2 67 | assert img_out.dtype == np.uint8 68 | assert np.all(img_out >= 0) and np.all(img_out <= 255) 69 | 70 | return img_out 71 | 72 | 73 | @pytest.fixture() 74 | def img_grayscale_float(img_grayscale_uint): 75 | """Generate a float32 version of the grayscale image.""" 76 | img_out = img_as_float32(img_grayscale_uint, force_copy=True) 77 | 78 | assert img_out.ndim == 2 79 | assert img_out.dtype == np.float32 80 | assert np.all(img_out >= 0) and np.all(img_out <= 1) 81 | 82 | return img_out 83 | 84 | 85 | @pytest.fixture(scope="session") 86 | def img_rgb_uint(path_test_data): 87 | """Just a RGB image with dtype uint8. 88 | 89 | Notes 90 | ----- 91 | OpenCV reads a 3 channel image as 'BGR' but it is not a problem for our purposes. 92 | 93 | """ 94 | file_path = path_test_data / "animals.jpg" 95 | 96 | img_out = cv2.imread(str(file_path), 1) 97 | 98 | assert img_out.ndim == 3 99 | assert img_out.dtype == np.uint8 100 | assert np.all(img_out >= 0) and np.all(img_out <= 255) 101 | 102 | return img_out 103 | 104 | 105 | @pytest.fixture() 106 | def img_rgb_float(img_rgb_uint): 107 | """Generate a float32 version of the rgb image.""" 108 | img_out = img_as_float32(img_rgb_uint, force_copy=True) 109 | 110 | assert img_out.ndim == 3 111 | assert img_out.dtype == np.float32 112 | assert np.all(img_out >= 0) and np.all(img_out <= 1) 113 | 114 | return img_out 115 | 116 | 117 | @pytest.fixture( 118 | params=["grayscale_uint8", "grayscale_float32", "RGB_uint8", "RGB_float32"] 119 | ) 120 | def img(request, img_rgb_uint, img_grayscale_uint, img_rgb_float, img_grayscale_float): 121 | """Generate parametrized fixture capturing all 4 possible uint8/float32 and grayscale/rgb combinations. 122 | 123 | Notes 124 | ----- 125 | If this fixture used then the test will run automatically on all 4 of these. 126 | 127 | """ 128 | img_type = request.param 129 | 130 | if img_type == "grayscale_uint8": 131 | return img_grayscale_uint 132 | 133 | elif img_type == "grayscale_float32": 134 | return img_grayscale_float 135 | 136 | elif img_type == "RGB_uint8": 137 | return img_rgb_uint 138 | 139 | elif img_type == "RGB_float32": 140 | return img_rgb_float 141 | 142 | else: 143 | raise ValueError("Unrecognized image type {}".format(img_type)) 144 | 145 | 146 | @pytest.fixture(params=["grayscale_float32", "RGB_float32"]) 147 | def img_float(request, img_rgb_float, img_grayscale_float): 148 | """Generate a parametrized fixture capturing all 2 possible float32 -> grayscale and rgb. 149 | 150 | Notes 151 | ----- 152 | If this fixture used then the test will run automatically on all 2 of these. 153 | 154 | """ 155 | img_type = request.param 156 | 157 | if img_type == "grayscale_float32": 158 | return img_grayscale_float 159 | 160 | elif img_type == "RGB_float32": 161 | return img_rgb_float 162 | 163 | else: 164 | raise ValueError("Unrecognized image type {}".format(img_type)) 165 | 166 | 167 | @pytest.fixture(params=["grayscale_uint8", "RGB_uint8"]) 168 | def img_uint(request, img_rgb_uint, img_grayscale_uint): 169 | """Generate a parametrized fixture capturing all 2 possible uint8 -> grayscale and rgb. 170 | 171 | Notes 172 | ----- 173 | If this fixture used then the test will run automatically on all 2 of these. 174 | 175 | """ 176 | img_type = request.param 177 | 178 | if img_type == "grayscale_uint8": 179 | return img_grayscale_uint 180 | 181 | elif img_type == "RGB_uint8": 182 | return img_rgb_uint 183 | 184 | else: 185 | raise ValueError("Unrecognized image type {}".format(img_type)) 186 | 187 | 188 | @pytest.fixture(params=["grayscale_uint8", "grayscale_float32"]) 189 | def img_grayscale(request, img_grayscale_uint, img_grayscale_float): 190 | """Generate a parametrized fixture capturing all 2 possible grayscale -> uint8 and float32. 191 | 192 | Notes 193 | ----- 194 | If this fixture used then the test will run automatically on all 2 of these. 195 | 196 | """ 197 | img_type = request.param 198 | 199 | if img_type == "grayscale_uint8": 200 | return img_grayscale_uint 201 | 202 | elif img_type == "grayscale_float32": 203 | return img_grayscale_float 204 | 205 | else: 206 | raise ValueError("Unrecognized image type {}".format(img_type)) 207 | 208 | 209 | @pytest.fixture(params=["RGB_uint8", "RGB_float32"]) 210 | def img_rgb(request, img_rgb_uint, img_rgb_float): 211 | """Generate a parametrized fixture capturing all 2 possible rgb -> uint8 and float32. 212 | 213 | Notes 214 | ----- 215 | If this fixture used then the test will run automatically on all 2 of these. 216 | 217 | """ 218 | img_type = request.param 219 | 220 | if img_type == "RGB_uint8": 221 | return img_rgb_uint 222 | 223 | elif img_type == "RGB_float32": 224 | return img_rgb_float 225 | 226 | else: 227 | raise ValueError("Unrecognized image type {}".format(img_type)) 228 | 229 | 230 | @pytest.fixture(scope="function") # In order to allow for in place changes 231 | def img_dummy(): 232 | """Generate a dummy image made out of all zeros.""" 233 | return np.zeros(SHAPE, dtype=np.float32) 234 | 235 | 236 | @pytest.fixture(scope="function") 237 | def img_random(): 238 | """Generate a dummy images made out of random intensities.""" 239 | np.random.seed(RANDOM_STATE) 240 | out = np.random.random(SHAPE).astype(dtype=np.float32) 241 | 242 | assert out.dtype == np.float32 243 | 244 | return out 245 | 246 | 247 | @pytest.fixture(scope="session") 248 | def df_cached(path_test_data): 249 | """Load a DVF and its inverse of shape (80, 114). 250 | 251 | DVF represents a mild warping then can be relatively easily unwarped. 252 | 253 | Notes 254 | ----- 255 | After composition the largest displacement in x ~ 0.05 and in y ~ 0.03. 256 | 257 | Returns 258 | ------- 259 | delta_x : np.array 260 | DVF in the x direction. 261 | 262 | delta_y : np.array 263 | DVF in the y direction. 264 | 265 | delta_x_inv : np.array 266 | Inverse DVF in the x direction. 267 | 268 | delta_x_inv : np.array 269 | Inverse DVF in the y direction. 270 | 271 | """ 272 | file_path = path_test_data / "mild_inversion.npy" 273 | a = np.load(str(file_path)) 274 | 275 | delta_x = a[..., 0] 276 | delta_y = a[..., 1] 277 | delta_x_inv = a[..., 2] 278 | delta_y_inv = a[..., 3] 279 | 280 | return delta_x, delta_y, delta_x_inv, delta_y_inv 281 | 282 | 283 | @pytest.fixture() 284 | def df_id(request): 285 | """Generate an identity transformation. 286 | 287 | In order to specify a shape one decorates the test function in the following way: 288 | `@pytest.mark.parametrize('df_id', [(320, 456)], indirect=True)` 289 | 290 | """ 291 | if hasattr(request, "param"): 292 | shape = request.param 293 | else: 294 | shape = (10, 11) 295 | return DisplacementField.generate(shape, approach="identity") 296 | 297 | 298 | @pytest.fixture() 299 | def label_dict(): 300 | """Gnerate a dictionary for the concatenation of labels (segmentation).""" 301 | dic = { 302 | "id": 2, 303 | "children": [ 304 | {"id": 3, "children": []}, 305 | { 306 | "id": 4, 307 | "children": [{"id": 5, "children": []}, {"id": 6, "children": []}], 308 | }, 309 | ], 310 | } 311 | 312 | return dic 313 | -------------------------------------------------------------------------------- /tests/data/animals.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/tests/data/animals.jpg -------------------------------------------------------------------------------- /tests/data/mild_inversion.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/tests/data/mild_inversion.npy -------------------------------------------------------------------------------- /tests/data/supervised_dataset.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueBrain/atlas-alignment/9d3def5a68add1654b5e33c0b4b8c73130e600cc/tests/data/supervised_dataset.h5 -------------------------------------------------------------------------------- /tests/test_augmentations.py: -------------------------------------------------------------------------------- 1 | """Collection of tests focused on the augmentations.py module.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | import pathlib 22 | from unittest.mock import Mock 23 | 24 | import numpy as np 25 | import pytest 26 | 27 | from atlalign.augmentations import DatasetAugmenter, load_dataset_in_memory 28 | from atlalign.base import DisplacementField 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "key", 33 | [ 34 | "dataset_id", 35 | "deltas_xy", 36 | "image_id", 37 | "img", 38 | "inv_deltas_xy", 39 | "p", 40 | ], 41 | ) 42 | def test_load_dataset_in_memory(path_test_data, key): 43 | h5_path = path_test_data / "supervised_dataset.h5" 44 | 45 | res = load_dataset_in_memory(h5_path, key) 46 | 47 | assert isinstance(res, np.ndarray) 48 | assert len(res) > 0 49 | 50 | 51 | class TestDatasetAugmenter: 52 | def test_construction(self, path_test_data): 53 | da = DatasetAugmenter(path_test_data / "supervised_dataset.h5") 54 | 55 | assert da.n_orig > 0 56 | 57 | @pytest.mark.parametrize("n_iter", [1, 2]) 58 | @pytest.mark.parametrize("anchor", [True, False]) 59 | @pytest.mark.parametrize("is_valid", [True, False]) 60 | def test_augment( 61 | self, monkeypatch, path_test_data, tmpdir, n_iter, anchor, is_valid 62 | ): 63 | fake_es = Mock( 64 | return_value=DisplacementField.generate( 65 | (320, 456), approach="affine_simple", rotation=0.2 66 | ) 67 | ) 68 | max_corrupted_pixels = 10 69 | 70 | if not is_valid: 71 | max_corrupted_pixels = 0 # hack that will force to use the original 72 | 73 | monkeypatch.setattr("atlalign.zoo.edge_stretching", fake_es) 74 | 75 | da = DatasetAugmenter(path_test_data / "supervised_dataset.h5") 76 | 77 | output_path = pathlib.Path(str(tmpdir)) / "output.h5" 78 | 79 | da.augment( 80 | output_path, 81 | n_iter=n_iter, 82 | anchor=anchor, 83 | max_trials=2, 84 | max_corrupted_pixels=max_corrupted_pixels, 85 | ds_f=32, 86 | ) 87 | 88 | assert output_path.exists() 89 | 90 | keys = ["dataset_id", "deltas_xy", "image_id", "img", "inv_deltas_xy", "p"] 91 | for key in keys: 92 | array = load_dataset_in_memory(output_path, key) 93 | assert da.n_orig * n_iter == len(array) 94 | assert not np.any(np.isnan(array)) 95 | 96 | keys = ["dataset_id", "image_id", "p"] 97 | for key in keys: 98 | original_a = load_dataset_in_memory(da.original_path, key) 99 | new_a = load_dataset_in_memory(output_path, key) 100 | new_a_expected = np.concatenate(n_iter * [original_a]) 101 | 102 | np.testing.assert_equal(new_a, new_a_expected) 103 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | The package atlalign is a tool for registration of 2D images. 3 | 4 | Copyright (C) 2021 EPFL/Blue Brain Project 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU Lesser General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU Lesser General Public License for more details. 15 | 16 | You should have received a copy of the GNU Lesser General Public License 17 | along with this program. If not, see . 18 | """ 19 | 20 | import pathlib 21 | 22 | import numpy as np 23 | import pytest 24 | 25 | from atlalign.data import ( 26 | annotation_volume, 27 | circles, 28 | manual_registration, 29 | nissl_volume, 30 | rectangles, 31 | segmentation_collapsing_labels, 32 | ) 33 | from atlalign.utils import _find_all_children 34 | 35 | 36 | class TestAnnotationVolume: 37 | """A collection of tests focused on the `annotation_volume` function.""" 38 | 39 | def test_load_works(self, monkeypatch, tmpdir): 40 | """Test that loading works""" 41 | # Lets do some patching 42 | monkeypatch.setattr( 43 | "numpy.load", 44 | lambda *args, **kwargs: np.zeros((528, 320, 456), dtype=np.float32), 45 | ) 46 | 47 | x_atlas = annotation_volume() 48 | 49 | # Final output 50 | assert x_atlas.shape == (528, 320, 456) 51 | assert np.all(np.isfinite(x_atlas)) 52 | assert x_atlas.dtype == np.int32 53 | 54 | 55 | class TestNisslVolume: 56 | """A collection of tests focused on the `nissl_volume` function.""" 57 | 58 | def test_load_works(self, monkeypatch): 59 | """Test that loading works.""" 60 | # Lets do some patching 61 | monkeypatch.setattr( 62 | "numpy.load", 63 | lambda *args, **kwargs: np.zeros((528, 320, 456), dtype=np.float32), 64 | ) 65 | monkeypatch.setattr( 66 | "atlalign.data.img_as_float32", 67 | lambda *args, **kwargs: np.zeros((320, 456), dtype=np.float32), 68 | ) 69 | 70 | x_atlas = nissl_volume() 71 | 72 | # Final output 73 | assert x_atlas.shape == (528, 320, 456, 1) 74 | assert np.all(np.isfinite(x_atlas)) 75 | assert x_atlas.min() >= 0 76 | assert x_atlas.max() <= 1 77 | assert x_atlas.dtype == np.float32 78 | 79 | 80 | class TestManualRegistration: 81 | def test_correct_keys(self, monkeypatch): 82 | path = "tests/data/supervised_dataset.h5" # manual patch 83 | 84 | res = manual_registration(path) 85 | 86 | # Final output 87 | assert set(res.keys()) == { 88 | "dataset_id", 89 | "deltas_xy", 90 | "image_id", 91 | "img", 92 | "inv_deltas_xy", 93 | "p", 94 | } 95 | assert np.all([isinstance(x, np.ndarray) for x in res.values()]) 96 | 97 | 98 | class TestCircles: 99 | """Collection of tests focues on the circles function.""" 100 | 101 | def test_input_shape(self): 102 | """Test that only allows for 2D.""" 103 | 104 | shape_wrong = (41, 21, 312) 105 | shape_correct = (100, 30) 106 | 107 | with pytest.raises(ValueError): 108 | circles(10, shape_wrong, 10) 109 | 110 | circles(10, shape_correct, 10) 111 | 112 | def test_correct_dtype(self): 113 | """Test that float32 ndarray is returned.""" 114 | 115 | shape = (100, 30) 116 | res = circles(10, shape, 10) 117 | 118 | assert res.min() >= 0 119 | assert res.max() <= 1 120 | assert res.dtype == np.float32 121 | 122 | @pytest.mark.parametrize("n_levels", [1, 2, 3, 4, 5, 6]) 123 | def test_correct_number_of_intensities(self, n_levels): 124 | """Test whether n_unique_intensities = n_levels + 1""" 125 | 126 | res = circles(10, (200, 220), (40, 50), n_levels=n_levels) 127 | 128 | for row in res: 129 | assert len(np.unique(row) == n_levels + 1) # also count black background 130 | 131 | def test_reproducible(self): 132 | """Test that random_state works.""" 133 | 134 | res_1 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=None) 135 | res_2 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=1) 136 | res_3 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=2) 137 | res_4 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=1) 138 | res_5 = circles(10, (100, 120), (20, 40), n_levels=(2, 6), random_state=None) 139 | 140 | assert np.all(res_2 == res_4) 141 | assert not np.all(res_2 == res_3) 142 | assert not np.all(res_2 == res_1) 143 | assert not np.all(res_3 == res_1) 144 | assert not np.all(res_1 == res_5) 145 | 146 | 147 | class TestRectangles: 148 | """A collection of tests testing the rectangles function.""" 149 | 150 | def test_input_shape(self): 151 | """Test that only allows for 2D.""" 152 | 153 | shape_wrong = (40, 30, 10) 154 | shape_correct = (40, 30) 155 | 156 | with pytest.raises(ValueError): 157 | rectangles(100, shape_wrong, 10, 20) 158 | 159 | rectangles(100, shape_correct, 10, 20) 160 | 161 | def test_wrong_type(self): 162 | """Test that height, width and n_levels only work with integers.""" 163 | 164 | shape = (40, 50) 165 | with pytest.raises(TypeError): 166 | rectangles(100, shape, 10.1, 20) 167 | 168 | with pytest.raises(TypeError): 169 | rectangles(100, shape, 10, 20.3) 170 | 171 | with pytest.raises(TypeError): 172 | rectangles(100, shape, (10, 20), 20, n_levels=3.4) 173 | 174 | rectangles(10, shape, 10, 20, n_levels=2) 175 | 176 | def test_wrong_rectangle_size(self): 177 | """Test that rectangle needs to fit the image.""" 178 | 179 | shape = (img_h, img_w) = (40, 50) 180 | 181 | with pytest.raises(ValueError): 182 | rectangles(100, shape, img_h + 1, img_w - 1) 183 | 184 | with pytest.raises(ValueError): 185 | rectangles(100, shape, img_h - 1, img_w + 1) 186 | 187 | with pytest.raises(ValueError): 188 | rectangles(100, shape, img_h + 1, img_w + 1) 189 | 190 | rectangles(100, shape, img_h - 1, img_w - 1) 191 | 192 | def test_wrong_n_levels(self): 193 | """Test that n_levels needs to be correct.""" 194 | 195 | with pytest.raises(ValueError): 196 | rectangles(5, (100, 120), 20, 10, n_levels=14) 197 | 198 | with pytest.raises(ValueError): 199 | rectangles(5, (100, 120), 10, 20, n_levels=14) 200 | 201 | rectangles(5, (100, 120), 20, 20, n_levels=14) 202 | 203 | def test_reproducible(self): 204 | """Test that random_state works.""" 205 | 206 | res_1 = rectangles( 207 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=None 208 | ) 209 | res_2 = rectangles( 210 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=1 211 | ) 212 | res_3 = rectangles( 213 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=2 214 | ) 215 | res_4 = rectangles( 216 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=1 217 | ) 218 | res_5 = rectangles( 219 | 10, (100, 120), (20, 30), (10, 40), n_levels=(1, 4), random_state=None 220 | ) 221 | 222 | assert np.all(res_2 == res_4) 223 | assert not np.all(res_2 == res_3) 224 | assert not np.all(res_2 == res_1) 225 | assert not np.all(res_3 == res_1) 226 | assert not np.all(res_1 == res_5) 227 | 228 | @pytest.mark.parametrize("random_state", [0, 1, 2, 3, 4]) 229 | def test_no_empty_images(self, random_state): 230 | """Test that no empty images.""" 231 | 232 | shape = (100, 120) 233 | res = rectangles( 234 | 10, shape, (20, 30), (10, 50), n_levels=4, random_state=random_state 235 | ) 236 | 237 | zeros = np.zeros((*shape, 1)) 238 | for row in res: 239 | assert not np.all(row == zeros) 240 | 241 | def test_output_shape(self): 242 | 243 | """Test that the shape of the output is correct.""" 244 | shape = (50, 100) 245 | res = rectangles(10, shape, (20, 30), (10, 50), n_levels=4) 246 | 247 | assert res.shape == (10, *shape, 1) 248 | 249 | def test_correct_dtype(self): 250 | """Test that float32 ndarray is returned.""" 251 | 252 | shape = (50, 100) 253 | res = rectangles(10, shape, (20, 30), (10, 50), n_levels=4) 254 | 255 | assert res.min() >= 0 256 | assert res.max() <= 1 257 | assert res.dtype == np.float32 258 | 259 | def test_full_intensity(self): 260 | """Test that there exists a pixel with a full intensity = 1.""" 261 | 262 | shape = (50, 100) 263 | res = rectangles(10, shape, (20, 30), (10, 50), n_levels=4) 264 | 265 | for row in res: 266 | assert np.any(row == 1) 267 | 268 | @pytest.mark.parametrize("n_levels", [1, 2, 3, 4, 5, 6]) 269 | def test_correct_number_of_intensities(self, n_levels): 270 | """Test whether n_unique_intensities = n_levels + 1""" 271 | 272 | res = rectangles(10, (200, 220), (40, 50), (50, 100), n_levels=n_levels) 273 | 274 | for row in res: 275 | assert len(np.unique(row) == n_levels + 1) # also count black background 276 | 277 | 278 | class TestSegmentationCollapsingLabels: 279 | def test_load_works(self, monkeypatch, tmpdir): 280 | """Test that loading works.""" 281 | tmpfile = pathlib.Path(str(tmpdir)) / "temp.json" 282 | tmpfile.touch() 283 | 284 | # Lets do some patching 285 | monkeypatch.setattr("json.load", lambda *args, **kwargs: {}) 286 | 287 | res = segmentation_collapsing_labels(tmpfile) 288 | 289 | # Final output 290 | assert isinstance(res, dict) 291 | 292 | @pytest.mark.skip 293 | def test_no_id_equal_to_negative_one(self): 294 | """Make sure that -1 is not an existing label since we want to use it as a default not found value. 295 | 296 | We skip this bacause it assumes the dataset is downloaded and it wouldnt make sense to patch this. 297 | 298 | See Also 299 | -------- 300 | """ 301 | 302 | all_children = _find_all_children(segmentation_collapsing_labels()) 303 | 304 | assert isinstance(all_children, list) 305 | assert all_children 306 | assert -1 not in all_children 307 | -------------------------------------------------------------------------------- /tests/test_ml_utils/test_augmentation.py: -------------------------------------------------------------------------------- 1 | """Collection of tests focused on the augmentation.py moduele.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import imgaug.augmenters as iaa 23 | 24 | from atlalign.ml_utils import augmenter_1 25 | 26 | 27 | def test_runnable(): 28 | aug_1 = augmenter_1() 29 | 30 | assert isinstance(aug_1, iaa.Augmenter) 31 | -------------------------------------------------------------------------------- /tests/test_ml_utils/test_callbacks.py: -------------------------------------------------------------------------------- 1 | """Tests focused on callbacks.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import pathlib 23 | from unittest.mock import Mock 24 | 25 | import h5py 26 | import mlflow 27 | import numpy as np 28 | import pandas as pd 29 | import pytest 30 | 31 | from atlalign.augmentations import load_dataset_in_memory 32 | from atlalign.ml_utils import ( 33 | MLFlowCallback, 34 | SupervisedGenerator, 35 | get_mlflow_artifact_path, 36 | ) 37 | from atlalign.nn import supervised_model_factory 38 | 39 | 40 | def test_get_mlflow_artifact_path(monkeypatch, tmpdir): 41 | monkeypatch.chdir(tmpdir) 42 | 43 | with mlflow.start_run(): 44 | artifact_path = get_mlflow_artifact_path() 45 | 46 | expected = [ 47 | x for x in pathlib.Path(str(tmpdir)).rglob("*") if "artifacts" in str(x) 48 | ][0] 49 | 50 | assert artifact_path == expected 51 | 52 | 53 | class TestMLFlowCallback: 54 | @staticmethod 55 | def create_h5(h5_path, n_samples, random_state): 56 | height, width = 320, 456 57 | np.random.seed(random_state) 58 | 59 | with h5py.File(h5_path, "w") as f: 60 | dset_img = f.create_dataset( 61 | "img", (n_samples, height, width), dtype="uint8" 62 | ) 63 | dset_image_id = f.create_dataset("image_id", (n_samples,), dtype="int") 64 | dset_dataset_id = f.create_dataset("dataset_id", (n_samples,), dtype="int") 65 | dset_p = f.create_dataset("p", (n_samples,), dtype="int") 66 | dset_deltas_xy = f.create_dataset( 67 | "deltas_xy", 68 | (n_samples, height, width, 2), 69 | dtype=np.float16, 70 | fillvalue=0, 71 | ) 72 | dset_inv_deltas_xy = f.create_dataset( 73 | "inv_deltas_xy", 74 | (n_samples, height, width, 2), 75 | dtype=np.float16, 76 | fillvalue=0, 77 | ) 78 | 79 | # Populate 80 | dset_deltas_xy[:] = np.random.random((n_samples, height, width, 2)) 81 | dset_inv_deltas_xy[:] = np.random.random((n_samples, height, width, 2)) 82 | dset_img[:] = np.random.randint( 83 | 0, high=255, size=(n_samples, height, width) 84 | ) 85 | dset_image_id[:] = 50 + np.random.choice( 86 | n_samples, size=n_samples, replace=False 87 | ) 88 | dset_dataset_id[:] = 1000 + np.random.choice( 89 | n_samples, size=n_samples, replace=False 90 | ) 91 | dset_p[:] = np.random.randint(0, high=12000, size=n_samples) 92 | 93 | def test_hooks(self, tmpdir, path_test_data, monkeypatch): 94 | tmpdir = pathlib.Path(str(tmpdir)) 95 | monkeypatch.chdir(tmpdir) 96 | fake_sg_c = Mock(return_value=Mock(spec=SupervisedGenerator)) 97 | monkeypatch.setattr( 98 | "atlalign.ml_utils.callbacks.SupervisedGenerator", fake_sg_c 99 | ) 100 | monkeypatch.setattr("atlalign.ml_utils.callbacks.keras", Mock()) 101 | 102 | h5_path = path_test_data / "supervised_dataset.h5" 103 | 104 | train_ixs_path = "a" 105 | val_ixs_path = "b" 106 | 107 | with mlflow.start_run(): 108 | cb = MLFlowCallback(h5_path, train_ixs_path, val_ixs_path, freq=2) 109 | 110 | assert fake_sg_c.call_count == 2 111 | 112 | # Train 113 | train_kwargs = fake_sg_c.call_args_list[0][1] 114 | assert train_kwargs["indexes"] == train_ixs_path 115 | assert train_kwargs["shuffle"] is False 116 | assert train_kwargs["batch_size"] == 1 117 | 118 | # Val 119 | val_kwargs = fake_sg_c.call_args_list[1][1] 120 | assert val_kwargs["indexes"] == val_ixs_path 121 | assert val_kwargs["shuffle"] is False 122 | assert val_kwargs["batch_size"] == 1 123 | 124 | # On train begin 125 | assert not (cb.root_path / "architecture").exists() 126 | assert not (cb.root_path / "checkpoints").exists() 127 | 128 | cb.on_train_begin() 129 | 130 | assert (cb.root_path / "architecture").exists() 131 | assert (cb.root_path / "checkpoints").exists() 132 | 133 | # On batch_end 134 | cb.model = Mock(metrics_names=[]) # Inject a keras model 135 | cb.model.evaluate_generator.return_value = [] 136 | monkeypatch.setattr( 137 | cb, 138 | "compute_external_metrics", 139 | Mock( 140 | return_value=pd.DataFrame( 141 | { 142 | "metric": [ 143 | 1, 144 | ] 145 | } 146 | ) 147 | ), 148 | ) 149 | 150 | cb.on_batch_end(None) # 1 151 | cb.on_batch_end(None) # 2 152 | 153 | @pytest.mark.parametrize("random_state", [3, 10]) 154 | @pytest.mark.parametrize("return_inverse", [True, False]) 155 | def test_compute_external_metrics( 156 | self, monkeypatch, tmpdir, random_state, return_inverse 157 | ): 158 | evaluate_cache = [] 159 | 160 | def fake_evaluate(*args, **kwargs): 161 | evaluate_cache.append( 162 | { 163 | "deltas_true": args[0], 164 | "img_mov": args[2], 165 | "p": kwargs["p"], 166 | "deltas_true_inv": kwargs["deltas_true_inv"], 167 | } 168 | ) 169 | 170 | return pd.Series([2, 3]) 171 | 172 | monkeypatch.setattr( 173 | "atlalign.ml_utils.callbacks.evaluate_single", 174 | Mock(side_effect=fake_evaluate), 175 | ) 176 | monkeypatch.setattr("atlalign.ml_utils.callbacks.annotation_volume", Mock()) 177 | monkeypatch.setattr( 178 | "atlalign.ml_utils.io.nissl_volume", 179 | Mock(return_value=np.zeros((528, 320, 456, 1))), 180 | ) 181 | monkeypatch.setattr( 182 | "atlalign.ml_utils.callbacks.segmentation_collapsing_labels", Mock() 183 | ) 184 | 185 | n_samples = 10 186 | n_val_samples = 4 187 | h5_path = pathlib.Path(str(tmpdir)) / "temp.h5" 188 | self.create_h5(h5_path, n_samples, random_state) 189 | 190 | val_indexes = list(np.random.choice(n_samples, n_val_samples, replace=False)) 191 | 192 | val_gen = SupervisedGenerator( 193 | h5_path, 194 | indexes=val_indexes, 195 | shuffle=False, 196 | batch_size=1, 197 | return_inverse=return_inverse, 198 | ) 199 | losses = ["mse", "mse", "mse"] if return_inverse else ["mse", "mse"] 200 | losses_weights = [1, 1, 1] if return_inverse else [1, 1] 201 | 202 | model = supervised_model_factory( 203 | compute_inv=return_inverse, 204 | losses=losses, 205 | losses_weights=losses_weights, 206 | start_filters=(2,), 207 | downsample_filters=(4, 2), 208 | middle_filters=(2,), 209 | upsample_filters=(2, 4), 210 | ) 211 | 212 | df = MLFlowCallback.compute_external_metrics(model, val_gen) 213 | 214 | assert len(df) == len(val_indexes) 215 | assert np.allclose( 216 | df.index.values, load_dataset_in_memory(h5_path, "image_id")[val_indexes] 217 | ) 218 | assert len(evaluate_cache) == len(val_indexes) 219 | 220 | for ecache, val_index in zip(evaluate_cache, val_indexes): 221 | expected_deltas = load_dataset_in_memory(h5_path, "deltas_xy")[val_index] 222 | expected_deltas_inv = load_dataset_in_memory(h5_path, "inv_deltas_xy")[ 223 | val_index 224 | ] 225 | expected_image = load_dataset_in_memory(h5_path, "img")[val_index] / 255 226 | expected_p = load_dataset_in_memory(h5_path, "p")[val_index] 227 | 228 | assert np.allclose(expected_deltas, ecache["deltas_true"]) 229 | assert np.allclose(expected_image, ecache["img_mov"]) 230 | assert np.allclose(expected_p, ecache["p"]) 231 | 232 | if return_inverse: 233 | assert np.allclose(expected_deltas_inv, ecache["deltas_true_inv"]) 234 | else: 235 | assert ecache["deltas_true_inv"] is None # they are not streamed 236 | -------------------------------------------------------------------------------- /tests/test_ml_utils/test_io.py: -------------------------------------------------------------------------------- 1 | """Tests focues on the atlalign.io module.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import pathlib 23 | 24 | import imgaug.augmenters as iaa 25 | import numpy as np 26 | import pytest 27 | 28 | from atlalign.ml_utils import SupervisedGenerator 29 | 30 | SUPERVISED_H5_PATH = ( 31 | pathlib.Path(__file__).parent.parent / "data" / "supervised_dataset.h5" 32 | ) 33 | 34 | 35 | @pytest.fixture(scope="session") 36 | def fake_nissl_volume(): 37 | return np.random.random((528, 320, 456, 1)).astype(np.float32) 38 | 39 | 40 | class TestSupervisedGenerator: 41 | def test_inexistent_h5(self, monkeypatch, fake_nissl_volume): 42 | monkeypatch.setattr( 43 | "atlalign.ml_utils.io.nissl_volume", 44 | lambda *args, **kwargs: fake_nissl_volume, 45 | ) 46 | path_wrong = SUPERVISED_H5_PATH.parent / "fake.h5" 47 | 48 | with pytest.raises(OSError): 49 | SupervisedGenerator(path_wrong) 50 | 51 | def test_correct_indexes(self, monkeypatch, fake_nissl_volume): 52 | """Inner indices are [0, 1].""" 53 | monkeypatch.setattr( 54 | "atlalign.ml_utils.io.nissl_volume", 55 | lambda *args, **kwargs: fake_nissl_volume, 56 | ) 57 | 58 | assert SupervisedGenerator(SUPERVISED_H5_PATH).indexes == [0, 1] 59 | 60 | @pytest.mark.parametrize("batch_size", [1, 2]) 61 | def test_length(self, batch_size, monkeypatch, fake_nissl_volume): 62 | """Test that length of the Sequence is computed correctly.""" 63 | monkeypatch.setattr( 64 | "atlalign.ml_utils.io.nissl_volume", 65 | lambda *args, **kwargs: fake_nissl_volume, 66 | ) 67 | correct_len = 2 // batch_size 68 | 69 | assert ( 70 | len(SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size)) 71 | == correct_len 72 | ) 73 | 74 | def test_shuffling(self, monkeypatch, fake_nissl_volume): 75 | """Test shuffling works.""" 76 | monkeypatch.setattr( 77 | "atlalign.ml_utils.io.nissl_volume", 78 | lambda *args, **kwargs: fake_nissl_volume, 79 | ) 80 | n_trials = 10 81 | is_different = False 82 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, shuffle=True) 83 | orig_indexes = gen.indexes[:] 84 | 85 | for _ in range(n_trials): 86 | gen.on_epoch_end() 87 | is_different = is_different or orig_indexes != gen.indexes 88 | 89 | assert is_different 90 | 91 | @pytest.mark.parametrize("batch_size", [1, 2]) 92 | def test_getitem(self, batch_size, monkeypatch, fake_nissl_volume): 93 | """Get item.""" 94 | monkeypatch.setattr( 95 | "atlalign.ml_utils.io.nissl_volume", 96 | lambda *args, **kwargs: fake_nissl_volume, 97 | ) 98 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size) 99 | 100 | inp, out = gen[0] 101 | 102 | assert inp.shape == (batch_size, 320, 456, 2) 103 | assert inp.dtype == np.float32 104 | assert 0 <= inp.min() <= inp.max() <= 1 105 | 106 | assert isinstance(out, list) 107 | assert len(out) == 2 108 | 109 | assert out[0].shape == (batch_size, 320, 456, 1) and out[1].shape == ( 110 | batch_size, 111 | 320, 112 | 456, 113 | 2, 114 | ) 115 | assert out[0].dtype == np.float32 and out[1].dtype == np.float16 116 | assert 0 <= out[0].min() <= out[0].max() <= 1 117 | 118 | @pytest.mark.parametrize("aug_ref", [True, False]) 119 | @pytest.mark.parametrize("aug_mov", [True, False]) 120 | def test_augmenters(self, aug_ref, aug_mov, monkeypatch, fake_nissl_volume): 121 | """Augmenting works.""" 122 | monkeypatch.setattr( 123 | "atlalign.ml_utils.io.nissl_volume", 124 | lambda *args, **kwargs: fake_nissl_volume, 125 | ) 126 | batch_size = 2 127 | augmenter = iaa.Sequential( 128 | [ 129 | iaa.Fliplr(0.5), # horizontal flips 130 | iaa.Crop(percent=(0, 0.1)), # random crops 131 | # Small gaussian blur with random sigma between 0 and 0.5. 132 | # But we only blur about 50% of all images. 133 | iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))), 134 | ] 135 | ) 136 | 137 | kwargs = { 138 | "augmenter_ref": augmenter if aug_ref else None, 139 | "augmenter_mov": augmenter if aug_mov else None, 140 | } 141 | 142 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size, **kwargs) 143 | 144 | inp, out = gen[0] 145 | 146 | assert inp.dtype == np.float32 147 | 148 | assert isinstance(out, list) 149 | assert len(out) == 2 150 | 151 | assert out[0].shape == (batch_size, 320, 456, 1) and out[1].shape == ( 152 | batch_size, 153 | 320, 154 | 456, 155 | 2, 156 | ) 157 | assert out[0].dtype == np.float32 and out[1].dtype == np.float16 158 | 159 | @pytest.mark.parametrize("batch_size", [1, 2]) 160 | def test_get_all_data(self, batch_size, monkeypatch, fake_nissl_volume): 161 | monkeypatch.setattr( 162 | "atlalign.ml_utils.io.nissl_volume", 163 | lambda *args, **kwargs: fake_nissl_volume, 164 | ) 165 | 166 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size) 167 | 168 | all_inp, all_out = gen.get_all_data() 169 | 170 | assert len(all_inp) == 2 171 | assert len(all_out) == 2 172 | 173 | def test_indexes(self, monkeypatch, fake_nissl_volume): 174 | """Make sure indexes attribute work.""" 175 | batch_size = 1 176 | monkeypatch.setattr( 177 | "atlalign.ml_utils.io.nissl_volume", 178 | lambda *args, **kwargs: fake_nissl_volume, 179 | ) 180 | 181 | gen_indexes = SupervisedGenerator( 182 | SUPERVISED_H5_PATH, batch_size=batch_size, indexes=[0] 183 | ) 184 | gen = SupervisedGenerator(SUPERVISED_H5_PATH, batch_size=batch_size) 185 | 186 | assert len(gen) == 2 187 | assert len(gen_indexes) == 1 188 | -------------------------------------------------------------------------------- /tests/test_ml_utils/test_models.py: -------------------------------------------------------------------------------- 1 | """Collection of tests focused on the atlalign.ml_utils.models module.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import pathlib 23 | from copy import deepcopy 24 | 25 | import pytest 26 | from tensorflow import keras 27 | 28 | from atlalign.ml_utils import ( 29 | load_model, 30 | merge_global_local, 31 | replace_lambda_in_config, 32 | save_model, 33 | ) 34 | 35 | 36 | @pytest.fixture(scope="function", params=["compiled", "uncompiled"]) 37 | def reg_model(request): 38 | """Minimal registration network.""" 39 | 40 | compile = request.param == "compiled" 41 | 42 | inputs = keras.layers.Input((32, 45, 2)) 43 | 44 | extract_0 = keras.layers.Lambda(lambda x: x[..., :-1]) 45 | dummy_op_1 = keras.layers.Lambda(lambda x: x + 1) 46 | dummy_op_2 = keras.layers.Lambda(lambda x: x * 1) 47 | 48 | reg_imgs = dummy_op_1(extract_0(inputs)) 49 | dvfs = dummy_op_2(inputs) 50 | 51 | model = keras.models.Model(inputs=inputs, outputs=[reg_imgs, dvfs]) 52 | 53 | if compile: 54 | model.compile(optimizer="adam", loss="mse") 55 | 56 | return model 57 | 58 | 59 | @pytest.fixture() 60 | def lambda_model(): 61 | """Model containing Lambda layers""" 62 | inputs = keras.layers.Input(shape=(10, 14, 2)) 63 | x = keras.layers.Lambda(lambda x: x, name="inv_dvf")(inputs) 64 | x = keras.layers.Lambda(lambda x: x[..., 1:], name="extract_moving")(x) 65 | 66 | return keras.Model(inputs=inputs, outputs=x) 67 | 68 | 69 | class TestGlobalLocal: 70 | """Tests focused on the merge_global_local function.""" 71 | 72 | @pytest.mark.parametrize("expose_global", [True, False]) 73 | def test_overall(self, reg_model, expose_global): 74 | model_gl = merge_global_local(reg_model, reg_model, expose_global=expose_global) 75 | 76 | assert isinstance(model_gl, keras.models.Model) 77 | 78 | assert len(model_gl.outputs) == (4 if expose_global else 2) 79 | 80 | 81 | class TestSaveModel: 82 | """Collection of tests focused on the `save_model` function.""" 83 | 84 | @pytest.mark.parametrize("separate", [True, False]) 85 | def test_correct(self, separate, tmpdir, reg_model): 86 | temp_path = pathlib.Path(str(tmpdir)) / "temp_model" 87 | save_model(reg_model, temp_path, separate=separate) 88 | 89 | if separate: 90 | assert (temp_path / "temp_model.json").exists() 91 | assert (temp_path / "temp_model.h5").exists() 92 | else: 93 | assert pathlib.Path(str(temp_path) + ".h5").exists() 94 | 95 | def test_wrong_path(self, reg_model): 96 | with pytest.raises(ValueError): 97 | save_model(reg_model, "aaa.extension") 98 | 99 | def test_already_exists(self, reg_model, tmpdir): 100 | orig_path = pathlib.Path(str(tmpdir)) / "temp_model" 101 | temp_path = orig_path / "temp_model.json" 102 | 103 | temp_path.parent.mkdir( 104 | parents=True, exist_ok=True 105 | ) # maybe the folder was already created before 106 | temp_path.touch() 107 | 108 | with pytest.raises(FileExistsError): 109 | save_model(reg_model, orig_path, overwrite=False) 110 | 111 | 112 | class TestLoadModel: 113 | """Collection of tests focused on the `load_model` function.""" 114 | 115 | @pytest.mark.parametrize("separate", [True, False]) 116 | @pytest.mark.parametrize("compile", [True, False]) 117 | def test_correct(self, separate, compile, tmpdir, reg_model): 118 | original_compiled = reg_model.optimizer is not None 119 | 120 | temp_path = pathlib.Path(str(tmpdir)) / "temp_model" 121 | save_model(reg_model, temp_path, separate=separate) 122 | 123 | if separate and compile: 124 | with pytest.raises(ValueError): 125 | load_model( 126 | str(temp_path) + "{}".format("" if separate else ".h5"), 127 | compile=compile, 128 | ) 129 | return 130 | 131 | model = load_model( 132 | str(temp_path) + "{}".format("" if separate else ".h5"), compile=compile 133 | ) 134 | 135 | assert isinstance(model, keras.Model) 136 | 137 | if not separate and compile and original_compiled: 138 | assert model.optimizer is not None 139 | else: 140 | assert model.optimizer is None 141 | 142 | assert len(model.get_config()["layers"]) == len( 143 | reg_model.get_config()["layers"] 144 | ) 145 | # assert reg_model.to_json() == model.to_json() # might differ because of lambda layers:D:D 146 | 147 | def test_nonexistent_path(self, tmpdir): 148 | with pytest.raises(OSError): 149 | load_model(tmpdir / "fake") 150 | 151 | def test_ambiguous(self, tmpdir): 152 | path = pathlib.Path(str(tmpdir)) 153 | path_architecture_1 = path / "a_1.json" 154 | path_architecture_2 = path / "a_2.json" 155 | path_weights = path / "w_1.h5" 156 | 157 | path_architecture_1.touch() 158 | path_architecture_2.touch() 159 | path_weights.touch() 160 | 161 | with pytest.raises(ValueError): 162 | load_model(path) 163 | 164 | 165 | class TestReplaceLambdaInConfig: 166 | """Collection of tests focused on the `replace_lambda_in_config` method.""" 167 | 168 | # Commenting for now 169 | # @pytest.mark.parametrize("input_format", ["json", "keras", "dict", "path"]) 170 | # @pytest.mark.parametrize("output_format", ["json", "keras", "dict"]) 171 | # def test_identical_results(self, lambda_model, input_format, output_format, tmpdir): 172 | # 173 | # if input_format == "keras": 174 | # input_config = lambda_model 175 | # elif input_format == "json": 176 | # input_config = lambda_model.to_json() 177 | # elif input_format == "dict": 178 | # input_config = lambda_model.get_config() 179 | # elif input_format == "path": 180 | # path_ = pathlib.Path(str(tmpdir)) 181 | # input_config = path_ / "model.json" 182 | # with input_config.open("w") as f_a: 183 | # json.dump(lambda_model.to_json(), f_a) 184 | # else: 185 | # raise ValueError() 186 | # 187 | # output = replace_lambda_in_config(input_config, output_format=output_format) 188 | # 189 | # # check types 190 | # if output_format == "keras": 191 | # assert isinstance(output, keras.Model) 192 | # new_model = output 193 | # elif output_format == "json": 194 | # assert isinstance(output, str) 195 | # new_model = keras.models.model_from_json( 196 | # output, 197 | # ) 198 | # elif output_format == "dict": 199 | # assert isinstance(output, dict) 200 | # new_model = keras.Model.from_config( 201 | # output, 202 | # custom_objects={ 203 | # "ExtractMoving": ExtractMoving, 204 | # "NoOp": NoOp, 205 | # }, 206 | # ) 207 | # 208 | # shape_input = keras.backend.int_shape(lambda_model.input)[1:] 209 | # x = np.random.random((2, *shape_input)) 210 | # 211 | # assert np.allclose(lambda_model.predict(x), new_model.predict(x)) 212 | 213 | def test_incorrect_input(self, lambda_model): 214 | # incorrect input type 215 | with pytest.raises(TypeError): 216 | replace_lambda_in_config(1) 217 | 218 | # incorret input path 219 | with pytest.raises(ValueError): 220 | replace_lambda_in_config(pathlib.Path.cwd() / "fake.wrong") 221 | 222 | # incorrect output type 223 | with pytest.raises(TypeError): 224 | replace_lambda_in_config(lambda_model, output_format="fake") 225 | 226 | def test_incorrect_layer_config(self, lambda_model): 227 | 228 | correct_config = lambda_model.get_config() 229 | missing_class_name = deepcopy(correct_config) 230 | missing_config = deepcopy(correct_config) 231 | invalid_name = deepcopy(correct_config) 232 | 233 | # prep 234 | del missing_class_name["layers"][0]["class_name"] 235 | del missing_config["layers"][1]["config"] 236 | invalid_name["layers"][2]["config"]["name"] = "aaaaa" 237 | 238 | with pytest.raises(KeyError): 239 | replace_lambda_in_config(missing_class_name) 240 | 241 | with pytest.raises(KeyError): 242 | replace_lambda_in_config(missing_config) 243 | 244 | with pytest.raises(KeyError): 245 | replace_lambda_in_config(invalid_name) 246 | -------------------------------------------------------------------------------- /tests/test_nn.py: -------------------------------------------------------------------------------- 1 | """Tests focused on the nn module.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | import numpy as np 23 | import pytest 24 | from tensorflow import keras 25 | 26 | from atlalign.nn import supervised_global_model_factory, supervised_model_factory 27 | 28 | 29 | class TestSupervisedModelFactory: 30 | """Collection of tests focused on the `supervised_model_factory`.""" 31 | 32 | def test_default_construction(self): 33 | """Make sure possible to use with the default setting""" 34 | 35 | model = supervised_model_factory() 36 | assert isinstance(model, keras.Model) 37 | 38 | @pytest.mark.parametrize("compute_inv", [True, False]) 39 | @pytest.mark.parametrize("use_lambda", [True, False]) 40 | def test_use_lambda(self, use_lambda, compute_inv): 41 | """Make sure the `use_lambda` flag is working""" 42 | 43 | losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse") 44 | losses_weights = (1, 1, 1) if compute_inv else (1, 1) 45 | model = supervised_model_factory( 46 | losses=losses, 47 | losses_weights=losses_weights, 48 | compute_inv=compute_inv, 49 | use_lambda=use_lambda, 50 | ) 51 | lambda_list = [x for x in model.layers if isinstance(x, keras.layers.Lambda)] 52 | if use_lambda: 53 | assert lambda_list 54 | else: 55 | assert not lambda_list 56 | 57 | # @pytest.mark.parametrize("compute_inv", [True, False]) 58 | # def test_equivalence(self, compute_inv): 59 | # """Make sure moving Lambda layers does not affect the results.""" 60 | # losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse") 61 | # losses_weights = (1, 1, 1) if compute_inv else (1, 1) 62 | # params = { 63 | # "start_filters": (2,), 64 | # "downsample_filters": (2, 3), 65 | # "middle_filters": (2,), 66 | # "upsample_filters": (2, 3), 67 | # "end_filters": tuple(), 68 | # "compute_inv": compute_inv, 69 | # "losses": losses, 70 | # "losses_weights": losses_weights, 71 | # } 72 | # 73 | # np.random.seed(1337) 74 | # tf.random.set_seed(1337) 75 | # model_with = supervised_model_factory(use_lambda=True, **params) 76 | # np.random.seed(1337) 77 | # tf.random.set_seed(1337) 78 | # model_without = supervised_model_factory(use_lambda=False, **params) 79 | # x = np.random.random((1, 320, 456, 2)) 80 | # pred_with = model_with.predict([x, x] if compute_inv else x) 81 | # pred_without = model_without.predict([x, x] if compute_inv else x) 82 | # 83 | # assert np.allclose(pred_with[0], pred_without[0]) 84 | # assert np.allclose(pred_with[1], pred_without[1]) 85 | # if compute_inv: 86 | # assert np.allclose(pred_with[2], pred_without[2]) 87 | 88 | def test_down_up_samples(self): 89 | """Make sure raises an error if downsamples and upsamples have not the same number of layers""" 90 | with pytest.raises(ValueError): 91 | supervised_model_factory(downsample_filters=(2,), upsample_filters=(2, 3)) 92 | with pytest.raises(ValueError): 93 | supervised_model_factory( 94 | downsample_filters=(2, 2, 2, 2, 2, 2, 2), 95 | upsample_filters=(2, 2, 2, 2, 2, 2, 2), 96 | ) 97 | 98 | 99 | class TestSupervisedGlobalModelFactory: 100 | """Collection of tests focused on the `supervised_model_factory`.""" 101 | 102 | def test_default_construction(self): 103 | """Make sure possible to use with the default setting""" 104 | 105 | model = supervised_global_model_factory() 106 | assert isinstance(model, keras.Model) 107 | 108 | @pytest.mark.parametrize("use_lambda", [True, False]) 109 | def test_use_lambda(self, use_lambda): 110 | """Make sure the `use_lambda` flag is working""" 111 | 112 | model = supervised_global_model_factory(use_lambda=use_lambda) 113 | lambda_list = [x for x in model.layers if isinstance(x, keras.layers.Lambda)] 114 | if use_lambda: 115 | assert lambda_list 116 | else: 117 | assert not lambda_list 118 | 119 | def test_equivalence(self): 120 | """Make sure moving Lambda layers does not affect the results.""" 121 | params = {"filters": (2, 2, 2, 2), "dense_layers": (2,)} 122 | np.random.seed(1337) 123 | model_with = supervised_global_model_factory(use_lambda=True, **params) 124 | np.random.seed(1337) 125 | model_without = supervised_global_model_factory(use_lambda=False, **params) 126 | x = np.random.random((1, 320, 456, 2)) 127 | pred_with = model_with.predict(x) 128 | pred_without = model_without.predict(x) 129 | 130 | assert np.allclose(pred_with[0], pred_without[0]) 131 | assert np.allclose(pred_with[1], pred_without[1]) 132 | -------------------------------------------------------------------------------- /tests/test_non_ml.py: -------------------------------------------------------------------------------- 1 | """ 2 | The package atlalign is a tool for registration of 2D images. 3 | 4 | Copyright (C) 2021 EPFL/Blue Brain Project 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU Lesser General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU Lesser General Public License for more details. 15 | 16 | You should have received a copy of the GNU Lesser General Public License 17 | along with this program. If not, see . 18 | """ 19 | 20 | import numpy as np 21 | import pytest 22 | 23 | from atlalign.base import DisplacementField 24 | from atlalign.non_ml import antspy_registration 25 | 26 | 27 | class TestAntspyRegistration: 28 | """A collection of tests focused on the ANTsPy registration.""" 29 | 30 | @pytest.mark.slow 31 | @pytest.mark.parametrize( 32 | "registration_type", 33 | [ 34 | "Translation", 35 | "Rigid", 36 | "Similarity", 37 | "QuickRigid", 38 | "DenseRigid", 39 | "BOLDRigid", 40 | "Affine", 41 | "AffineFast", 42 | "BOLDAffine", 43 | "TRSAA", 44 | "ElasticSyN", 45 | "SyN", 46 | "SyNRA", 47 | "SyNOnly", 48 | "SyNCC", 49 | "SyNabp", 50 | "SyNBold", 51 | "SyNBoldAff", 52 | "SyNAggro", 53 | "TVMSQ", 54 | ], 55 | ) 56 | def test_each_type(self, tmp_path, img_grayscale, registration_type): 57 | """Make sure that every type of registration gives a valid DisplacementField as output.""" 58 | 59 | moving_img = img_grayscale 60 | fixed_img = img_grayscale 61 | 62 | reg_iterations = (4, 2, 0) 63 | 64 | df_final, meta = antspy_registration( 65 | fixed_img, 66 | moving_img, 67 | registration_type=registration_type, 68 | reg_iterations=reg_iterations, 69 | path=tmp_path, 70 | ) 71 | 72 | assert isinstance(df_final, DisplacementField) 73 | assert df_final.is_valid 74 | 75 | def test_displacement_field(self, tmp_path, img_grayscale): 76 | """Make sure that the displacement extracted can reproduce the registered moving image. 77 | 78 | Done by comparing with the image contained in the registration output. 79 | """ 80 | 81 | fixed_img = img_grayscale 82 | size = fixed_img.shape 83 | df = DisplacementField.generate( 84 | size, approach="affine_simple", translation_x=20, translation_y=20 85 | ) 86 | moving_img = df.warp(img_grayscale) 87 | 88 | df_final, meta = antspy_registration(fixed_img, moving_img, path=tmp_path) 89 | img1 = meta["warpedmovout"].numpy() 90 | img2 = df_final.warp( 91 | moving_img, interpolation="linear", border_mode="constant", c=0 92 | ) 93 | 94 | if img_grayscale.dtype == "uint8": 95 | epsilon = 1 96 | else: 97 | epsilon = 0.005 98 | 99 | assert abs(img1 - img2).mean() < epsilon 100 | 101 | @pytest.mark.todo 102 | @pytest.mark.parametrize( 103 | "registration_type", 104 | [ 105 | "Translation", 106 | "Rigid", 107 | "Similarity", 108 | "QuickRigid", 109 | "DenseRigid", 110 | "BOLDRigid", 111 | "Affine", 112 | "AffineFast", 113 | "BOLDAffine", 114 | "TRSAA", 115 | "ElasticSyN", 116 | "SyN", 117 | "SyNRA", 118 | "SyNOnly", 119 | "SyNCC", 120 | "SyNabp", 121 | ], 122 | ) 123 | def test_same_results( 124 | self, tmp_path, monkeypatch, img_grayscale, registration_type 125 | ): 126 | """Make sure that the registration is reproducible. 127 | 128 | Done by checking if the displacement field extracted are equal if we run the registration with 129 | the same parameters. 130 | 131 | Notes 132 | ----- 133 | Marked as `todo` because we did not find a way how to force ANTsPY to be always reproducible. 134 | """ 135 | 136 | monkeypatch.setenv("ANTS_RANDOM_SEED", "1") 137 | monkeypatch.setenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS", "1") 138 | 139 | size = img_grayscale.shape 140 | p = img_grayscale / np.sum(img_grayscale) 141 | df = DisplacementField.generate(size, approach="paper", p=p, random_state=1) 142 | moving_img = df.warp(img_grayscale) 143 | 144 | df_final1, meta1 = antspy_registration( 145 | img_grayscale, 146 | moving_img, 147 | registration_type=registration_type, 148 | path=tmp_path, 149 | verbose=True, 150 | ) 151 | df_final2, meta2 = antspy_registration( 152 | img_grayscale, 153 | moving_img, 154 | registration_type=registration_type, 155 | path=tmp_path, 156 | verbose=True, 157 | ) 158 | 159 | assert np.allclose(df_final1.delta_x, df_final2.delta_x, atol=1) 160 | assert np.allclose(df_final1.delta_y, df_final2.delta_y, atol=1) 161 | 162 | @pytest.mark.slow 163 | @pytest.mark.parametrize( 164 | "registration_type", 165 | [ 166 | "Translation", 167 | "Rigid", 168 | "Similarity", 169 | "QuickRigid", 170 | "DenseRigid", 171 | "BOLDRigid", 172 | "Affine", 173 | "AffineFast", 174 | "BOLDAffine", 175 | "TRSAA", 176 | "ElasticSyN", 177 | "SyN", 178 | "SyNRA", 179 | "SyNOnly", 180 | "SyNCC", 181 | "SyNabp", 182 | "SyNBold", 183 | "SyNBoldAff", 184 | "SyNAggro", 185 | "TVMSQ", 186 | ], 187 | ) 188 | def test_different_results(self, tmp_path, img_grayscale_uint, registration_type): 189 | """Make sure that the registration is not reproducible if some environment variables not set. 190 | 191 | 192 | The environment variables are ANTS_RANDOM_SEED and ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS. 193 | """ 194 | 195 | size = img_grayscale_uint.shape 196 | df = DisplacementField.generate( 197 | size, approach="affine_simple", translation_x=20 198 | ) 199 | moving_img = df.warp(img_grayscale_uint) 200 | 201 | df_final1, meta1 = antspy_registration( 202 | img_grayscale_uint, moving_img, path=tmp_path 203 | ) 204 | df_final2, meta2 = antspy_registration( 205 | img_grayscale_uint, moving_img, path=tmp_path 206 | ) 207 | 208 | assert not np.allclose(df_final1.delta_x, df_final2.delta_x, atol=0.1) 209 | assert not np.allclose(df_final1.delta_y, df_final2.delta_y, atol=0.1) 210 | 211 | @pytest.mark.todo 212 | def test_different_types( 213 | self, tmp_path, monkeypatch, img_grayscale_uint, img_grayscale_float 214 | ): 215 | """Make sure that the registration does not depend on the type of input images. 216 | 217 | Notes 218 | ----- 219 | Marked as `todo` because we did not find a way how to force ANTsPY to be always reproducible. 220 | """ 221 | 222 | monkeypatch.setenv("ANTS_RANDOM_SEED", "4") 223 | monkeypatch.setenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS", "1") 224 | 225 | size = img_grayscale_uint.shape 226 | p = img_grayscale_uint / np.sum(img_grayscale_uint) 227 | df = DisplacementField.generate(size, approach="paper", p=p) 228 | moving_img_uint = df.warp(img_grayscale_uint) 229 | moving_img_float = df.warp(img_grayscale_float) 230 | 231 | df_final1, meta1 = antspy_registration( 232 | img_grayscale_uint, moving_img_uint, path=tmp_path, verbose=False 233 | ) 234 | df_final2, meta2 = antspy_registration( 235 | img_grayscale_float, moving_img_float, path=tmp_path, verbose=False 236 | ) 237 | 238 | assert np.allclose(df_final1.delta_x, df_final2.delta_x, atol=0.1) 239 | assert np.allclose(df_final1.delta_y, df_final2.delta_y, atol=0.1) 240 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The package atlalign is a tool for registration of 2D images. 3 | 4 | Copyright (C) 2021 EPFL/Blue Brain Project 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU Lesser General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU Lesser General Public License for more details. 15 | 16 | You should have received a copy of the GNU Lesser General Public License 17 | along with this program. If not, see . 18 | """ 19 | 20 | import numpy as np 21 | import pytest 22 | 23 | from atlalign.utils import _find_all_children, find_labels_dic 24 | 25 | 26 | class TestSegmentationConcatenation: 27 | """A set of methods testing the concatenation of labels for the segmentation.""" 28 | 29 | def test_find_children(self, label_dict): 30 | dic = {"id": 2, "children": []} 31 | 32 | assert np.all(_find_all_children(dic) == [2]) 33 | assert np.all(set(_find_all_children(label_dict)) == {2, 3, 4, 5, 6}) 34 | 35 | @pytest.mark.parametrize("depth", [0, 1, 10]) 36 | def test_background(self, depth, label_dict): 37 | shape = (20, 20) 38 | segmentation_array = np.zeros(shape) 39 | new_segmentation_array = np.zeros(shape) 40 | 41 | assert np.all( 42 | find_labels_dic(segmentation_array, label_dict, depth) 43 | == new_segmentation_array 44 | ) 45 | 46 | @pytest.mark.parametrize("depth", [0, 1, 10]) 47 | def test_labels_not_in_tree(self, depth, label_dict): 48 | shape = (20, 20) 49 | segmentation_array = np.ones(shape) * 10 50 | new_segmentation_array = np.ones(shape) * -1 51 | 52 | assert np.all( 53 | find_labels_dic(segmentation_array, label_dict, depth) 54 | == new_segmentation_array 55 | ) 56 | 57 | @pytest.mark.parametrize("label", [2, 3, 4, 5, 6]) 58 | def test_tree_concatenation(self, label, label_dict): 59 | depth = 0 60 | shape = (20, 20) 61 | segmentation_array = np.ones(shape) * label 62 | new_segmentation_array = np.ones(shape) * 2 63 | 64 | assert np.all( 65 | find_labels_dic(segmentation_array, label_dict, depth) 66 | == new_segmentation_array 67 | ) 68 | 69 | @pytest.mark.parametrize("unchanged_label", [2, 3, 4]) 70 | def test_tree_concatenation_unchanged_labels(self, unchanged_label, label_dict): 71 | depth = 1 72 | shape = (20, 20) 73 | segmentation_array = np.ones(shape) * unchanged_label 74 | new_segmentation_array = np.ones(shape) * unchanged_label 75 | 76 | assert np.all( 77 | find_labels_dic(segmentation_array, label_dict, depth) 78 | == new_segmentation_array 79 | ) 80 | 81 | def test_specific_example(self, label_dict): 82 | """A simple 3 x 3 matrix segmentation array.""" 83 | 84 | segmentation_array = np.array([[[1, 2, 3], [20, 14, 50], [4, 5, 6]]]) 85 | 86 | segmentation_array_d0 = np.array([[[-1, 2, 2], [-1, -1, -1], [2, 2, 2]]]) 87 | 88 | segmentation_array_d1 = np.array([[[-1, 2, 3], [-1, -1, -1], [4, 4, 4]]]) 89 | 90 | segmentation_array_d2 = np.array([[[-1, 2, 3], [-1, -1, -1], [4, 5, 6]]]) 91 | 92 | segmentation_array_d3 = np.array([[[-1, 2, 3], [-1, -1, -1], [4, 5, 6]]]) 93 | 94 | assert np.all( 95 | find_labels_dic(segmentation_array, label_dict, 0) == segmentation_array_d0 96 | ) 97 | assert np.all( 98 | find_labels_dic(segmentation_array, label_dict, 1) == segmentation_array_d1 99 | ) 100 | assert np.all( 101 | find_labels_dic(segmentation_array, label_dict, 2) == segmentation_array_d2 102 | ) 103 | assert np.all( 104 | find_labels_dic(segmentation_array, label_dict, 3) == segmentation_array_d3 105 | ) 106 | -------------------------------------------------------------------------------- /tests/test_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | The package atlalign is a tool for registration of 2D images. 3 | 4 | Copyright (C) 2021 EPFL/Blue Brain Project 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU Lesser General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU Lesser General Public License for more details. 15 | 16 | You should have received a copy of the GNU Lesser General Public License 17 | along with this program. If not, see . 18 | """ 19 | 20 | import numpy as np 21 | import pytest 22 | from matplotlib import animation 23 | 24 | from atlalign.base import DisplacementField 25 | from atlalign.visualization import ( 26 | create_animation, 27 | create_segmentation_image, 28 | generate_df_plots, 29 | ) 30 | 31 | 32 | class TestUtils: 33 | """A collection of tests focues on the `utils` module.""" 34 | 35 | def test_wrong_input_dtype(self): 36 | """Make sure that float input dtype not allowed.""" 37 | 38 | shape = (4, 5) 39 | 40 | segmentation_array = np.ones(shape, dtype=np.float32) / 5 41 | 42 | with pytest.raises(TypeError): 43 | create_segmentation_image(segmentation_array) 44 | 45 | def test_correct_output_type(self): 46 | """Make sure that uint8 output dtype.""" 47 | 48 | shape = (4, 5) 49 | 50 | segmentation_array = np.random.randint(100, size=shape) 51 | 52 | segmentation_img, _ = create_segmentation_image(segmentation_array) 53 | 54 | assert segmentation_img.dtype == np.uint8 55 | assert np.all((0 <= segmentation_img) & (segmentation_img < 256)) 56 | 57 | def test_different_classes_different_colors(self): 58 | """Test different classes have different colors.""" 59 | 60 | segmentation_array = np.array([[0, 10], [2, 0]]) 61 | 62 | segmentation_img, _ = create_segmentation_image(segmentation_array) 63 | 64 | assert ( 65 | len( 66 | { 67 | tuple(x) 68 | for x in [ 69 | segmentation_img[0, 0], 70 | segmentation_img[0, 1], 71 | segmentation_img[1, 0], 72 | segmentation_img[1, 1], 73 | ] 74 | } 75 | ) 76 | == 3 77 | ) 78 | 79 | assert np.all(segmentation_img[0, 0] == segmentation_img[1, 1]) 80 | 81 | def test_predefined_colors(self): 82 | """Test possible to pass colors.""" 83 | 84 | segmentation_array = np.array([[0, 1], [2, 22]]) 85 | 86 | colors_dict = {0: (0, 0, 0), 1: (255, 0, 0)} 87 | 88 | segmentation_img, _ = create_segmentation_image(segmentation_array, colors_dict) 89 | 90 | assert np.all(segmentation_img[0, 0] == (0, 0, 0)) 91 | assert np.all(segmentation_img[0, 1] == (255, 0, 0)) 92 | 93 | def test_animation(self, img): 94 | """Possible to generate animations.""" 95 | 96 | df = DisplacementField.generate(img.shape, approach="identity") 97 | 98 | ani = create_animation(df, img) 99 | ani_many = create_animation([df, df], img) 100 | 101 | assert isinstance(ani, animation.Animation) 102 | assert isinstance(ani_many, animation.Animation) 103 | 104 | 105 | class TestGenerateDFPlots: 106 | """Tests focused on the `generate_df_plots` function.""" 107 | 108 | @pytest.mark.parametrize("df_id", [(320, 456)], indirect=True) 109 | def test_basic(self, df_id, tmpdir, monkeypatch): 110 | generate_df_plots(df_id, df_id, tmpdir) 111 | -------------------------------------------------------------------------------- /tests/test_volume.py: -------------------------------------------------------------------------------- 1 | """Collection of tests focused on the `volume` module.""" 2 | 3 | """ 4 | The package atlalign is a tool for registration of 2D images. 5 | 6 | Copyright (C) 2021 EPFL/Blue Brain Project 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU Lesser General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU Lesser General Public License for more details. 17 | 18 | You should have received a copy of the GNU Lesser General Public License 19 | along with this program. If not, see . 20 | """ 21 | 22 | from unittest.mock import MagicMock 23 | 24 | import numpy as np 25 | import pytest 26 | 27 | from atlalign.base import DisplacementField 28 | from atlalign.volume import CoronalInterpolator, GappedVolume, Volume 29 | 30 | 31 | @pytest.fixture() 32 | def minimal_vol(monkeypatch): 33 | sn = [1, 13] 34 | mov_imgs = 2 * [np.zeros((12, 13))] 35 | dvfs = 2 * [DisplacementField.generate((12, 13), approach="identity")] 36 | 37 | nvol_mock = MagicMock() 38 | nvol_mock.__getitem__.return_value = np.zeros((len(mov_imgs), 12, 13, 1)) 39 | monkeypatch.setattr("atlalign.volume.nissl_volume", lambda: nvol_mock) 40 | 41 | return Volume(sn, mov_imgs, dvfs) 42 | 43 | 44 | class TestVolume: 45 | def test_wrong_input(self): 46 | with pytest.raises(ValueError): 47 | Volume([None], [], []) 48 | 49 | with pytest.raises(ValueError): 50 | Volume(["a", "a"], ["a", "a"], ["a", "a"]) 51 | 52 | with pytest.raises(ValueError): 53 | Volume([1111, 12], ["a", "a"], ["a", "a"]) 54 | 55 | def test_construction(self, minimal_vol): 56 | assert isinstance(minimal_vol, Volume) 57 | 58 | def test_getitem(self, minimal_vol): 59 | with pytest.raises(KeyError): 60 | minimal_vol[4] 61 | 62 | outputs = minimal_vol[1] 63 | 64 | assert len(outputs) == 4 65 | assert isinstance(outputs[0], np.ndarray) 66 | assert isinstance(outputs[1], np.ndarray) 67 | assert isinstance(outputs[2], np.ndarray) 68 | assert isinstance(outputs[3], DisplacementField) 69 | 70 | def test_sorted_attributes(self, minimal_vol): 71 | assert isinstance(minimal_vol.sorted_dvfs[0], list) 72 | assert isinstance(minimal_vol.sorted_mov[0], list) 73 | assert isinstance(minimal_vol.sorted_reg[0], list) 74 | assert isinstance(minimal_vol.sorted_ref[0], list) 75 | 76 | 77 | class TestGappedVolume: 78 | def test_incorrect_input(self): 79 | with pytest.raises(ValueError): 80 | GappedVolume([1], []) 81 | 82 | with pytest.raises(ValueError): 83 | GappedVolume([1, 1], [np.zeros((1, 2)), np.zeros((2, 3))]) 84 | 85 | def test_array2list_conversion(self): 86 | sn = [1, 44, 12] 87 | shape = (10, 11) 88 | imgs = np.array([np.zeros(shape) for _ in range(len(sn))]) 89 | 90 | gv = GappedVolume(sn, imgs) 91 | 92 | assert isinstance(gv.sn, list) 93 | assert np.allclose(sn, gv.sn) 94 | assert isinstance(gv.imgs, list) 95 | 96 | 97 | class TestCoronalInterpolator: 98 | @pytest.mark.parametrize( 99 | "kind", ["linear", "nearest", "zero", "slinear", "previous", "next"] 100 | ) 101 | def test_all_kinds(self, kind): 102 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False) 103 | 104 | sn = [0, 527] 105 | imgs = np.zeros((len(sn), 10, 11)) 106 | dummy_gv = GappedVolume(sn, imgs) 107 | 108 | final_volume = ip.interpolate(dummy_gv) 109 | 110 | assert np.allclose(np.zeros((528, *dummy_gv.shape)), final_volume) 111 | 112 | @pytest.mark.parametrize( 113 | "kind", 114 | [ 115 | "linear", 116 | "quadratic", 117 | "cubic", 118 | "nearest", 119 | "zero", 120 | "slinear", 121 | "previous", 122 | "next", 123 | ], 124 | ) 125 | def test_precise_on_known(self, kind): 126 | """Make sure that on the known slices the interpolation is precise.""" 127 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False) 128 | 129 | shape = (10, 11) 130 | sn = list(range(0, 528, 8)) + [527] 131 | imgs = np.random.random((len(sn), *shape)) 132 | 133 | gv = GappedVolume(sn, imgs) 134 | 135 | final_volume = ip.interpolate(gv) 136 | 137 | for i, s in enumerate(sn): 138 | assert np.allclose(imgs[i], final_volume[s]) 139 | 140 | @pytest.mark.parametrize( 141 | "kind", ["linear", "nearest", "zero", "slinear", "previous", "next"] 142 | ) 143 | def test_nan_all(self, kind): 144 | """Make sure that if one input section composed fully of NaN pixels then things work.""" 145 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False) 146 | 147 | shape = (10, 11) 148 | sn = [0, 100, 527] 149 | imgs = [np.zeros(shape), np.ones(shape) * np.nan, np.zeros(shape)] 150 | gv = GappedVolume(sn, imgs) 151 | 152 | final_volume = ip.interpolate(gv) 153 | 154 | assert np.all(np.isfinite(final_volume)) 155 | 156 | @pytest.mark.parametrize( 157 | "kind", ["linear", "nearest", "zero", "slinear", "previous", "next"] 158 | ) 159 | def test_nan_some(self, kind): 160 | """Make sure that if input section has a NaN pixel then things work.""" 161 | ip = CoronalInterpolator(kind=kind, fill_value=0, bounds_error=False) 162 | 163 | shape = (10, 11) 164 | sn = [0, 100, 527] 165 | 166 | valid = np.ones(shape, dtype=bool) 167 | valid[5:9, 2:7] = False 168 | 169 | weird_img = np.random.random(shape) 170 | weird_img[5:9, 2:7] = np.nan 171 | imgs = [np.zeros(shape), weird_img, np.zeros(shape)] 172 | 173 | gv = GappedVolume(sn, imgs) 174 | 175 | final_volume = ip.interpolate(gv) 176 | 177 | assert np.all(np.isfinite(final_volume)) 178 | assert np.allclose(final_volume[sn[1]][valid], weird_img[valid]) 179 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 3.1.0 3 | requires = virtualenv >= 20.0.0 4 | source = atlalign 5 | envlist = 6 | lint 7 | py38 8 | py39 9 | docs 10 | 11 | [testenv] 12 | download = true 13 | deps = 14 | lpips_tf @ git+http://github.com/alexlee-gk/lpips-tensorflow.git#egg=lpips_tf 15 | extras = 16 | dev 17 | commands = 18 | pytest {posargs:tests} 19 | 20 | [testenv:lint] 21 | skip_install = true 22 | deps = 23 | flake8 24 | isort 25 | pydocstyle 26 | black==22.3.0 27 | commands = 28 | flake8 setup.py {[tox]source} tests 29 | isort --honor-noqa --profile=black --check setup.py {[tox]source} tests 30 | pydocstyle {[tox]source} 31 | black --check setup.py {[tox]source} tests 32 | 33 | [testenv:format] 34 | skip_install = true 35 | deps = 36 | isort 37 | black 38 | commands = 39 | isort --honor-noqa --profile=black setup.py {[tox]source} tests 40 | black setup.py {[tox]source} tests 41 | 42 | [testenv:docs] 43 | changedir = docs 44 | extras = 45 | dev 46 | docs 47 | allowlist_externals = make 48 | commands = 49 | make clean 50 | make doctest SPHINXOPTS=-W 51 | make html SPHINXOPTS=-W 52 | 53 | [pytest] 54 | addopts = 55 | -v 56 | -m "not todo and not slow and not internet" 57 | --disable-warnings 58 | --strict 59 | --cov=atlalign 60 | --cov-report=term-missing 61 | testpaths = tests 62 | markers = 63 | internet: requires connection to the internet 64 | slow: mark denoting a test that is too slow 65 | todo: mark denoting a test that is not written yet 66 | 67 | [flake8] 68 | count = True 69 | max-line-length = 120 70 | ignore = E402, W503, E203 71 | 72 | [pydocstyle] 73 | convention = numpy 74 | 75 | [gh-actions] 76 | python = 77 | 3.6: py36 78 | 3.7: py37 79 | 3.8: py38 80 | 3.9: py39 81 | --------------------------------------------------------------------------------