├── .github └── workflows │ ├── build.yml │ ├── ci.yml │ ├── pypi_publish.yml │ ├── release-drafter.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── checkpoints ├── __init__.py └── halfres_512_checkpoint_160000.h5 ├── configs ├── decidua_split.json ├── hickey_split.json ├── msk_colon_split.json ├── msk_pancreas_split.json ├── params.toml └── tonic_split.json ├── conftest.py ├── pyproject.toml ├── scripts ├── evaluation_script.py └── hyperparameter_search.py ├── src ├── cell_classification │ ├── __init__.py │ ├── application.py │ ├── augmentation_pipeline.py │ ├── inference.py │ ├── loss.py │ ├── metrics.py │ ├── model_builder.py │ ├── plot_utils.py │ ├── post_processing.py │ ├── prepare_data_script_MSK_colon.py │ ├── prepare_data_script_MSK_pancreas.py │ ├── prepare_data_script_decidua.py │ ├── prepare_data_script_tonic.py │ ├── prepare_data_script_tonic_annotated.py │ ├── promix_naive.py │ ├── segmentation_data_prep.py │ ├── simple_data_prep.py │ ├── unet.py │ └── viewer_widget.py └── deepcell │ ├── __init__.py │ ├── application.py │ ├── backbone_utils.py │ ├── deepcell_toolbox.py │ ├── fpn.py │ ├── layers.py │ ├── panopticnet.py │ ├── semantic_head.py │ └── utils.py ├── templates ├── 1_Nimbus_Predict.ipynb └── 2_Generic_Cell_Clustering.ipynb ├── tests ├── __init__.py ├── application_test.py ├── augmentation_pipeline_test.py ├── deepcell_application_test.py ├── deepcell_backbone_utils_test.py ├── deepcell_layers_test.py ├── deepcell_panopticnet_test.py ├── deepcell_toolbox_test.py ├── deepcell_utils_test.py ├── inference_test.py ├── loss_test.py ├── metrics_test.py ├── model_builder_test.py ├── plot_utils_test.py ├── post_processing_test.py ├── promix_naive_test.py ├── segmentation_data_prep_test.py ├── simple_data_prep_test.py ├── unet_test.py └── viewer_widget_test.py └── tox.ini /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Wheel Builder 2 | 3 | on: 4 | workflow_call: 5 | 6 | permissions: 7 | contents: read # to fetch code (actions/checkout) 8 | 9 | jobs: 10 | build: 11 | name: Pure Python Wheel and Source Distribution 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout ${{ github.repository }} 16 | uses: actions/checkout@v3 17 | with: 18 | fetch-depth: 0 19 | 20 | - name: Build Wheels 21 | run: pipx run build 22 | 23 | - name: Check wheel Metadata 24 | run: pipx run twine check dist/* 25 | 26 | - name: Store Wheel Artifacts 27 | uses: actions/upload-artifact@v3 28 | with: 29 | name: distributions 30 | path: dist/*.whl 31 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | types: [labeled, opened, synchronize, reopened] 9 | workflow_dispatch: 10 | merge_group: 11 | types: [checks_requested] 12 | branches: [main] 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 16 | cancel-in-progress: true 17 | 18 | permissions: 19 | contents: read # to fetch code (actions/checkout) 20 | 21 | jobs: 22 | test: 23 | name: Test 24 | permissions: 25 | contents: read 26 | pull-requests: write 27 | secrets: inherit 28 | uses: ./.github/workflows/test.yml 29 | 30 | build: 31 | name: Build 32 | permissions: 33 | contents: read 34 | pull-requests: write 35 | secrets: inherit 36 | uses: ./.github/workflows/build.yml 37 | 38 | upload_coverage: 39 | needs: [test] 40 | name: Upload Coverage 41 | runs-on: ubuntu-latest 42 | steps: 43 | - name: Checkout ${{github.repository }} 44 | uses: actions/checkout@v3 45 | with: 46 | fetch-depth: 0 47 | 48 | - name: Download Coverage Artifact 49 | uses: actions/download-artifact@v3 50 | # if `name` is not specified, all artifacts are downloaded. 51 | 52 | - name: Upload Coverage to Coveralls 53 | uses: coverallsapp/github-action@v2 54 | with: 55 | github-token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/pypi_publish.yml: -------------------------------------------------------------------------------- 1 | name: Build Wheels and upload to PyPI 2 | 3 | on: 4 | pull_request: 5 | branches: ["releases/**"] 6 | types: [labeled, opened, synchronize, reopened] 7 | release: 8 | types: [published] 9 | workflow_dispatch: 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 13 | cancel-in-progress: true 14 | 15 | permissions: 16 | contents: read # to fetch code (actions/checkout) 17 | 18 | jobs: 19 | test: 20 | name: Test 21 | permissions: 22 | contents: read 23 | secrets: inherit 24 | uses: ./.github/workflows/test.yml 25 | 26 | build_wheels_sdist: 27 | needs: [test] 28 | name: Build 29 | uses: ./.github/workflows/build.yml 30 | secrets: inherit 31 | 32 | test_pypi_publish: 33 | # Test PyPI publish, requires wheels and source dist (sdist) 34 | name: Publish ${{ github.repository }} to TestPyPI 35 | needs: [test, build_wheels_sdist] 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/download-artifact@v3 39 | with: 40 | name: distributions 41 | path: dist 42 | 43 | - uses: pypa/gh-action-pypi-publish@release/v1.6 44 | with: 45 | user: __token__ 46 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 47 | repository_url: https://test.pypi.org/legacy/ 48 | packages_dir: dist/ 49 | verbose: true 50 | 51 | pypi_publish: 52 | name: Publish ${{ github.repository }} to to PyPI 53 | needs: [test, build_wheels_sdist, test_pypi_publish] 54 | 55 | runs-on: ubuntu-latest 56 | # Publish when a GitHub Release is created, use the following rule: 57 | if: github.event_name == 'release' && github.event.action == 'published' 58 | steps: 59 | - name: Download Artifact 60 | uses: actions/download-artifact@v3 61 | with: 62 | name: distributions 63 | path: dist 64 | 65 | - name: PYPI Publish 66 | uses: pypa/gh-action-pypi-publish@release/v1.6 67 | with: 68 | user: __token__ 69 | password: ${{ secrets.PYPI_API_TOKEN }} 70 | packages_dir: dist/ 71 | verbose: true 72 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | # branches to consider in the event; optional, defaults to all 6 | branches: ["main"] 7 | # pull_request event is required only for autolabeler 8 | pull_request: 9 | # Only following types are handled by the action, but one can default to all as well 10 | types: [opened, reopened, synchronize] 11 | 12 | issues: 13 | types: [closed] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | update_release_draft: 20 | runs-on: ubuntu-latest 21 | permissions: 22 | contents: write # for release-drafter/release-drafter to create a github release 23 | pull-requests: write # for release-drafter/release-drafter to add label to PR 24 | steps: 25 | # Drafts your next Release notes as Pull Requests are merged into "master" 26 | - uses: release-drafter/release-drafter@v5.20.1 27 | # Specify config name to use, relative to .github/. Default: release-drafter.yml 28 | with: 29 | config-name: release-drafter.yml 30 | # disable-autolabeler: true 31 | env: 32 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 33 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_call: 5 | 6 | permissions: 7 | contents: read # to fetch code (actions/checkout) 8 | jobs: 9 | test: 10 | name: ${{ github.repository }} 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout ${{ github.repository }} 15 | uses: actions/checkout@v3 16 | with: 17 | fetch-depth: 0 18 | 19 | - name: Set up Python 3.9 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: 3.9 23 | cache-dependency-path: "**/pyproject.toml" 24 | cache: "pip" 25 | 26 | - name: Install Dependencies and ${{ github.repository }} 27 | run: | 28 | pip install .[test] 29 | 30 | - name: Run Tests 31 | run: | 32 | pytest 33 | 34 | - name: Archive Coverage 35 | uses: actions/upload-artifact@v3 36 | with: 37 | name: coverage 38 | path: | 39 | coverage.lcov 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .vscode 162 | .DS_Store 163 | coverage.lcov 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Modified Apache License 2 | Version 2.0, January 2004 3 | 4 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 5 | 6 | 1. Definitions. 7 | 8 | "License" shall mean the terms and conditions for use, reproduction, 9 | and distribution as defined by Sections 1 through 9 of this document. 10 | 11 | "Licensor" shall mean the copyright owner or entity authorized by 12 | the copyright owner that is granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all 15 | other entities that control, are controlled by, or are under common 16 | control with that entity. For the purposes of this definition, 17 | "control" means (i) the power, direct or indirect, to cause the 18 | direction or management of such entity, whether by contract or 19 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 20 | outstanding shares, or (iii) beneficial ownership of such entity. 21 | 22 | "You" (or "Your") shall mean an individual or Legal Entity 23 | exercising permissions granted by this License. 24 | 25 | "Source" form shall mean the preferred form for making modifications, 26 | including but not limited to software source code, documentation 27 | source, and configuration files. 28 | 29 | "Object" form shall mean any form resulting from mechanical 30 | transformation or translation of a Source form, including but 31 | not limited to compiled object code, generated documentation, 32 | and conversions to other media types. 33 | 34 | "Work" shall mean the work of authorship, whether in Source or 35 | Object form, made available under the License, as indicated by a 36 | copyright notice that is included in or attached to the work 37 | (an example is provided in the Appendix below). 38 | 39 | "Derivative Works" shall mean any work, whether in Source or Object 40 | form, that is based on (or derived from) the Work and for which the 41 | editorial revisions, annotations, elaborations, or other modifications 42 | represent, as a whole, an original work of authorship. For the purposes 43 | of this License, Derivative Works shall not include works that remain 44 | separable from, or merely link (or bind by name) to the interfaces of, 45 | the Work and Derivative Works thereof. 46 | 47 | "Contribution" shall mean any work of authorship, including 48 | the original version of the Work and any modifications or additions 49 | to that Work or Derivative Works thereof, that is intentionally 50 | submitted to Licensor for inclusion in the Work by the copyright owner 51 | or by an individual or Legal Entity authorized to submit on behalf of 52 | the copyright owner. For the purposes of this definition, "submitted" 53 | means any form of electronic, verbal, or written communication sent 54 | to the Licensor or its representatives, including but not limited to 55 | communication on electronic mailing lists, source code control systems, 56 | and issue tracking systems that are managed by, or on behalf of, the 57 | Licensor for the purpose of discussing and improving the Work, but 58 | excluding communication that is conspicuously marked or otherwise 59 | designated in writing by the copyright owner as "Not a Contribution." 60 | 61 | "Contributor" shall mean Licensor and any individual or Legal Entity 62 | on behalf of whom a Contribution has been received by Licensor and 63 | subsequently incorporated within the Work. 64 | 65 | 2. Grant of Copyright License. Subject to the terms and conditions of 66 | this License, each Contributor hereby grants to You a non-commercial, 67 | academic perpetual, worldwide, non-exclusive, no-charge, royalty-free, 68 | irrevocable copyright license to reproduce, prepare Derivative Works 69 | of, publicly display, publicly perform, sublicense, and distribute the 70 | Work and such Derivative Works in Source or Object form. For any other 71 | use, including commercial use, please contact: mangelo0@stanford.edu. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a non-commercial, 75 | academic perpetual, worldwide, non-exclusive, no-charge, royalty-free, 76 | irrevocable (except as stated in this section) patent license to make, 77 | have made, use, offer to sell, sell, import, and otherwise transfer the 78 | Work, where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | 10. Neither the name of Stanford nor the names of its contributors may be 177 | used to endorse or promote products derived from this software without 178 | specific prior written permission. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright [yyyy] [name of copyright owner] 194 | 195 | Licensed under the Apache License, Version 2.0 (the "License"); 196 | you may not use this file except in compliance with the License. 197 | You may obtain a copy of the License at 198 | 199 | http://www.apache.org/licenses/LICENSE-2.0 200 | 201 | Unless required by applicable law or agreed to in writing, software 202 | distributed under the License is distributed on an "AS IS" BASIS, 203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 204 | See the License for the specific language governing permissions and 205 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nimbus 2 | 3 | The Nimbus repo contains code for training and validation of a machine learning model that classifies cells into marker positive/negative for arbitrary markers and different imaging platforms. 4 | 5 | The code for using the model and running inference on your own data can be found here: [Nimbus-Inference](https://github.com/angelolab/Nimbus-Inference). Code for generating the figures in the paper can be found here: [Publication plots](https://github.com/angelolab/publications/tree/main/2024-Rumberger_Greenwald_etal_Nimbus). 6 | 7 | ## Installation instructions 8 | 9 | Clone the repository 10 | 11 | `git clone https://github.com/angelolab/Nimbus.git` 12 | 13 | 14 | Make a conda environment for Nimbus and activate it 15 | 16 | `conda create -n Nimbus python==3.10` 17 | 18 | `conda activate Nimbus` 19 | 20 | Install CUDA libraries if you have a NVIDIA GPU available 21 | 22 | `conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0` 23 | 24 | Install the package and all depedencies in the conda environment 25 | 26 | `python -m pip install -e Nimbus` 27 | 28 | Install tensorflow-metal if you have an Apple Silicon GPU 29 | 30 | `python -m pip install tensorflow-metal` 31 | 32 | Navigate to the example notebooks and start jupyter 33 | 34 | `cd Nimbus/templates` 35 | 36 | `jupyter notebook` 37 | 38 | ## Citation 39 | 40 | ```bash 41 | @article{rum2024nimbus, 42 | title={Automated classification of cellular expression in multiplexed imaging data with Nimbus}, 43 | author={Rumberger, J. Lorenz and Greenwald, Noah F. and Ranek, Jolene S. and Boonrat, Potchara and Walker, Cameron and Franzen, Jannik and Varra, Sricharan Reddy and Kong, Alex and Sowers, Cameron and Liu, Candace C. and Averbukh, Inna and Piyadasa, Hadeesha and Vanguri, Rami and Nederlof, Iris and Wang, Xuefei Julie and Van Valen, David and Kok, Marleen and Hollman, Travis J. and Kainmueller, Dagmar and Angelo, Michael}, 44 | journal={bioRxiv}, 45 | pages={2024--05}, 46 | year={2024}, 47 | publisher={Cold Spring Harbor Laboratory} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /checkpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/checkpoints/__init__.py -------------------------------------------------------------------------------- /checkpoints/halfres_512_checkpoint_160000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/checkpoints/halfres_512_checkpoint_160000.h5 -------------------------------------------------------------------------------- /configs/decidua_split.json: -------------------------------------------------------------------------------- 1 | {"test": ["12_31750_16_12", "14_31758_20_4", "16_31762_16_21", "12_31750_1_10", "14_31759_20_6", "6_31728_15_5", "16_31773_4_9", "20_31797_6_9", "6_31731_11_8", "12_31754_16_14", "10_31745_16_9", "16_31770_14_16", "20_31796_6_8", "10_31745_1_6", "6_31731_11_7", "20_31798_6_10", "16_31771_14_14", "16_31767_3_3", "12_31748_1_8", "6_31729_9_4", "8_31736_12_18", "16_31774_14_13", "18_31783_14_11"], "validation": ["8_31734_12_5", "12_31751_1_12", "6_31730_10_10", "6_31729_10_5", "6_31726_8_8", "10_31747_1_7", "6_31730_10_6", "8_31737_12_22", "16_31771_4_3", "6_31733_11_15", "6_31728_9_3", "8_31734_16_1", "16_31770_14_15", "6_31728_9_1", "6_31731_11_5", "8_31734_12_3", "6_31729_10_4", "18_31785_5_14", "20_31793_13_9", "6_31731_15_7", "6_31732_11_10", "18_31783_5_8", "8_31734_12_1"], "train": ["8_31736_12_16", "12_31750_1_11", "6_31728_9_2", "8_31736_12_19", "6_31726_8_5", "14_31755_20_1", "6_31730_18_15", "10_31744_1_3", "18_31776_14_12", "6_31733_11_16", "16_31774_4_12", "14_31759_16_18", "20_31793_6_3", "6_31728_15_4", "20_31793_13_10", "16_31772_4_6", "18_31784_5_12", "20_31787_5_16", "20_31786_14_5", "14_31758_20_5", "6_31727_8_9", "20_31794_13_7", "6_31733_11_14", "12_31754_1_16", "6_31730_10_9", "8_31737_6_18", "16_31768_3_6", "14_31761_20_7", "10_31740_6_11", "8_31735_12_9", "12_31749_16_11", "6_31733_11_17", "20_31789_5_19", "6_31731_11_2", "6_31730_10_7", "10_31741_1_1", "10_31739_6_16", "16_31767_17_1", "18_31778_5_3", "6_31728_15_6", "16_31773_4_8", "20_31788_5_17", "16_31770_4_2", "8_31738_6_17", "8_31735_12_6", "6_31728_8_13", "14_31758_16_17", "12_31754_18_2", "18_31782_5_6", "6_31731_11_4", "12_31754_16_13", "6_31729_10_1", "6_31725_8_1", "6_31726_8_7", "6_31732_11_9", "8_31734_12_2", "10_31739_6_14", "6_31732_11_12", "6_31729_10_2", "20_31791_18_13", "6_31726_8_4", "8_31735_12_14", "10_31739_6_15", "16_31768_17_2", "8_31737_12_23", "18_31783_5_11", "16_31770_4_1", "20_31797_13_3", "20_31793_6_2", "8_31735_12_8", "18_31776_5_2", "16_31772_4_5", "16_31762_20_9", "20_31792_18_14", "18_31784_14_6", "16_31768_3_5", "8_31736_16_3", "6_31727_8_10", "20_31794_6_6", "18_31785_5_13", "20_31791_13_12", "6_31732_11_11", "8_31736_12_17", "18_31785_5_15", "18_31783_5_9", "20_31789_14_1", "20_31791_13_11", "14_31761_16_20", "6_31730_11_1", "16_31763_20_10", "8_31736_16_4", "14_31756_16_16", "20_31793_13_8", "12_31749_1_9", "6_31725_15_1", "8_31735_12_12", "6_31732_15_9", "6_31725_15_2", "20_31795_13_5", "12_31754_1_14", "6_31733_15_10", "8_31735_12_7", "10_31745_16_10", "20_31789_14_2", "20_31798_13_2", "20_31794_6_5", "18_31782_5_7", "16_31772_4_4", "6_31725_8_2", "18_31783_14_8", "6_31730_10_11", "16_31769_3_7", "16_31765_20_11", "20_31795_13_4", "12_31752_1_13", "6_31729_10_3", "14_31756_20_2", "10_31742_1_2", "8_31735_12_11", "10_31745_1_4", "18_31783_14_9", "20_31795_6_7", "20_31787_14_4", "18_31780_5_5", "6_31727_8_12", "16_31773_4_10", "12_31754_1_15", "20_31787_14_3", "20_31798_13_1", "16_31762_20_8", "20_31790_13_14", "6_31725_8_3", "6_31731_15_8", "6_31727_8_11", "6_31733_15_12", "8_31734_11_18", "16_31769_3_8", "16_31765_3_1", "6_31727_15_3", "6_31726_8_6", "10_31740_6_12", "18_31783_14_7", "6_31733_15_11", "10_31745_1_5", "16_31765_19_1", "6_31730_10_8", "10_31743_16_8", "14_31755_18_1", "16_31769_18_11", "18_31779_5_4", "16_31766_3_2", "6_31731_11_6", "8_31735_12_13", "8_31734_12_4", "16_31772_4_7", "20_31793_6_1", "16_31775_5_1", "8_31737_12_20", "14_31760_16_19", "20_31789_5_20", "16_31774_4_11", "20_31791_13_13", "8_31736_12_15", "8_31736_16_5", "10_31740_16_6", "12_31754_16_15", "20_31788_5_18", "8_31734_16_2", "10_31743_16_7", "14_31757_20_3", "8_31737_12_21", "20_31794_13_6", "18_31783_14_10", "6_31731_11_3", "6_31733_11_13", "20_31790_18_12", "20_31794_6_4", "8_31735_12_10", "18_31783_5_10", "10_31739_6_13", "16_31767_3_4", "6_31730_10_12"]} -------------------------------------------------------------------------------- /configs/hickey_split.json: -------------------------------------------------------------------------------- 1 | {"test": ["B010A_reg003_X01_Y01_Z01", "B011B_reg003_X01_Y01_Z01", "B011B_reg001_X01_Y01_Z01"], "validation": ["B011A_reg001_X01_Y01_Z01", "B012A_reg003_X01_Y01_Z01", "B011A_reg002_X01_Y01_Z01"], "train": ["B010B_reg004_X01_Y01_Z01", "B011A_reg003_X01_Y01_Z01", "B012A_reg004_X01_Y01_Z01", "B009A_reg003_X01_Y01_Z01", "B009A_reg002_X01_Y01_Z01", "B010B_reg003_X01_Y01_Z01", "B010B_reg002_X01_Y01_Z01", "B012B_reg001_X01_Y01_Z01", "B009A_reg004_X01_Y01_Z01", "B011B_reg002_X01_Y01_Z01", "B012A_reg001_X01_Y01_Z01", "B012B_reg004_X01_Y01_Z01", "B011B_reg004_X01_Y01_Z01", "B012B_reg003_X01_Y01_Z01", "B010B_reg001_X01_Y01_Z01", "B010A_reg001_X01_Y01_Z01", "B009B_reg002_X01_Y01_Z01", "B011A_reg004_X01_Y01_Z01", "B009B_reg001_X01_Y01_Z01", "B009B_reg003_X01_Y01_Z01", "B012B_reg002_X01_Y01_Z01", "B009B_reg004_X01_Y01_Z01", "B009A_reg001_X01_Y01_Z01", "B010A_reg004_X01_Y01_Z01", "B012A_reg002_X01_Y01_Z01", "B010A_reg002_X01_Y01_Z01"]} -------------------------------------------------------------------------------- /configs/params.toml: -------------------------------------------------------------------------------- 1 | record_path = "C:/Users/lorenz/Desktop/angelo_lab/MIBI_test/TNBC_CD45.tfrecord" 2 | path = "C:/Users/lorenz/OneDrive/Desktop/angelo_lab/" 3 | experiment = "test" 4 | project = "Nimbus" 5 | logging_mode = "offline" 6 | model = "ModelBuilder" 7 | num_steps = 20 8 | lr = 1e-3 9 | backbone = "resnet50" 10 | dataset_names = ["TNBC_CD45"] 11 | dataset_sample_probs = [1.0] 12 | input_shape = [256,256,4] 13 | batch_constituents = ["mplex_img", "binary_mask", "nuclei_img", "membrane_img"] 14 | num_validation = [3] 15 | num_test = [3] 16 | shuffle_buffer_size = 2000 17 | flip_prob = 0.5 18 | affine_prob = 0.5 19 | scale_min = 0.8 20 | scale_max = 1.2 21 | shear_angle = 0.2 22 | elastic_prob = 0.5 23 | elastic_alpha = 10 24 | elastic_sigma = 4 25 | rotate_prob = 0.5 26 | rotate_count = 4 27 | gaussian_noise_prob = 0.5 28 | gaussian_noise_min = 0.05 29 | gaussian_noise_max = 0.15 30 | gaussian_blur_prob = 0.5 31 | gaussian_blur_min = 0.05 32 | gaussian_blur_max = 0.15 33 | contrast_prob = 0.5 34 | contrast_min = 0.8 35 | contrast_max = 1.2 36 | mixup_prob = 0.5 37 | mixup_alpha = 4.0 38 | batch_size = 4 39 | loss_fn = "BinaryCrossentropy" 40 | loss_selective_masking = true 41 | quantile = 0.5 42 | quantile_end = 0.9 43 | quantile_warmup_steps = 100000 44 | confidence_thresholds = [0.1, 0.9] 45 | ema = 0.01 46 | location = false 47 | [loss_kwargs] 48 | from_logits = false 49 | label_smoothing = 0.1 50 | 51 | [classes] 52 | marker_positive = 1 53 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterator 2 | import numpy as np 3 | import pytest 4 | import toml 5 | import os 6 | import sys 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "src")) 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "tests")) 11 | 12 | @pytest.fixture(scope="function") 13 | def config_params() -> Iterator[Dict]: 14 | params: Dict[str, Any] = toml.load("./configs/params.toml") 15 | yield params 16 | 17 | @pytest.fixture(scope="function") 18 | def rng() -> Iterator[np.random.Generator]: 19 | rng_ = np.random.default_rng(seed=42) 20 | yield rng_ -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | "setuptools_scm[toml]>=6.2", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [project] 10 | dependencies = [ 11 | "numpy>=1.20", 12 | "ark-analysis>=0.6.0", 13 | "spektral>=1", 14 | "scikit-image>=0.16", 15 | "scikit-learn>=1.2.2", 16 | "tifffile>=2023", 17 | "tqdm>=4", 18 | "imgaug>=0.4.0", 19 | "opencv-python>=4.6", 20 | "toml", 21 | "seaborn>=0.12", 22 | "alpineer>=0.1.5", 23 | "natsort>=7.1", 24 | "tensorflow~=2.8.0", 25 | "tensorflow_addons~=0.16.1", 26 | "pydot>=1.4.2,<2", 27 | "protobuf", 28 | "wandb" 29 | ] 30 | name = "cell_classification" 31 | authors = [{ name = "Angelo Lab", email = "theangelolab@gmail.com" }] 32 | description = "Cell classification tool for classifying cells into marker positive and negative for arbitrary markers." 33 | readme = "README.md" 34 | requires-python = ">=3.9" 35 | license = { text = "Modified Apache License 2.0" } 36 | classifiers = [ 37 | "Development Status :: 4 - Beta", 38 | "Programming Language :: Python :: 3", 39 | "Programming Language :: Python :: 3.9", 40 | "License :: OSI Approved :: Apache Software License", 41 | "Topic :: Scientific/Engineering :: Bio-Informatics", 42 | "Topic :: Scientific/Engineering :: Image Processing", 43 | ] 44 | dynamic = ["version"] 45 | urls = { repository = "https://github.com/angelolab/Nimbus" } 46 | 47 | [project.optional-dependencies] 48 | test = [ 49 | "attrs", 50 | "coveralls[toml]", 51 | "pytest", 52 | "pytest-cases", 53 | "pytest-cov", 54 | "pytest-mock", 55 | "pytest-pycodestyle", 56 | "pytest-randomly", 57 | ] 58 | 59 | [tool.setuptools_scm] 60 | version_scheme = "release-branch-semver" 61 | local_scheme = "no-local-version" 62 | 63 | # Coverage 64 | [tool.coverage.paths] 65 | source = ["src", "*/site-packages"] 66 | 67 | [tool.coverage.run] 68 | branch = true 69 | source = ["cell_classification"] 70 | 71 | [tool.coverage.report] 72 | exclude_lines = [ 73 | "except ImportError", 74 | "raise AssertionError", 75 | "raise NotImplementedError", 76 | ] 77 | 78 | # Pytest Options 79 | [tool.pytest.ini_options] 80 | filterwarnings = [ 81 | "ignore::DeprecationWarning", 82 | "ignore::PendingDeprecationWarning", 83 | ] 84 | addopts = [ 85 | "-v", 86 | "-s", 87 | "--durations=20", 88 | "--randomly-seed=42", 89 | "--randomly-dont-reorganize", 90 | "--cov=cell_classification", 91 | "--cov-report=lcov", 92 | "--pycodestyle", 93 | ] 94 | console_output_style = "count" 95 | testpaths = ["tests"] 96 | -------------------------------------------------------------------------------- /scripts/evaluation_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import tensorflow as tf 8 | import toml 9 | 10 | from cell_classification.metrics import (average_roc, calc_metrics, calc_roc, 11 | load_model) 12 | from cell_classification.plot_utils import (heatmap_plot, plot_average_roc, 13 | plot_metrics_against_threshold, 14 | plot_together, subset_plots) 15 | from cell_classification.segmentation_data_prep import (feature_description, 16 | parse_dict) 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--model_path", 22 | type=str, 23 | help="Path to model weights", 24 | default=None, 25 | ) 26 | parser.add_argument( 27 | "--params_path", 28 | type=str, 29 | help="Path to model params", 30 | default="E:\\angelo_lab\\test\\params.toml", 31 | ) 32 | parser.add_argument( 33 | "--worst_n", 34 | type=int, 35 | help="Number of worst predictions to plot", 36 | default=20, 37 | ) 38 | parser.add_argument( 39 | "--best_n", 40 | type=int, 41 | help="Number of best predictions to plot", 42 | default=20, 43 | ) 44 | parser.add_argument( 45 | "--split_by_marker", 46 | type=bool, 47 | help="Split best/worst predictions by marker", 48 | default=True, 49 | ) 50 | parser.add_argument( 51 | "--external_datasets", 52 | type=str, 53 | help="List of paths to tfrecord datasets", 54 | nargs='+', 55 | default=[], 56 | ) 57 | args = parser.parse_args() 58 | with open(args.params_path, "r") as f: 59 | params = toml.load(f) 60 | if args.model_path is not None: 61 | params["model_path"] = args.model_path 62 | 63 | model = load_model(params) 64 | datasets = {name: dataset for name, dataset in zip(model.dataset_names, model.test_datasets)} 65 | if hasattr(args, "external_datasets"): 66 | external_datasets = { 67 | os.path.split(external_dataset)[-1].split(".")[0]: external_dataset.replace(",", "") 68 | for external_dataset in args.external_datasets 69 | } 70 | external_datasets = { 71 | key: tf.data.TFRecordDataset(external_datasets[key]) 72 | for key in external_datasets.keys() 73 | } 74 | external_datasets = { 75 | name: dataset.map( 76 | lambda x: tf.io.parse_single_example(x, feature_description), 77 | num_parallel_calls=tf.data.AUTOTUNE, 78 | ) for name, dataset in external_datasets.items() 79 | } 80 | external_datasets = { 81 | name: dataset.map( 82 | parse_dict, num_parallel_calls=tf.data.AUTOTUNE 83 | ) for name, dataset in external_datasets.items() 84 | } 85 | external_datasets = { 86 | name: dataset.batch( 87 | params["batch_size"], drop_remainder=False 88 | ) for name, dataset in external_datasets.items() 89 | } 90 | 91 | datasets.update(external_datasets) 92 | 93 | for name, val_dset in datasets.items(): 94 | params["eval_dir"] = os.path.join(*os.path.split(params["model_path"])[:-1], "eval", name) 95 | os.makedirs(params["eval_dir"], exist_ok=True) 96 | # iterate over datasets 97 | pred_list = model.predict_dataset(val_dset, False) 98 | 99 | # prepare cell_table 100 | activity_list = [] 101 | for pred in pred_list: 102 | activity_df = pred["activity_df"].copy() 103 | for key in ["dataset", "marker", "folder_name"]: 104 | activity_df[key] = [pred[key]]*len(activity_df) 105 | activity_list.append(activity_df) 106 | activity_df = pd.concat(activity_list) 107 | activity_df.to_csv(os.path.join(params["eval_dir"], "pred_cell_table.csv"), index=False) 108 | 109 | # cell level evaluation 110 | roc = calc_roc(pred_list, gt_key="activity", pred_key="pred_activity", cell_level=True) 111 | with open(os.path.join(params["eval_dir"], "roc_cell_lvl.pkl"), "wb") as f: 112 | pickle.dump(roc, f) 113 | 114 | # find index of n worst predictions and save plots of them 115 | roc_df = pd.DataFrame(roc) 116 | if args.split_by_marker: 117 | worst_idx = [] 118 | best_idx = [] 119 | markers = np.unique(roc_df.marker) 120 | for marker in markers: 121 | marker_df = roc_df[roc_df.marker == marker] 122 | sort_idx = np.argsort(marker_df.auc).index 123 | worst_idx.extend(sort_idx[-args.worst_n:]) 124 | best_idx.extend(sort_idx[:args.best_n]) 125 | else: 126 | sort_idx = np.argsort(roc["auc"]) 127 | worst_idx = sort_idx[-args.worst_n:] 128 | best_idx = sort_idx[:args.best_n] 129 | for idx_list, best_worst in [(best_idx, "best"), (worst_idx, "worst")]: 130 | for i, idx in enumerate(idx_list): 131 | pred = pred_list[idx] 132 | plot_together( 133 | pred, keys=["mplex_img", "marker_activity_mask", "prediction"], 134 | save_dir=os.path.join(params["eval_dir"], best_worst + "_predictions"), 135 | save_file="worst_{}_{}_{}.png".format( 136 | i, pred["marker"], pred["dataset"], pred["folder_name"] 137 | ) 138 | ) 139 | 140 | pd.DataFrame(roc).auc 141 | tprs, mean_tprs, fpr, std, mean_thresh = average_roc(roc) 142 | plot_average_roc( 143 | mean_tprs, std, save_dir=params["eval_dir"], save_file="avg_roc_cell_lvl.png" 144 | ) 145 | print("AUC: {}".format(np.mean(roc["auc"]))) 146 | 147 | print("Calculate precision, recall, f1_score and accuracy on the cell level") 148 | avg_metrics = calc_metrics( 149 | pred_list, gt_key="activity", pred_key="pred_activity", cell_level=True 150 | ) 151 | pd.DataFrame(avg_metrics).to_csv( 152 | os.path.join(params["eval_dir"], "cell_metrics.csv"), index=False 153 | ) 154 | 155 | plot_metrics_against_threshold( 156 | avg_metrics, 157 | metric_keys=["precision", "recall", "f1_score", "specificity"], 158 | threshold_key="threshold", 159 | save_dir=params["eval_dir"], 160 | save_file="precision_recall_f1_cell_lvl.png", 161 | ) 162 | 163 | print("Plot activity predictions split by markers and cell types") 164 | subset_plots( 165 | activity_df, subset_list=["marker"], 166 | save_dir=params["eval_dir"], 167 | save_file="split_by_marker.png", 168 | gt_key="activity", 169 | pred_key="pred_activity", 170 | ) 171 | if "cell_type" in activity_df.columns: 172 | subset_plots( 173 | activity_df, subset_list=["cell_type"], 174 | save_dir=params["eval_dir"], 175 | save_file="split_by_cell_type.png", 176 | gt_key="activity", 177 | pred_key="pred_activity", 178 | ) 179 | heatmap_plot( 180 | activity_df, subset_list=["marker"], 181 | save_dir=params["eval_dir"], 182 | save_file="heatmap_split_by_marker.png", 183 | gt_key="activity", 184 | pred_key="pred_activity", 185 | ) 186 | -------------------------------------------------------------------------------- /scripts/hyperparameter_search.py: -------------------------------------------------------------------------------- 1 | from cell_classification.model_builder import ModelBuilder 2 | from cell_classification.metrics import calc_scores 3 | from joblib import Parallel, delayed 4 | from tqdm import tqdm 5 | import numpy as np 6 | import argparse 7 | import toml 8 | import os 9 | 10 | 11 | def hyperparameter_search(params, n_jobs=10, checkpoint=None): 12 | """ 13 | Hyperparameter search for the best pos/neg threshold per marker and dataset 14 | Args: 15 | params: configs comparable to Nimbus/configs/params.toml 16 | n_jobs: number of jobs for parallelization 17 | checkpoint: name of checkpoint 18 | Returns: 19 | """ 20 | optimal_thresh_dict = {} 21 | model = ModelBuilder(params) 22 | model.prep_data() 23 | if checkpoint: 24 | ckpt_path = os.path.join(params["model_dir"], checkpoint) 25 | else: 26 | ckpt_path = None 27 | df = model.predict_dataset_list( 28 | model.validation_datasets, save_predictions=False, ckpt_path=ckpt_path 29 | ) 30 | print("Run hyperparameter search") 31 | thresholds = np.linspace(0, 1, 101) 32 | thresholds = [np.round(thresh, 2) for thresh in thresholds] 33 | for dataset in df.dataset.unique(): 34 | df_subset = df[df.dataset == dataset] 35 | if dataset not in optimal_thresh_dict.keys(): 36 | optimal_thresh_dict[dataset] = {} 37 | for marker in tqdm(df_subset.marker.unique()): 38 | df_subset_marker = df_subset[df_subset.marker == marker] 39 | metrics = Parallel(n_jobs=n_jobs)( 40 | delayed(calc_scores)( 41 | df_subset_marker["activity"].astype(np.int32), df_subset_marker["prediction"], 42 | threshold=thresh 43 | ) for thresh in thresholds 44 | ) 45 | f1_scores = [metric["f1_score"] for metric in metrics] 46 | optimal_thresh_dict[dataset][marker] = thresholds[np.argmax(f1_scores)] 47 | # assign classes based on optimal thresholds 48 | df["pred_class"] = df.apply( 49 | lambda row: 1 if row["prediction"] >= optimal_thresh_dict[row["dataset"]][row["marker"]] else 0, 50 | axis=1 51 | ) 52 | df.to_csv(os.path.join( 53 | params["path"], params["experiment"], "{}_validation_predictions.csv".format(checkpoint) 54 | )) 55 | # save as toml 56 | fpath = os.path.join( 57 | params["path"], params["experiment"], "{}_optimal_thresholds.toml".format(checkpoint) 58 | ) 59 | with open(fpath, "w") as f: 60 | toml.dump(optimal_thresh_dict, f) 61 | return optimal_thresh_dict, df 62 | 63 | 64 | def prepare_testset_predictions(params, optimal_thresh_dict, checkpoint=None): 65 | """ 66 | Prepare testset predictions based on optimal thresholds 67 | Args: 68 | params: configs comparable to Nimbus/configs/params.toml 69 | optimal_thresh_dict: optimal thresholds per marker and dataset 70 | checkpoint: name of checkpoint 71 | Returns: 72 | """ 73 | model = ModelBuilder(params) 74 | model.prep_data() 75 | if checkpoint: 76 | ckpt_path = os.path.join(params["model_dir"], checkpoint) 77 | else: 78 | ckpt_path = None 79 | df = model.predict_dataset_list(model.test_datasets, save_predictions=False, ckpt_path=ckpt_path) 80 | # assign classes based on optimal thresholds 81 | df["pred_class"] = df.apply( 82 | lambda row: 1 if row["prediction"] >= optimal_thresh_dict[row["dataset"]][row["marker"]] else 0, 83 | axis=1 84 | ) 85 | df.to_csv(os.path.join( 86 | params["path"], params["experiment"], "{}_test_predictions.csv".format(checkpoint) 87 | )) 88 | # calculate test set metrics 89 | metrics_dict = {} 90 | for dataset in df.dataset.unique(): 91 | df_subset = df[df.dataset == dataset] 92 | if dataset not in metrics_dict.keys(): 93 | metrics_dict[dataset] = {} 94 | for marker in df_subset.marker.unique(): 95 | df_subset_marker = df_subset[df_subset.marker == marker] 96 | metrics_dict[dataset][marker] = calc_scores( 97 | df_subset_marker["activity"].astype(np.int32), df_subset_marker["prediction"], 98 | threshold=optimal_thresh_dict[dataset][marker] 99 | ) 100 | # calculate per dataset scores and save as toml 101 | dataset_scores = {} 102 | for dataset in metrics_dict.keys(): 103 | dataset_scores[dataset] = { 104 | "precision": np.mean([metrics_dict[dataset][marker]["precision"] for marker in metrics_dict[dataset].keys()]), 105 | "recall": np.mean([metrics_dict[dataset][marker]["recall"] for marker in metrics_dict[dataset].keys()]), 106 | "f1_score": np.mean([metrics_dict[dataset][marker]["f1_score"] for marker in metrics_dict[dataset].keys()]), 107 | "accuracy": np.mean([metrics_dict[dataset][marker]["accuracy"] for marker in metrics_dict[dataset].keys()]), 108 | "specificity": np.mean([metrics_dict[dataset][marker]["specificity"] for marker in metrics_dict[dataset].keys()]) 109 | } 110 | fpath = os.path.join( 111 | params["path"], params["experiment"], "{}_test_scores.toml".format(checkpoint) 112 | ) 113 | with open(fpath, "w") as f: 114 | toml.dump(dataset_scores, f) 115 | return df 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--params", type=str, default="configs/params.toml") 121 | parser.add_argument("--ckpt", type=str, default=None) 122 | args = parser.parse_args() 123 | params = toml.load(args.params) 124 | if not os.path.exists( 125 | os.path.join(params["path"], params["experiment"], "optimal_thresholds.toml") 126 | ): 127 | optimal_thresh_dict,_ = hyperparameter_search(params, checkpoint=args.ckpt) 128 | prepare_testset_predictions(params, optimal_thresh_dict, checkpoint=args.ckpt) 129 | -------------------------------------------------------------------------------- /src/cell_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/src/cell_classification/__init__.py -------------------------------------------------------------------------------- /src/cell_classification/application.py: -------------------------------------------------------------------------------- 1 | from deepcell.panopticnet import PanopticNet 2 | from deepcell.semantic_head import create_semantic_head 3 | from deepcell.application import Application 4 | from alpineer import io_utils 5 | from cell_classification.inference import prepare_normalization_dict, predict_fovs 6 | import cell_classification 7 | from pathlib import Path 8 | from glob import glob 9 | import tensorflow as tf 10 | import numpy as np 11 | import json 12 | import os 13 | 14 | 15 | def nimbus_preprocess(image, **kwargs): 16 | """Preprocess input data for Nimbus model. 17 | Args: 18 | image: array to be processed 19 | Returns: 20 | np.array: processed image array 21 | """ 22 | output = np.copy(image) 23 | if len(image.shape) != 4: 24 | raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape)) 25 | 26 | normalize = kwargs.get('normalize', True) 27 | if normalize: 28 | marker = kwargs.get('marker', None) 29 | normalization_dict = kwargs.get('normalization_dict', {}) 30 | if marker in normalization_dict.keys(): 31 | norm_factor = normalization_dict[marker] 32 | else: 33 | print("Norm_factor not found for marker {}, calculating directly from the image. \ 34 | ".format(marker)) 35 | norm_factor = np.quantile(output[..., 0], 0.999) 36 | # normalize only marker channel in chan 0 not binary mask in chan 1 37 | output[..., 0] /= norm_factor 38 | output = output.clip(0, 1) 39 | return output 40 | 41 | 42 | def nimbus_postprocess(model_output): 43 | return model_output 44 | 45 | 46 | def format_output(model_output): 47 | return model_output[0] 48 | 49 | 50 | def prep_deepcell_naming_convention(deepcell_output_dir): 51 | """Prepares the naming convention for the segmentation data 52 | Args: 53 | deepcell_output_dir (str): path to directory where segmentation data is saved 54 | Returns: 55 | segmentation_naming_convention (function): function that returns the path to the 56 | segmentation data for a given fov 57 | """ 58 | def segmentation_naming_convention(fov_path): 59 | """Prepares the path to the segmentation data for a given fov 60 | Args: 61 | fov_path (str): path to fov 62 | Returns: 63 | seg_path (str): paths to segmentation fovs 64 | """ 65 | fov_name = os.path.basename(fov_path) 66 | return os.path.join( 67 | deepcell_output_dir, fov_name + "_whole_cell.tiff" 68 | ) 69 | return segmentation_naming_convention 70 | 71 | 72 | class Nimbus(Application): 73 | """Nimbus application class for predicting marker activity for cells in multiplexed images. 74 | """ 75 | def __init__( 76 | self, fov_paths, segmentation_naming_convention, output_dir, 77 | save_predictions=True, exclude_channels=[], half_resolution=True, 78 | batch_size=4, test_time_aug=True, input_shape=[1024,1024] 79 | ): 80 | """Initializes a Nimbus Application. 81 | Args: 82 | fov_paths (list): List of paths to fovs to be analyzed. 83 | exclude_channels (list): List of channels to exclude from analysis. 84 | segmentation_naming_convention (function): Function that returns the path to the 85 | segmentation mask for a given fov path. 86 | output_dir (str): Path to directory to save output. 87 | save_predictions (bool): Whether to save predictions. 88 | half_resolution (bool): Whether to run model on half resolution images. 89 | batch_size (int): Batch size for model inference. 90 | test_time_aug (bool): Whether to use test time augmentation. 91 | input_shape (list): Shape of input images. 92 | """ 93 | self.fov_paths = fov_paths 94 | self.exclude_channels = exclude_channels 95 | self.segmentation_naming_convention = segmentation_naming_convention 96 | self.output_dir = output_dir 97 | self.half_resolution = half_resolution 98 | self.save_predictions = save_predictions 99 | self._batch_size = batch_size 100 | self.checked_inputs = False 101 | self.test_time_aug = test_time_aug 102 | self.input_shape = input_shape 103 | # exclude segmentation channel from analysis 104 | seg_name = os.path.basename(self.segmentation_naming_convention(self.fov_paths[0])) 105 | self.exclude_channels.append(seg_name.split(".")[0]) 106 | if self.output_dir != '': 107 | os.makedirs(self.output_dir, exist_ok=True) 108 | 109 | # initialize model and parent class 110 | self.initialize_model() 111 | 112 | super(Nimbus, self).__init__( 113 | model=self.model, 114 | model_image_shape=self.model.input_shape[1:], 115 | preprocessing_fn=nimbus_preprocess, 116 | postprocessing_fn=nimbus_postprocess, 117 | format_model_output_fn=format_output, 118 | ) 119 | 120 | def check_inputs(self): 121 | """ check inputs for Nimbus model 122 | """ 123 | # check if all paths in fov_paths exists 124 | io_utils.validate_paths(self.fov_paths) 125 | 126 | # check if segmentation_naming_convention returns valid paths 127 | path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0]) 128 | if not os.path.exists(path_to_segmentation): 129 | raise FileNotFoundError("Function segmentation_naming_convention does not return valid\ 130 | path. Segmentation path {} does not exist."\ 131 | .format(path_to_segmentation)) 132 | # check if output_dir exists 133 | io_utils.validate_paths([self.output_dir]) 134 | 135 | if isinstance(self.exclude_channels, str): 136 | self.exclude_channels = [self.exclude_channels] 137 | self.checked_inputs = True 138 | print("All inputs are valid.") 139 | 140 | def initialize_model(self): 141 | """Initializes the model and load weights. 142 | """ 143 | backbone = "efficientnetv2bs" 144 | input_shape = self.input_shape + [2] 145 | model = PanopticNet( 146 | backbone=backbone, input_shape=input_shape, 147 | norm_method="std", num_semantic_classes=[1], 148 | create_semantic_head=create_semantic_head, location=False, 149 | ) 150 | # make sure path can be resolved on any OS and when importing from anywhere 151 | self.checkpoint_path = os.path.normpath( 152 | "../cell_classification/checkpoints/halfres_512_checkpoint_160000.h5" 153 | ) 154 | if not os.path.exists(self.checkpoint_path): 155 | path = os.path.abspath(cell_classification.__file__) 156 | path = Path(path).resolve() 157 | self.checkpoint_path = os.path.join( 158 | *path.parts[:-3], 'checkpoints', 'halfres_512_checkpoint_160000.h5' 159 | ) 160 | if not os.path.exists(self.checkpoint_path): 161 | self.checkpoint_path = os.path.abspath(*glob('**/halfres_512_checkpoint_160000.h5')) 162 | 163 | if not os.path.exists(self.checkpoint_path): 164 | self.checkpoint_path = os.path.join( 165 | os.getcwd(), 'checkpoints', 'halfres_512_checkpoint_160000.h5' 166 | ) 167 | 168 | if os.path.exists(self.checkpoint_path): 169 | model.load_weights(self.checkpoint_path) 170 | print("Loaded weights from {}".format(self.checkpoint_path)) 171 | else: 172 | raise FileNotFoundError("Could not find Nimbus weights at {ckpt_path}. \ 173 | Current path is {current_path} and directory contains {dir_c},\ 174 | path to cell_clasification i{p}".format( 175 | ckpt_path=self.checkpoint_path, 176 | current_path=os.getcwd(), 177 | dir_c=os.listdir(os.getcwd()), 178 | p=os.path.abspath(cell_classification.__file__) 179 | ) 180 | ) 181 | self.model = model 182 | 183 | def prepare_normalization_dict( 184 | self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, 185 | ): 186 | """Load or prepare and save normalization dictionary for Nimbus model. 187 | Args: 188 | quantile (float): Quantile to use for normalization. 189 | n_subset (int): Number of fovs to use for normalization. 190 | multiprocessing (bool): Whether to use multiprocessing. 191 | overwrite (bool): Whether to overwrite existing normalization dict. 192 | Returns: 193 | dict: Dictionary of normalization factors. 194 | """ 195 | self.normalization_dict_path = os.path.join(self.output_dir, "normalization_dict.json") 196 | if os.path.exists(self.normalization_dict_path) and not overwrite: 197 | self.normalization_dict = json.load(open(self.normalization_dict_path)) 198 | else: 199 | 200 | n_jobs = os.cpu_count() if multiprocessing else 1 201 | self.normalization_dict = prepare_normalization_dict( 202 | self.fov_paths, self.output_dir, quantile, self.exclude_channels, n_subset, n_jobs 203 | ) 204 | 205 | def predict_fovs(self): 206 | """Predicts cell classification for input data. 207 | Returns: 208 | np.array: Predicted cell classification. 209 | """ 210 | if self.checked_inputs == False: 211 | self.check_inputs() 212 | if not hasattr(self, "normalization_dict"): 213 | self.prepare_normalization_dict() 214 | # check if GPU is available 215 | print("Available GPUs: ", tf.config.list_physical_devices('GPU')) 216 | print("Predictions will be saved in {}".format(self.output_dir)) 217 | print("Iterating through fovs will take a while...") 218 | self.cell_table = predict_fovs( 219 | self.fov_paths, self.output_dir, self, self.normalization_dict, 220 | self.segmentation_naming_convention, self.exclude_channels, self.save_predictions, 221 | self.half_resolution, batch_size=self._batch_size, 222 | test_time_augmentation=self.test_time_aug, 223 | ) 224 | self.cell_table.to_csv( 225 | os.path.join(self.output_dir,"nimbus_cell_table.csv"), index=False 226 | ) 227 | return self.cell_table 228 | 229 | Nimbus(["none_path"], lambda x: x, "") -------------------------------------------------------------------------------- /src/cell_classification/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | from skimage import io 8 | import tensorflow as tf 9 | import matplotlib.pyplot as plt 10 | from tqdm.autonotebook import tqdm 11 | from joblib import Parallel, delayed 12 | from skimage.segmentation import find_boundaries 13 | 14 | 15 | def calculate_normalization(channel_path, quantile): 16 | """Calculates the normalization value for a given channel 17 | Args: 18 | channel_path (str): path to channel 19 | quantile (float): quantile to use for normalization 20 | Returns: 21 | normalization_value (float): normalization value 22 | """ 23 | mplex_img = io.imread(channel_path) 24 | normalization_value = np.quantile(mplex_img, quantile) 25 | chan = os.path.basename(channel_path).split(".")[0] 26 | return chan, normalization_value 27 | 28 | 29 | def prepare_normalization_dict( 30 | fov_paths, output_dir, quantile=0.999, exclude_channels=[], n_subset=10, n_jobs=1, 31 | output_name="normalization_dict.json" 32 | ): 33 | """Prepares the normalization dict for a list of fovs 34 | Args: 35 | fov_paths (list): list of paths to fovs 36 | output_dir (str): path to output directory 37 | quantile (float): quantile to use for normalization 38 | exclude_channels (list): list of channels to exclude 39 | n_subset (int): number of fovs to use for normalization 40 | n_jobs (int): number of jobs to use for joblib multiprocessing 41 | output_name (str): name of output file 42 | Returns: 43 | normalization_dict (dict): dict with channel names as keys and norm factors as values 44 | """ 45 | normalization_dict = {} 46 | if n_subset is not None: 47 | random.shuffle(fov_paths) 48 | fov_paths = fov_paths[:n_subset] 49 | print("Iterate over fovs...") 50 | for fov_path in tqdm(fov_paths): 51 | channels = os.listdir(fov_path) 52 | channels = [ 53 | channel for channel in channels if channel.split(".")[0] not in exclude_channels 54 | ] 55 | channel_paths = [os.path.join(fov_path, channel) for channel in channels] 56 | if n_jobs > 1: 57 | normalization_values = Parallel(n_jobs=n_jobs)( 58 | delayed(calculate_normalization)(channel_path, quantile) 59 | for channel_path in channel_paths 60 | ) 61 | else: 62 | normalization_values = [ 63 | calculate_normalization(channel_path, quantile) 64 | for channel_path in channel_paths 65 | ] 66 | for channel, normalization_value in normalization_values: 67 | if channel not in normalization_dict: 68 | normalization_dict[channel] = [] 69 | normalization_dict[channel].append(normalization_value) 70 | for channel in normalization_dict.keys(): 71 | normalization_dict[channel] = np.mean(normalization_dict[channel]) 72 | # save normalization dict 73 | with open(os.path.join(output_dir, output_name), 'w') as f: 74 | json.dump(normalization_dict, f) 75 | return normalization_dict 76 | 77 | 78 | def prepare_input_data(mplex_img, instance_mask): 79 | """Prepares the input data for the segmentation model 80 | Args: 81 | mplex_img (np.array): multiplex image 82 | instance_mask (np.array): instance mask 83 | Returns: 84 | input_data (np.array): input data for segmentation model 85 | """ 86 | edge = find_boundaries(instance_mask, mode="inner").astype(np.uint8) 87 | binary_mask = np.logical_and(edge == 0, instance_mask > 0).astype(np.float32) 88 | input_data = np.stack([mplex_img, binary_mask], axis=-1)[np.newaxis,...] # bhwc 89 | return input_data 90 | 91 | 92 | def segment_mean(instance_mask, prediction): 93 | """Calculates the mean prediction per instance 94 | Args: 95 | instance_mask (np.array): instance mask 96 | prediction (np.array): prediction 97 | Returns: 98 | uniques (np.array): unique instance ids 99 | mean_per_cell (np.array): mean prediction per instance 100 | """ 101 | instance_mask_flat = tf.cast(tf.reshape(instance_mask, -1), tf.int32) # (h*w) 102 | pred_flat = tf.cast(tf.reshape(prediction, -1), tf.float32) 103 | sort_order = tf.argsort(instance_mask_flat) 104 | instance_mask_flat = tf.gather(instance_mask_flat, sort_order) 105 | uniques, _ = tf.unique(instance_mask_flat) 106 | pred_flat = tf.gather(pred_flat, sort_order) 107 | mean_per_cell = tf.math.segment_mean(pred_flat, instance_mask_flat) 108 | mean_per_cell = tf.gather(mean_per_cell, uniques) 109 | return [uniques.numpy()[1:], mean_per_cell.numpy()[1:]] # discard background 110 | 111 | 112 | def test_time_aug( 113 | input_data, channel, app, normalization_dict, rotate=True, flip=True, batch_size=4 114 | ): 115 | """Performs test time augmentation 116 | Args: 117 | input_data (np.array): input data for segmentation model, mplex_img and binary mask 118 | channel (str): channel name 119 | app (tf.keras.Model): segmentation model 120 | normalization_dict (dict): dict with channel names as keys and norm factors as values 121 | rotate (bool): whether to rotate 122 | flip (bool): whether to flip 123 | batch_size (int): batch size 124 | Returns: 125 | seg_map (np.array): predicted segmentation map 126 | """ 127 | forward_augmentations = [] 128 | backward_augmentations = [] 129 | if rotate: 130 | for k in [0,1,2,3]: 131 | forward_augmentations.append(lambda x: tf.image.rot90(x, k=k)) 132 | backward_augmentations.append(lambda x: tf.image.rot90(x, k=-k)) 133 | if flip: 134 | forward_augmentations += [ 135 | lambda x: tf.image.flip_left_right(x), 136 | lambda x: tf.image.flip_up_down(x) 137 | ] 138 | backward_augmentations += [ 139 | lambda x: tf.image.flip_left_right(x), 140 | lambda x: tf.image.flip_up_down(x) 141 | ] 142 | input_batch = [] 143 | for forw_aug in forward_augmentations: 144 | input_data_tmp = forw_aug(input_data).numpy() # bhwc 145 | input_batch.append(np.concatenate(input_data_tmp)) 146 | input_batch = np.stack(input_batch, 0) 147 | seg_map = app._predict_segmentation( 148 | input_batch, 149 | batch_size=batch_size, 150 | preprocess_kwargs={ 151 | "normalize": True, 152 | "marker": channel, 153 | "normalization_dict": normalization_dict}, 154 | ) 155 | tmp = [] 156 | for backw_aug, seg_map_tmp in zip(backward_augmentations, seg_map): 157 | seg_map_tmp = backw_aug(seg_map_tmp[np.newaxis,...]) 158 | seg_map_tmp = np.squeeze(seg_map_tmp) 159 | tmp.append(seg_map_tmp) 160 | seg_map = np.stack(tmp, -1) 161 | seg_map = np.mean(seg_map, axis = -1, keepdims = True) 162 | return seg_map 163 | 164 | 165 | def predict_fovs( 166 | fov_paths, cell_classification_output_dir, app, normalization_dict, 167 | segmentation_naming_convention, exclude_channels=[], save_predictions=True, 168 | half_resolution=False, batch_size=4, test_time_augmentation=True 169 | ): 170 | """Predicts the segmentation map for each mplex image in each fov 171 | Args: 172 | fov_paths (list): list of fov paths 173 | cell_classification_output_dir (str): path to cell classification output dir 174 | app (deepcell.applications.Application): segmentation model 175 | normalization_dict (dict): dict with channel names as keys and norm factors as values 176 | segmentation_naming_convention (function): function to get instance mask path from fov path 177 | exclude_channels (list): list of channels to exclude 178 | save_predictions (bool): whether to save predictions 179 | half_resolution (bool): whether to use half resolution 180 | batch_size (int): batch size 181 | test_time_augmentation (bool): whether to use test time augmentation 182 | Returns: 183 | cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell 184 | """ 185 | fov_dict_list = [] 186 | for fov_path in tqdm(fov_paths): 187 | out_fov_path = os.path.join( 188 | os.path.normpath(cell_classification_output_dir), os.path.basename(fov_path) 189 | ) 190 | fov_dict = {} 191 | for channel in os.listdir(fov_path): 192 | channel_path = os.path.join(fov_path, channel) 193 | if not channel.endswith(".tiff"): 194 | continue 195 | if channel[:2] == "._": 196 | continue 197 | channel = channel.split(".")[0] 198 | if channel in exclude_channels: 199 | continue 200 | mplex_img = np.squeeze(io.imread(channel_path)) 201 | instance_path = segmentation_naming_convention(fov_path) 202 | instance_mask = np.squeeze(io.imread(instance_path)) 203 | input_data = prepare_input_data(mplex_img, instance_mask) 204 | if half_resolution: 205 | scale = 0.5 206 | input_data = np.squeeze(input_data) 207 | h,w,_ = input_data.shape 208 | img = cv2.resize(input_data[...,0], [int(h*scale), int(w*scale)]) 209 | binary_mask = cv2.resize( 210 | input_data[...,1], [int(h*scale), int(w*scale)], interpolation=0 211 | ) 212 | input_data = np.stack([img, binary_mask], axis=-1)[np.newaxis,...] 213 | if test_time_augmentation: 214 | prediction = test_time_aug( 215 | input_data, channel, app, normalization_dict, batch_size=batch_size 216 | ) 217 | else: 218 | prediction = app._predict_segmentation( 219 | input_data, 220 | preprocess_kwargs={ 221 | "normalize": True, "marker": channel, 222 | "normalization_dict": normalization_dict 223 | }, 224 | batch_size=batch_size 225 | ) 226 | prediction = np.squeeze(prediction) 227 | if half_resolution: 228 | prediction = cv2.resize(prediction, (h, w)) 229 | instance_mask = np.expand_dims(instance_mask, axis=-1) 230 | labels, mean_per_cell = segment_mean(instance_mask, prediction) 231 | if "label" not in fov_dict.keys(): 232 | fov_dict["fov"] = [os.path.basename(fov_path)]*len(labels) 233 | fov_dict["label"] = labels 234 | fov_dict[channel+"_pred"] = mean_per_cell 235 | if save_predictions: 236 | os.makedirs(out_fov_path, exist_ok=True) 237 | pred_int = tf.cast(prediction*255.0, tf.uint8).numpy() 238 | io.imsave( 239 | os.path.join(out_fov_path, channel+".tiff"), pred_int, 240 | photometric="minisblack", compression="zlib" 241 | ) 242 | fov_dict_list.append(pd.DataFrame(fov_dict)) 243 | cell_table = pd.concat(fov_dict_list, ignore_index=True) 244 | return cell_table 245 | -------------------------------------------------------------------------------- /src/cell_classification/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Loss(): 5 | """ Wrapper for loss functions that allows for selective masking of the loss 6 | """ 7 | def __init__(self, loss_name, selective_masking, **kwargs): 8 | """ Initialize the loss function 9 | Args: 10 | loss_name (str): 11 | name of the loss function 12 | selective_masking (bool): 13 | whether to use selective masking 14 | **kwargs: 15 | additional arguments for the loss function 16 | """ 17 | self.loss_fn = getattr(tf.keras.losses, loss_name)( 18 | reduction=tf.keras.losses.Reduction.NONE, **kwargs 19 | ) 20 | self.selective_masking = selective_masking 21 | 22 | def mask_out(self, loss_img, y_true): 23 | """ Selectively mask the loss by setting it to zero where y_true == -1 24 | Args: 25 | loss_img (tf.Tensor): 26 | loss image 27 | y_true (tf.Tensor): 28 | ground truth image 29 | Returns: 30 | tf.Tensor: 31 | masked loss image 32 | """ 33 | y_true = tf.reshape(y_true, tf.shape(loss_img)) 34 | return tf.where(y_true == 2, tf.zeros_like(loss_img), loss_img) 35 | 36 | def __call__(self, y_true, y_pred): 37 | """ Call the loss function 38 | Args: 39 | y_true (tf.Tensor): 40 | ground truth image 41 | y_pred (tf.Tensor): 42 | prediction image 43 | Returns: 44 | tf.Tensor: 45 | loss image 46 | """ 47 | loss_img = self.loss_fn(y_true=tf.clip_by_value(y_true, 0, 1), y_pred=y_pred) 48 | if self.selective_masking: 49 | loss_img = self.mask_out(loss_img, y_true) 50 | return loss_img 51 | 52 | def get_config(self): 53 | """ Get the configuration of the loss function 54 | Returns: 55 | dict: 56 | configuration of the loss function 57 | """ 58 | return { 59 | "loss_name": self.loss_fn.__class__.__name__, 60 | "selective_masking": self.selective_masking, 61 | **self.loss_fn.get_config(), 62 | } 63 | -------------------------------------------------------------------------------- /src/cell_classification/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | 4 | import h5py 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.metrics import auc, confusion_matrix, roc_curve 8 | 9 | 10 | def calc_roc(pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False): 11 | """Calculate ROC curve 12 | Args: 13 | pred_list (list): 14 | list of samples with predictions 15 | gt_key (str): 16 | key for ground truth labels 17 | pred_key (str): 18 | key for predictions 19 | Returns: 20 | roc (dict): 21 | dictionary containing ROC curve data 22 | """ 23 | roc = {"fpr": [], "tpr": [], "thresholds": [], "auc": [], "marker": []} 24 | for sample in pred_list: 25 | if cell_level: 26 | # filter out cells with gt activity == 2 27 | df = sample["activity_df"].copy() 28 | df = df[df[gt_key] != 2] 29 | gt = df[gt_key].to_numpy() 30 | pred = df[pred_key].to_numpy() 31 | else: 32 | foreground = sample["binary_mask"] > 0 33 | gt = sample[gt_key][foreground].flatten() 34 | pred = sample[pred_key][foreground].flatten() 35 | if gt.size > 0 and gt.min() == 0 and gt.max() > 0: # roc is only defined for this interval 36 | fpr, tpr, thresholds = roc_curve(gt, pred) 37 | roc["fpr"].append(fpr) 38 | roc["tpr"].append(tpr) 39 | roc["thresholds"].append(thresholds) 40 | roc["auc"].append(auc(fpr, tpr)) 41 | roc["marker"].append(sample["marker"]) 42 | return roc 43 | 44 | 45 | def calc_scores(gt, pred, threshold): 46 | """Calculate scores for a given threshold 47 | Args: 48 | gt (np.array): 49 | ground truth labels 50 | pred (np.array): 51 | predictions 52 | threshold (float): 53 | threshold for predictions 54 | Returns: 55 | scores (dict): 56 | dictionary containing scores 57 | """ 58 | # exclude masked out regions from metric calculation 59 | pred = pred[gt < 2] 60 | gt = gt[gt < 2] 61 | tn, fp, fn, tp = confusion_matrix( 62 | y_true=gt, y_pred=(pred >= threshold).astype(int), labels=[0, 1] 63 | ).ravel() 64 | metrics = { 65 | "tp": tp, "tn": tn, "fp": fp, "fn": fn, 66 | "accuracy": (tp + tn) / (tp + tn + fp + fn + 1e-8), 67 | "precision": tp / (tp + fp + 1e-8), 68 | "recall": tp / (tp + fn + 1e-8), 69 | "specificity": tn / (tn + fp + 1e-8), 70 | "f1_score": 2 * tp / (2 * tp + fp + fn + 1e-8), 71 | } 72 | return metrics 73 | 74 | 75 | def calc_metrics( 76 | pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False 77 | ): 78 | """Calculate metrics 79 | Args: 80 | pred_list (list): 81 | list of samples with predictions 82 | gt_key (str): 83 | key of ground truth in pred_list 84 | pred_key (str): 85 | key of prediction in pred_list 86 | Returns: 87 | avg_metrics (dict): 88 | dictionary containing metrics averaged over all samples 89 | """ 90 | metrics_dict = { 91 | "accuracy": [], "precision": [], "recall": [], "specificity": [], "f1_score": [], "tp": [], 92 | "tn": [], "fp": [], "fn": [], 93 | } 94 | 95 | def _calc_metrics(threshold): 96 | """Helper function to calculate metrics for a given threshold in parallel""" 97 | metrics = deepcopy(metrics_dict) 98 | 99 | for sample in pred_list: 100 | if cell_level: 101 | df = sample["activity_df"] 102 | # filter out cells with gt activity == 2 103 | df = df[df[gt_key] != 2] 104 | gt = np.array(df[gt_key]) 105 | pred = np.array(df[pred_key]) 106 | else: 107 | foreground = sample["binary_mask"] > 0 108 | gt = sample[gt_key][foreground].flatten() 109 | pred = sample[pred_key][foreground].flatten() 110 | if gt.size == 0: 111 | continue 112 | scores = calc_scores(gt, pred, threshold) 113 | 114 | # only add specificity for samples that have no positives 115 | if np.sum(gt) == 0: 116 | keys = ["specificity"] 117 | else: 118 | keys = scores.keys() 119 | for key in keys: 120 | metrics[key].append(scores[key]) 121 | metrics["threshold"] = threshold 122 | for key in ["dataset", "imaging_platform", "marker"]: 123 | metrics[key] = sample[key] 124 | return metrics 125 | 126 | # calculate metrics for all thresholds in parallel 127 | thresholds = np.linspace(0.01, 1, 50) 128 | # metric_list = Parallel(n_jobs=8)(delayed(_calc_metrics)(i) for i in thresholds) 129 | metric_list = [_calc_metrics(i) for i in thresholds] 130 | # reduce metrics over all samples for each threshold 131 | avg_metrics = deepcopy(metrics_dict) 132 | for key in ["dataset", "imaging_platform", "marker", "threshold"]: 133 | avg_metrics[key] = [] 134 | for metrics in metric_list: 135 | for key in ["accuracy", "precision", "recall", "specificity", "f1_score"]: 136 | avg_metrics[key].append(np.mean(metrics[key])) 137 | for key in ["tp", "tn", "fp", "fn"]: # sum fn, fp, tn, tp 138 | avg_metrics[key].append(np.sum(metrics[key])) 139 | for key in ["dataset", "imaging_platform", "marker", "threshold"]: # copy strings 140 | avg_metrics[key].append(metrics[key]) 141 | return avg_metrics 142 | 143 | 144 | def average_roc(roc_list): 145 | """Average ROC curves 146 | Args: 147 | roc_list (list): 148 | list of ROC curves 149 | Returns: 150 | tprs (np.array): 151 | standardized true positive rates for each sample 152 | mean_tprs (np.array): 153 | mean true positive rates over all samples 154 | std np.array: 155 | standard deviation of true positive rates over all samples 156 | base (np.array): 157 | fpr values for interpolation 158 | mean_thresh (np.array): 159 | mean of the threshold values over all samples 160 | """ 161 | base = np.linspace(0, 1, 101) 162 | tpr_list = [] 163 | thresh_list = [] 164 | for i in range(len(roc_list["tpr"])): 165 | tpr_list.append(np.interp(base, roc_list["fpr"][i], roc_list["tpr"][i])) 166 | thresh_list.append(np.interp(base, roc_list["tpr"][i], roc_list["thresholds"][i])) 167 | 168 | tprs = np.array(tpr_list) 169 | thresh_list = np.array(thresh_list) 170 | mean_thresh = np.mean(thresh_list, axis=0) 171 | mean_tprs = tprs.mean(axis=0) 172 | std = tprs.std(axis=0) 173 | return tprs, mean_tprs, base, std, mean_thresh 174 | 175 | 176 | class HDF5Loader(object): 177 | """HDF5 iterator for loading data from HDF5 files""" 178 | 179 | def __init__(self, folder): 180 | """Initialize HDF5 generator 181 | Args: 182 | folder (str): 183 | path to folder containing HDF5 files 184 | """ 185 | self.folder = folder 186 | self.files = os.listdir(folder) 187 | # filter out hdf files 188 | self.files = [os.path.join(folder, f) for f in self.files if f.endswith(".hdf")] 189 | self.file_idx = 0 190 | 191 | def __len__(self): 192 | return len(self.files) 193 | 194 | def load_hdf(self, file): 195 | """Load HDF5 file 196 | Args: 197 | file (str): 198 | path to HDF5 file 199 | Returns: 200 | data (dict): 201 | dictionary containing data from HDF5 file 202 | """ 203 | out_dict = {} 204 | with h5py.File(file, "r") as f: 205 | keys = [key for key in f.keys() if key != "activity_df"] 206 | for key in keys: 207 | if isinstance(f[key][()], bytes): 208 | out_dict[key] = f[key][()].decode("utf-8") 209 | else: 210 | out_dict[key] = f[key][()] 211 | out_dict["activity_df"] = pd.read_json(f["activity_df"][()].decode()) 212 | return out_dict 213 | 214 | def __iter__(self): 215 | self.file_idx = 0 216 | return self 217 | 218 | def __next__(self): 219 | if self.file_idx >= len(self.files): 220 | raise StopIteration 221 | else: 222 | self.file_idx += 1 223 | return self.load_hdf(self.files[self.file_idx - 1]) 224 | -------------------------------------------------------------------------------- /src/cell_classification/post_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def process_to_cells(instance_mask, prediction): 6 | """Process predictions from pixel level to cell level 7 | Args: 8 | instance_mask (np.array): 9 | 2D array of instance mask 10 | prediction (np.array): 11 | 2D array of pixel level prediction 12 | Returns: 13 | np.array: 14 | 2D array of cell level averaged prediction 15 | pd.DataFrame: 16 | DataFrame of cell level predictions 17 | """ 18 | 19 | unique_labels = np.unique(instance_mask) 20 | unique_labels = unique_labels[unique_labels != 0] 21 | mean_per_cell_mask = np.zeros_like(instance_mask, dtype=np.float32) 22 | df = pd.DataFrame(columns=['labels', 'pred_activity']) 23 | i = 0 24 | for unique_label in unique_labels: 25 | mask = instance_mask == unique_label 26 | mean_pred = prediction[mask].mean() 27 | mean_per_cell_mask[mask] = mean_pred 28 | df = pd.concat( 29 | [df, pd.DataFrame( 30 | {'labels': [unique_label], 'pred_activity': [mean_pred]}, index=[i] 31 | )], 32 | ) 33 | i += 1 34 | return mean_per_cell_mask, df 35 | 36 | 37 | def merge_activity_df(gt_df, pred_df): 38 | """Merge ground truth and prediction dataframes over labels 39 | Args: 40 | gt_df (pd.DataFrame): 41 | DataFrame of ground truth 42 | pred_df (pd.DataFrame): 43 | DataFrame of prediction 44 | Returns: 45 | pd.DataFrame: 46 | DataFrame of merged ground truth and prediction 47 | """ 48 | pred_df.labels = pred_df.labels.astype(int) 49 | gt_df.labels = gt_df.labels.astype(int) 50 | df = gt_df.merge(pred_df, on='labels', how='left') 51 | return df 52 | -------------------------------------------------------------------------------- /src/cell_classification/prepare_data_script_MSK_colon.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cell_classification.simple_data_prep import SimpleTFRecords 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 6 | 7 | 8 | def naming_convention(fname): 9 | return os.path.join( 10 | "E:/angelo_lab/data/MSKCC_colon/segmentation", 11 | fname + "feature_0.ome.tif" 12 | ) 13 | 14 | 15 | data_prep = SimpleTFRecords( 16 | data_dir=os.path.normpath( 17 | "E:/angelo_lab/data/MSKCC_colon/raw_structured" 18 | ), 19 | cell_table_path=os.path.normpath( 20 | "E:/angelo_lab/data/MSKCC_colon/cell_table.csv" 21 | ), 22 | imaging_platform="Vectra", 23 | dataset="MSK_colon", 24 | tissue_type="colon", 25 | nuclei_channels=["DAPI"], 26 | membrane_channels=["CD3", "CD8", "ICOS", "panCK+CK7+CAM5.2"], 27 | tile_size=[256, 256], 28 | stride=[240, 240], 29 | tf_record_path=os.path.normpath("E:/angelo_lab/data/MSKCC_colon"), 30 | normalization_quantile=0.999, 31 | selected_markers=["CD3", "CD8", "Foxp3", "ICOS", "panCK+CK7+CAM5.2", "PD-L1"], 32 | sample_key="fov", 33 | segment_label_key="labels", 34 | segmentation_naming_convention=naming_convention, 35 | exclude_background_tiles=True, 36 | img_suffix=".ome.tif", 37 | # normalization_dict_path=os.path.normpath( 38 | # "E:/angelo_lab/data/MSKCC_colon/normalization_dict.json" 39 | # ), 40 | ) 41 | 42 | data_prep.make_tf_record() 43 | -------------------------------------------------------------------------------- /src/cell_classification/prepare_data_script_MSK_pancreas.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cell_classification.simple_data_prep import SimpleTFRecords 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 6 | 7 | 8 | def naming_convention(fname): 9 | return os.path.join( 10 | "E:/angelo_lab/data/MSKCC_pancreas/segmentation", 11 | fname + "feature_0.ome.tif" 12 | ) 13 | 14 | 15 | data_prep = SimpleTFRecords( 16 | data_dir=os.path.normpath( 17 | "E:/angelo_lab/data/MSKCC_pancreas/raw_structured" 18 | ), 19 | cell_table_path=os.path.normpath( 20 | "E:/angelo_lab/data/MSKCC_pancreas/cell_table.csv" 21 | ), 22 | imaging_platform="Vectra", 23 | dataset="MSK_pancreas", 24 | tissue_type="pancreas", 25 | nuclei_channels=["DAPI"], 26 | membrane_channels=["CD8", "CD40", "CD40-L", "panCK"], 27 | tile_size=[256, 256], 28 | stride=[240, 240], 29 | tf_record_path=os.path.normpath("E:/angelo_lab/data/MSKCC_pancreas"), 30 | normalization_quantile=0.999, 31 | selected_markers=["CD8", "CD40", "CD40-L", "panCK", "PD-1", "PD-L1"], 32 | sample_key="fov", 33 | segment_label_key="labels", 34 | segmentation_naming_convention=naming_convention, 35 | exclude_background_tiles=True, 36 | img_suffix=".ome.tif", 37 | gt_suffix="", 38 | normalization_dict_path=os.path.normpath( 39 | "E:/angelo_lab/data/MSKCC_pancreas/normalization_dict.json" 40 | ), 41 | ) 42 | 43 | data_prep.make_tf_record() 44 | -------------------------------------------------------------------------------- /src/cell_classification/prepare_data_script_decidua.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cell_classification.segmentation_data_prep import SegmentationTFRecords 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 6 | 7 | 8 | def naming_convention(fname): 9 | return os.path.join( 10 | "E:/angelo_lab/data/decidua/segmentation_data/", 11 | fname + "_segmentation_labels.tiff", 12 | ) 13 | 14 | 15 | data_prep = SegmentationTFRecords( 16 | data_dir=os.path.normpath("E:/angelo_lab/data/decidua/image_data"), 17 | cell_table_path=os.path.normpath( 18 | "E:/angelo_lab/data/decidua/" 19 | "Supplementary_table_3_single_cells_updated.csv" 20 | ), 21 | conversion_matrix_path=os.path.normpath( 22 | "E:/angelo_lab/data/decidua/conversion_matrix.csv" 23 | ), 24 | imaging_platform="MIBI", 25 | dataset="decidua_erin", 26 | tissue_type="decidua", 27 | nuclei_channels=["H3"], 28 | membrane_channels=["VIM", "HLAG", "CD3", "CD14", "CD56"], 29 | tile_size=[256, 256], 30 | stride=[240, 240], 31 | tf_record_path=os.path.normpath("E:/angelo_lab/data/decidua"), 32 | normalization_dict_path=os.path.normpath( 33 | "E:/angelo_lab/data/decidua/normalization_dict.json" 34 | ), 35 | selected_markers=[ 36 | "CD45", "CD14", "HLADR", "CD11c", "DCSIGN", "CD68", "CD206", "CD163", "CD3", "Ki67", "IDO", 37 | "CD8", "CD4", "CD16", "CD56", "CD57", "SMA", "VIM", "CD31", "CK7", "HLAG", "FoxP3", "PDL1", 38 | ], 39 | normalization_quantile=0.999, 40 | cell_type_key="lineage", 41 | sample_key="Point", 42 | segmentation_naming_convention=naming_convention, 43 | segment_label_key="cell_ID_in_Point", 44 | exclude_background_tiles=True, 45 | img_suffix=".tif", 46 | ) 47 | 48 | data_prep.make_tf_record() 49 | -------------------------------------------------------------------------------- /src/cell_classification/prepare_data_script_tonic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cell_classification.segmentation_data_prep import SegmentationTFRecords 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 6 | 7 | 8 | def naming_convention(fname): 9 | return os.path.join( 10 | "E:/angelo_lab/data/TONIC/raw/segmentation_data/deepcell_output", 11 | fname + "_feature_0.tif" 12 | ) 13 | 14 | 15 | data_prep = SegmentationTFRecords( 16 | data_dir=os.path.normpath( 17 | "E:/angelo_lab/data/TONIC/raw/image_data/samples" 18 | ), 19 | cell_table_path=os.path.normpath( 20 | "E:/angelo_lab/data/TONIC/raw/" + 21 | "combined_cell_table_normalized_cell_labels_updated.csv" 22 | ), 23 | conversion_matrix_path=os.path.normpath( 24 | "E:/angelo_lab/data/TONIC/raw/TONIC_conversion_matrix.csv" 25 | ), 26 | imaging_platform="MIBI", 27 | dataset="TONIC", 28 | tissue_type="TNBC", 29 | nuclei_channels=["H3K27me3", "H3K9ac"], 30 | membrane_channels=["CD45", "ECAD", "CD14", "CD38", "CK17"], 31 | tile_size=[256, 256], 32 | stride=[240, 240], 33 | tf_record_path=os.path.normpath("E:/angelo_lab/data/TONIC"), 34 | # normalization_dict_path=os.path.normpath( 35 | # "E:/angelo_lab/data/TONIC/normalization_dict.json" 36 | # ), 37 | normalization_quantile=0.999, 38 | cell_type_key="cell_meta_cluster", 39 | sample_key="fov", 40 | segmentation_fname="cell_segmentation", 41 | segmentation_naming_convention=naming_convention, 42 | segment_label_key="label", 43 | exclude_background_tiles=True, 44 | ) 45 | 46 | data_prep.make_tf_record() 47 | -------------------------------------------------------------------------------- /src/cell_classification/prepare_data_script_tonic_annotated.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cell_classification.simple_data_prep import SimpleTFRecords 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 6 | 7 | 8 | def naming_convention(fname): 9 | return os.path.join( 10 | "E:/angelo_lab/data/TONIC/raw/segmentation_data/deepcell_output", 11 | fname + "_feature_0.tif" 12 | ) 13 | 14 | 15 | data_prep = SimpleTFRecords( 16 | data_dir=os.path.normpath( 17 | "C:/Users/Lorenz/OneDrive - Charité - Universitätsmedizin Berlin/cell_classification/" + 18 | "data_annotation/tonic/raw" 19 | ), 20 | cell_table_path=os.path.normpath( 21 | "C:/Users/Lorenz/OneDrive - Charité - Universitätsmedizin Berlin/cell_classification/" + 22 | "data_annotation/tonic/ground_truth.csv" 23 | ), 24 | imaging_platform="MIBI", 25 | dataset="TONIC", 26 | tissue_type="TNBC", 27 | nuclei_channels=["H3K27me3", "H3K9ac"], 28 | membrane_channels=["CD45", "ECAD", "CD14", "CD38", "CK17"], 29 | selected_markers=[ 30 | "Calprotectin", "CD14", "CD163", "CD20", "CD3", "CD31", "CD4", "CD45", "CD56", "CD68", 31 | "CD8", "ChyTr", "CK17", "Collagen1", "ECAD", "FAP", "Fibronectin", "FOXP3", "HLADR", "SMA", 32 | "VIM" 33 | ], 34 | tile_size=[256, 256], 35 | stride=[240, 240], 36 | tf_record_path=os.path.normpath("E:/angelo_lab/data/TONIC/annotated"), 37 | normalization_dict_path=os.path.normpath( 38 | "E:/angelo_lab/data/TONIC/normalization_dict.json" 39 | ), 40 | normalization_quantile=0.999, 41 | sample_key="fov", 42 | segmentation_fname="cell_segmentation", 43 | segmentation_naming_convention=naming_convention, 44 | segment_label_key="labels", 45 | exclude_background_tiles=True, 46 | gt_suffix="", 47 | ) 48 | 49 | data_prep.make_tf_record() 50 | -------------------------------------------------------------------------------- /src/cell_classification/simple_data_prep.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from cell_classification.segmentation_data_prep import SegmentationTFRecords 4 | 5 | 6 | class SimpleTFRecords(SegmentationTFRecords): 7 | """Prepares the data for the segmentation model""" 8 | def __init__( 9 | self, data_dir, cell_table_path, imaging_platform, dataset, tissue_type, tile_size, stride, 10 | tf_record_path, nuclei_channels, membrane_channels, selected_markers=None, 11 | normalization_dict_path=None, normalization_quantile=0.999, 12 | segmentation_naming_convention=None, segmentation_fname="cell_segmentation", 13 | exclude_background_tiles=False, resize=None, img_suffix=".tiff", sample_key="SampleID", 14 | segment_label_key="labels", gt_suffix="_gt", 15 | 16 | ): 17 | """Initializes SegmentationTFRecords and loads everything except the images 18 | 19 | Args: 20 | data_dir str: 21 | Path where the data is stored 22 | cell_table_path (str): 23 | Path to the cell table 24 | imaging_platform (str): 25 | The imaging platform used to generate the multiplexed imaging data 26 | dataset (str): 27 | The dataset where the imaging data comes from 28 | tissue_type (str): 29 | The tissue type of the imaging data 30 | tile_size list [int,int]: 31 | The size of the tiles to use for the segmentation model 32 | stride list [int,int]: 33 | The stride to tile the data 34 | tf_record_path (str): 35 | The path to the tf record to make 36 | selected_markers (list): 37 | The markers of interest for generating the tf record. If None, all markers 38 | mentioned in the conversion_matrix are used 39 | normalization_dict_path (str): 40 | Path to the normalization dict json 41 | normalization_quantile (float): 42 | The quantile to use for normalization of multiplexed data 43 | segmentation_naming_convention (Function): 44 | Function that takes in the sample name and returns the path to the segmentation 45 | .tiff file. Default is None, then it is assumed that the segmentation file is in 46 | the sample folder and is named $segmentation_fname.tiff 47 | exclude_background_tiles (bool): 48 | Whether to exclude the all tiles that only contain background 49 | resize (float): 50 | The resize factor to use for the images 51 | img_suffix (str): 52 | The suffix of the image files 53 | sample_key (str): 54 | The key in the cell table that contains the sample name 55 | segment_label_key (str): 56 | The key in the cell table that contains the segmentation labels 57 | gt_suffix (str): 58 | The suffix of the ground truth column in the cell_table 59 | """ 60 | super().__init__( 61 | data_dir=data_dir, cell_table_path=cell_table_path, imaging_platform=imaging_platform, 62 | dataset=dataset, tile_size=tile_size, stride=stride, tf_record_path=tf_record_path, 63 | nuclei_channels=nuclei_channels, membrane_channels=membrane_channels, 64 | selected_markers=selected_markers, normalization_dict_path=normalization_dict_path, 65 | normalization_quantile=normalization_quantile, segmentation_fname=segmentation_fname, 66 | segmentation_naming_convention=segmentation_naming_convention, resize=resize, 67 | exclude_background_tiles=exclude_background_tiles, img_suffix=img_suffix, 68 | sample_key=sample_key, segment_label_key=segment_label_key, tissue_type=tissue_type, 69 | conversion_matrix_path=None, 70 | ) 71 | self.selected_markers = selected_markers 72 | self.data_dir = data_dir 73 | self.normalization_dict_path = normalization_dict_path 74 | self.normalization_quantile = normalization_quantile 75 | self.segmentation_fname = segmentation_fname 76 | self.segment_label_key = segment_label_key 77 | self.sample_key = sample_key 78 | self.dataset = dataset 79 | self.tissue_type = tissue_type 80 | self.imaging_platform = imaging_platform 81 | self.tf_record_path = tf_record_path 82 | self.cell_table_path = cell_table_path 83 | self.tile_size = tile_size 84 | self.stride = stride 85 | self.segmentation_naming_convention = segmentation_naming_convention 86 | self.exclude_background_tiles = exclude_background_tiles 87 | self.resize = resize 88 | self.img_suffix = img_suffix 89 | self.gt_suffix = gt_suffix 90 | self.nuclei_channels = nuclei_channels 91 | self.membrane_channels = membrane_channels 92 | 93 | def get_marker_activity(self, sample_name, marker): 94 | """Gets the marker activity for the given labels 95 | Args: 96 | sample_name (str): 97 | The name of the sample 98 | conversion_matrix (pd.DataFrame): 99 | The conversion matrix to use for the lookup 100 | marker (str, list): 101 | The markers to get the activity for 102 | Returns: 103 | np.array: 104 | The marker activity for the given labels, 1 if the marker is active, 0 105 | otherwise and -1 if the marker is not specific enough to be considered active 106 | """ 107 | 108 | df = pd.DataFrame( 109 | { 110 | "labels": self.sample_subset[self.segment_label_key], 111 | "activity": self.sample_subset[marker + self.gt_suffix], 112 | } 113 | ) 114 | return df, None 115 | 116 | def check_additional_inputs(self): 117 | """Checks additional inputs for correctness""" 118 | -------------------------------------------------------------------------------- /src/cell_classification/viewer_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipywidgets as widgets 3 | from IPython.display import display 4 | from io import BytesIO 5 | from skimage import io 6 | from copy import copy 7 | import numpy as np 8 | from natsort import natsorted 9 | 10 | 11 | class NimbusViewer(object): 12 | def __init__(self, input_dir, output_dir, img_width='600px'): 13 | """Viewer for Nimbus application. 14 | Args: 15 | input_dir (str): Path to directory containing individual channels of multiplexed images 16 | output_dir (str): Path to directory containing output of Nimbus application. 17 | img_width (str): Width of images in viewer. 18 | """ 19 | self.image_width = img_width 20 | self.input_dir = input_dir 21 | self.output_dir = output_dir 22 | self.fov_names = [os.path.basename(p) for p in os.listdir(output_dir) if \ 23 | os.path.isdir(os.path.join(output_dir, p))] 24 | self.fov_names = natsorted(self.fov_names) 25 | self.update_button = widgets.Button(description="Update Image") 26 | self.update_button.on_click(self.update_button_click) 27 | 28 | self.fov_select = widgets.Select( 29 | options=self.fov_names, 30 | description='FOV:', 31 | disabled=False 32 | ) 33 | self.fov_select.observe(self.select_fov, names='value') 34 | 35 | self.red_select = widgets.Select( 36 | options=[], 37 | description='Red:', 38 | disabled=False 39 | ) 40 | self.green_select = widgets.Select( 41 | options=[], 42 | description='Green:', 43 | disabled=False 44 | ) 45 | self.blue_select = widgets.Select( 46 | options=[], 47 | description='Blue:', 48 | disabled=False 49 | ) 50 | self.input_image = widgets.Image() 51 | self.output_image = widgets.Image() 52 | 53 | def select_fov(self, change): 54 | """Selects fov to display. 55 | Args: 56 | change (dict): Change dictionary from ipywidgets. 57 | """ 58 | fov_path = os.path.join(self.output_dir, self.fov_select.value) 59 | channels = [ 60 | ch for ch in os.listdir(fov_path) if os.path.isfile(os.path.join(fov_path, ch)) 61 | ] 62 | self.red_select.options = natsorted(channels) 63 | self.green_select.options = natsorted(channels) 64 | self.blue_select.options = natsorted(channels) 65 | 66 | def create_composite_image(self, path_dict): 67 | """Creates composite image from input paths. 68 | Args: 69 | path_dict (dict): Dictionary of paths to images. 70 | Returns: 71 | composite_image (np.array): Composite image. 72 | """ 73 | output_image = [] 74 | img = None 75 | for k, p in path_dict.items(): 76 | if p: 77 | img = io.imread(p) 78 | output_image.append(img) 79 | if p is None: 80 | non_none = [p for p in path_dict.values() if p] 81 | if not img: 82 | img = io.imread(non_none[0]) 83 | output_image.append(img*0) 84 | 85 | composite_image = np.stack(output_image, axis=-1) 86 | return composite_image 87 | 88 | def layout(self): 89 | """Creates layout for viewer.""" 90 | channel_selectors = widgets.VBox([ 91 | self.red_select, 92 | self.green_select, 93 | self.blue_select 94 | ]) 95 | self.input_image.layout.width = self.image_width 96 | self.output_image.layout.width = self.image_width 97 | viewer_html = widgets.HTML("

Select files

") 98 | input_html = widgets.HTML("

Input

") 99 | output_html = widgets.HTML("

Nimbus Output

") 100 | 101 | layout = widgets.HBox([ 102 | widgets.VBox([ 103 | viewer_html, 104 | self.fov_select, 105 | channel_selectors, 106 | self.update_button 107 | ]), 108 | widgets.VBox([ 109 | input_html, 110 | self.input_image 111 | ]), 112 | widgets.VBox([ 113 | output_html, 114 | self.output_image 115 | ]) 116 | ]) 117 | display(layout) 118 | 119 | def search_for_similar(self, select_value): 120 | """Searches for similar filename in input directory. 121 | Args: 122 | select_value (str): Filename to search for. 123 | Returns: 124 | similar_path (str): Path to similar filename. 125 | """ 126 | in_f_path = os.path.join(self.input_dir, self.fov_select.value) 127 | # search for similar filename in in_f_path 128 | in_f_files = [ 129 | f for f in os.listdir(in_f_path) if os.path.isfile(os.path.join(in_f_path, f)) 130 | ] 131 | similar_path = None 132 | for f in in_f_files: 133 | if select_value.split(".")[0]+"." in f: 134 | similar_path = os.path.join(self.input_dir, self.fov_select.value, f) 135 | return similar_path 136 | 137 | def update_img(self, image_viewer, composite_image): 138 | """Updates image in viewer by saving it as png and loading it with the viewer widget. 139 | Args: 140 | image_viewer (ipywidgets.Image): Image widget to update. 141 | composite_image (np.array): Composite image to display. 142 | """ 143 | # Convert composite image to bytes and assign it to the output_image widget 144 | with BytesIO() as output_buffer: 145 | io.imsave(output_buffer, composite_image, format="png") 146 | output_buffer.seek(0) 147 | image_viewer.value = output_buffer.read() 148 | 149 | def update_composite(self): 150 | """Updates composite image in viewer.""" 151 | path_dict = { 152 | "red": None, 153 | "green": None, 154 | "blue": None 155 | } 156 | in_path_dict = copy(path_dict) 157 | if self.red_select.value: 158 | path_dict["red"] = os.path.join( 159 | self.output_dir, self.fov_select.value, self.red_select.value 160 | ) 161 | in_path_dict["red"] = self.search_for_similar(self.red_select.value) 162 | if self.green_select.value: 163 | path_dict["green"] = os.path.join( 164 | self.output_dir, self.fov_select.value, self.green_select.value 165 | ) 166 | in_path_dict["green"] = self.search_for_similar(self.green_select.value) 167 | if self.blue_select.value: 168 | path_dict["blue"] = os.path.join( 169 | self.output_dir, self.fov_select.value, self.blue_select.value 170 | ) 171 | in_path_dict["blue"] = self.search_for_similar(self.blue_select.value) 172 | non_none = [p for p in path_dict.values() if p] 173 | if not non_none: 174 | return 175 | composite_image = self.create_composite_image(path_dict) 176 | in_composite_image = self.create_composite_image(in_path_dict) 177 | in_composite_image = in_composite_image / np.quantile( 178 | in_composite_image, 0.999, axis=(0,1) 179 | ) 180 | in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8) 181 | # update image viewers 182 | self.update_img(self.input_image, in_composite_image) 183 | self.update_img(self.output_image, composite_image) 184 | 185 | def update_button_click(self, button): 186 | """Updates composite image in viewer when update button is clicked.""" 187 | self.update_composite() 188 | 189 | def display(self): 190 | """Displays viewer.""" 191 | self.select_fov(None) 192 | self.layout() 193 | self.update_composite() 194 | -------------------------------------------------------------------------------- /src/deepcell/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/src/deepcell/__init__.py -------------------------------------------------------------------------------- /src/deepcell/panopticnet.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of 3 | # Technology (Caltech), with support from the Paul Allen Family Foundation, 4 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. 5 | # All rights reserved. 6 | # 7 | # Licensed under a modified Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE 12 | # 13 | # The Work provided may be used for non-commercial academic purposes only. 14 | # For any other use of the Work, including commercial use, please contact: 15 | # vanvalenlab@gmail.com 16 | # 17 | # Neither the name of Caltech nor the names of its contributors may be used 18 | # to endorse or promote products derived from this software without specific 19 | # prior written permission. 20 | # 21 | # Unless required by applicable law or agreed to in writing, software 22 | # distributed under the License is distributed on an "AS IS" BASIS, 23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | # See the License for the specific language governing permissions and 25 | # limitations under the License. 26 | # ============================================================================== 27 | """Feature pyramid network utility functions""" 28 | 29 | 30 | import math 31 | import re 32 | 33 | from tensorflow.keras import backend as K 34 | from tensorflow.keras.models import Model 35 | from tensorflow.keras.layers import Conv2D, Conv3D 36 | from tensorflow.keras.layers import TimeDistributed, ConvLSTM2D 37 | from tensorflow.keras.layers import Input, Concatenate 38 | from tensorflow.keras.layers import Activation, BatchNormalization 39 | 40 | from deepcell.layers import ImageNormalization2D, Location2D 41 | from deepcell.fpn import __create_pyramid_features 42 | from deepcell.fpn import __create_semantic_head 43 | from deepcell.backbone_utils import get_backbone 44 | 45 | 46 | def __merge_temporal_features(feature, mode='conv', feature_size=256, 47 | frames_per_batch=1): 48 | """Merges feature with its temporal residual through addition. 49 | 50 | Input feature (x) --> Temporal convolution* --> Residual feature (x') 51 | *Type of temporal convolution specified by ``mode``. 52 | 53 | Output: ``y = x + x'`` 54 | 55 | Args: 56 | feature (tensorflow.keras.Layer): Input layer 57 | mode (str): Mode of temporal convolution. One of 58 | ``{'conv','lstm', None}``. 59 | feature_size (int): Length of convolutional kernel 60 | frames_per_batch (int): Size of z axis in generated batches. 61 | If equal to 1, assumes 2D data. 62 | 63 | Raises: 64 | ValueError: ``mode`` not 'conv', 'lstm' or ``None`` 65 | 66 | Returns: 67 | tensorflow.keras.Layer: Input feature merged with its residual 68 | from a temporal convolution. If mode is ``None``, 69 | the output is exactly the input. 70 | """ 71 | # Check inputs to mode 72 | acceptable_modes = {'conv', 'lstm', None} 73 | if mode is not None: 74 | mode = str(mode).lower() 75 | if mode not in acceptable_modes: 76 | raise ValueError(f'Mode {mode} not supported. Please choose ' 77 | f'from {str(acceptable_modes)}.') 78 | 79 | f_name = str(feature.name)[:2] 80 | 81 | if mode == 'conv': 82 | x = Conv3D(feature_size, 83 | (frames_per_batch, 3, 3), 84 | strides=(1, 1, 1), 85 | padding='same', 86 | name=f'conv3D_mtf_{f_name}', 87 | )(feature) 88 | x = BatchNormalization(axis=-1, name=f'bnorm_mtf_{f_name}')(x) 89 | x = Activation('relu', name=f'acti_mtf_{f_name}')(x) 90 | elif mode == 'lstm': 91 | x = ConvLSTM2D(feature_size, 92 | (3, 3), 93 | padding='same', 94 | activation='relu', 95 | return_sequences=True, 96 | name=f'convLSTM_mtf_{f_name}')(feature) 97 | else: 98 | x = feature 99 | 100 | temporal_feature = x 101 | 102 | return temporal_feature 103 | 104 | 105 | def PanopticNet(backbone, 106 | input_shape, 107 | inputs=None, 108 | backbone_levels=['C3', 'C4', 'C5'], 109 | pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'], 110 | create_pyramid_features=__create_pyramid_features, 111 | create_semantic_head=__create_semantic_head, 112 | frames_per_batch=1, 113 | temporal_mode=None, 114 | num_semantic_classes=[3], 115 | required_channels=3, 116 | norm_method=None, 117 | pooling=None, 118 | location=True, 119 | use_imagenet=True, 120 | lite=False, 121 | upsample_type='upsampling2d', 122 | interpolation='bilinear', 123 | name='panopticnet', 124 | z_axis_convolutions=False, 125 | **kwargs): 126 | """Constructs a Mask-RCNN model using a backbone from 127 | ``keras-applications`` with optional semantic segmentation transforms. 128 | 129 | Args: 130 | backbone (str): Name of backbone to use. 131 | input_shape (tuple): The shape of the input data. 132 | backbone_levels (list): The backbone levels to be used. 133 | to create the feature pyramid. 134 | pyramid_levels (list): Pyramid levels to use. 135 | create_pyramid_features (function): Function to get the pyramid 136 | features from the backbone. 137 | create_semantic_head (function): Function to build a semantic head 138 | submodel. 139 | frames_per_batch (int): Size of z axis in generated batches. 140 | If equal to 1, assumes 2D data. 141 | temporal_mode: Mode of temporal convolution. Choose from 142 | ``{'conv','lstm', None}``. 143 | num_semantic_classes (list or dict): Number of semantic classes 144 | for each semantic head. If a ``dict``, keys will be used as 145 | head names and values will be the number of classes. 146 | norm_method (str): Normalization method to use with the 147 | :mod:`deepcell.layers.normalization.ImageNormalization2D` layer. 148 | location (bool): Whether to include a 149 | :mod:`deepcell.layers.location.Location2D` layer. 150 | use_imagenet (bool): Whether to load imagenet-based pretrained weights. 151 | lite (bool): Whether to use a depthwise conv in the feature pyramid 152 | rather than regular conv. 153 | upsample_type (str): Choice of upsampling layer to use from 154 | ``['upsamplelike', 'upsampling2d', 'upsampling3d']``. 155 | interpolation (str): Choice of interpolation mode for upsampling 156 | layers from ``['bilinear', 'nearest']``. 157 | pooling (str): optional pooling mode for feature extraction 158 | when ``include_top`` is ``False``. 159 | 160 | - None means that the output of the model will be 161 | the 4D tensor output of the 162 | last convolutional layer. 163 | - 'avg' means that global average pooling 164 | will be applied to the output of the 165 | last convolutional layer, and thus 166 | the output of the model will be a 2D tensor. 167 | - 'max' means that global max pooling will 168 | be applied. 169 | 170 | z_axis_convolutions (bool): Whether or not to do convolutions on 171 | 3D data across the z axis. 172 | required_channels (int): The required number of channels of the 173 | backbone. 3 is the default for all current backbones. 174 | kwargs (dict): Other standard inputs for ``retinanet_mask``. 175 | 176 | Raises: 177 | ValueError: ``temporal_mode`` not 'conv', 'lstm' or ``None`` 178 | 179 | Returns: 180 | tensorflow.keras.Model: Panoptic model with a backbone. 181 | """ 182 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 183 | conv = Conv3D if frames_per_batch > 1 else Conv2D 184 | conv_kernel = (1, 1, 1) if frames_per_batch > 1 else (1, 1) 185 | 186 | # Check input to __merge_temporal_features 187 | acceptable_modes = {'conv', 'lstm', None} 188 | if temporal_mode is not None: 189 | temporal_mode = str(temporal_mode).lower() 190 | if temporal_mode not in acceptable_modes: 191 | raise ValueError(f'temporal_mode {temporal_mode} not supported. Please choose ' 192 | f'from {acceptable_modes}.') 193 | 194 | # TODO only works for 2D: do we check for 3D as well? 195 | # What are the requirements for 3D data? 196 | img_shape = input_shape[1:] if channel_axis == 1 else input_shape[:-1] 197 | if img_shape[0] != img_shape[1]: 198 | raise ValueError(f'Input data must be square, got dimensions {img_shape}') 199 | 200 | if not math.log(img_shape[0], 2).is_integer(): 201 | raise ValueError('Input data dimensions must be a power of 2, ' 202 | f'got {img_shape[0]}') 203 | 204 | # Check input to interpolation 205 | acceptable_interpolation = {'bilinear', 'nearest'} 206 | if interpolation not in acceptable_interpolation: 207 | raise ValueError(f'Interpolation mode "{interpolation}" not supported. ' 208 | f'Choose from {list(acceptable_interpolation)}.') 209 | 210 | if inputs is None: 211 | if frames_per_batch > 1: 212 | if channel_axis == 1: 213 | input_shape_with_time = tuple( 214 | [input_shape[0], frames_per_batch] + list(input_shape)[1:]) 215 | else: 216 | input_shape_with_time = tuple( 217 | [frames_per_batch] + list(input_shape)) 218 | inputs = Input(shape=input_shape_with_time, name='input_0') 219 | else: 220 | inputs = Input(shape=input_shape, name='input_0') 221 | 222 | # Normalize input images 223 | if norm_method is None: 224 | norm = inputs 225 | else: 226 | if frames_per_batch > 1: 227 | norm = TimeDistributed(ImageNormalization2D( 228 | norm_method=norm_method, name='norm'), name='td_norm')(inputs) 229 | else: 230 | norm = ImageNormalization2D(norm_method=norm_method, 231 | name='norm')(inputs) 232 | 233 | # Add location layer 234 | if location: 235 | if frames_per_batch > 1: 236 | # TODO: TimeDistributed is incompatible with channels_first 237 | loc = TimeDistributed(Location2D(name='location'), 238 | name='td_location')(norm) 239 | else: 240 | loc = Location2D(name='location')(norm) 241 | concat = Concatenate(axis=channel_axis, 242 | name='concatenate_location')([norm, loc]) 243 | else: 244 | concat = norm 245 | 246 | # Force the channel size for backbone input to be `required_channels` 247 | fixed_inputs = conv(required_channels, conv_kernel, strides=1, 248 | padding='same', name='conv_channels')(concat) 249 | 250 | # Force the input shape 251 | axis = 0 if K.image_data_format() == 'channels_first' else -1 252 | fixed_input_shape = list(input_shape) 253 | fixed_input_shape[axis] = required_channels 254 | fixed_input_shape = tuple(fixed_input_shape) 255 | 256 | model_kwargs = { 257 | 'include_top': False, 258 | 'weights': None, 259 | 'input_shape': fixed_input_shape, 260 | 'pooling': pooling 261 | } 262 | 263 | _, backbone_dict = get_backbone(backbone, fixed_inputs, 264 | use_imagenet=use_imagenet, 265 | frames_per_batch=frames_per_batch, 266 | return_dict=True, 267 | **model_kwargs) 268 | 269 | backbone_dict_reduced = {k: backbone_dict[k] for k in backbone_dict 270 | if k in backbone_levels} 271 | 272 | ndim = 2 if frames_per_batch == 1 else 3 273 | 274 | pyramid_dict = create_pyramid_features(backbone_dict_reduced, 275 | ndim=ndim, 276 | lite=lite, 277 | interpolation=interpolation, 278 | upsample_type=upsample_type, 279 | z_axis_convolutions=z_axis_convolutions) 280 | 281 | features = [pyramid_dict[key] for key in pyramid_levels] 282 | 283 | if frames_per_batch > 1: 284 | temporal_features = [__merge_temporal_features(f, mode=temporal_mode, 285 | frames_per_batch=frames_per_batch) 286 | 287 | for f in features] 288 | for f, k in zip(temporal_features, pyramid_levels): 289 | pyramid_dict[k] = f 290 | 291 | semantic_levels = [int(re.findall(r'\d+', k)[0]) for k in pyramid_dict] 292 | target_level = min(semantic_levels) 293 | 294 | semantic_head_list = [] 295 | if not isinstance(num_semantic_classes, dict): 296 | num_semantic_classes = { 297 | k: v for k, v in enumerate(num_semantic_classes) 298 | } 299 | 300 | for k, v in num_semantic_classes.items(): 301 | semantic_head_list.append(create_semantic_head( 302 | pyramid_dict, n_classes=v, 303 | input_target=inputs, target_level=target_level, 304 | semantic_id=k, ndim=ndim, upsample_type=upsample_type, 305 | interpolation=interpolation, **kwargs)) 306 | 307 | outputs = semantic_head_list 308 | 309 | model = Model(inputs=inputs, outputs=outputs, name=name) 310 | return model 311 | -------------------------------------------------------------------------------- /src/deepcell/semantic_head.py: -------------------------------------------------------------------------------- 1 | from deepcell.fpn import semantic_upsample 2 | from deepcell.utils import get_sorted_keys 3 | from tensorflow.keras import backend as K 4 | from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D, 5 | Conv3D, Softmax) 6 | 7 | 8 | def create_semantic_head( 9 | pyramid_dict, input_target=None, n_classes=3, n_filters=128, n_dense=128, semantic_id=0, 10 | ndim=2, include_top=True, target_level=2, upsample_type="upsamplelike", 11 | interpolation="bilinear", **kwargs 12 | ): 13 | """Creates a semantic head from a feature pyramid network. 14 | Args: 15 | pyramid_dict (dict): Dictionary of pyramid names and features. 16 | input_target (tensor): Optional tensor with the input image. 17 | n_classes (int): The number of classes to be predicted. 18 | n_filters (int): The number of convolutional filters. 19 | n_dense (int): Number of dense filters. 20 | semantic_id (int): ID of the semantic head. 21 | ndim (int): The spatial dimensions of the input data. 22 | Must be either 2 or 3. 23 | include_top (bool): Whether to include the final layer of the model 24 | target_level (int): The level we need to reach. Performs 25 | 2x upsampling until we're at the target level. 26 | upsample_type (str): Choice of upsampling layer to use from 27 | ``['upsamplelike', 'upsampling2d', 'upsampling3d']``. 28 | interpolation (str): Choice of interpolation mode for upsampling 29 | layers from ``['bilinear', 'nearest']``. 30 | Raises: 31 | ValueError: ``ndim`` must be 2 or 3 32 | ValueError: ``interpolation`` not in ``['bilinear', 'nearest']`` 33 | ValueError: ``upsample_type`` not in 34 | ``['upsamplelike','upsampling2d', 'upsampling3d']`` 35 | Returns: 36 | tensorflow.keras.Layer: The semantic segmentation head 37 | """ 38 | # Check input to ndims 39 | if ndim not in {2, 3}: 40 | raise ValueError("ndim must be either 2 or 3. " "Received ndim = {}".format(ndim)) 41 | 42 | # Check input to interpolation 43 | acceptable_interpolation = {"bilinear", "nearest"} 44 | if interpolation not in acceptable_interpolation: 45 | raise ValueError( 46 | 'Interpolation mode "{}" not supported. ' 47 | "Choose from {}.".format(interpolation, list(acceptable_interpolation)) 48 | ) 49 | 50 | # Check input to upsample_type 51 | acceptable_upsample = {"upsamplelike", "upsampling2d", "upsampling3d"} 52 | if upsample_type not in acceptable_upsample: 53 | raise ValueError( 54 | 'Upsample method "{}" not supported. ' 55 | "Choose from {}.".format(upsample_type, list(acceptable_upsample)) 56 | ) 57 | 58 | # Check that there is an input_target if upsamplelike is used 59 | if upsample_type == "upsamplelike" and input_target is None: 60 | raise ValueError("upsamplelike requires an input_target.") 61 | 62 | conv = Conv2D if ndim == 2 else Conv3D 63 | conv_kernel = (1,) * ndim 64 | 65 | if K.image_data_format() == "channels_first": 66 | channel_axis = 1 67 | else: 68 | channel_axis = -1 69 | 70 | if n_classes == 1: 71 | include_top = False 72 | 73 | # Get pyramid names and features into list form 74 | pyramid_names = get_sorted_keys(pyramid_dict) 75 | pyramid_features = [pyramid_dict[name] for name in pyramid_names] 76 | 77 | # Reverse pyramid names and features 78 | pyramid_names.reverse() 79 | pyramid_features.reverse() 80 | 81 | # Previous method of building feature pyramids 82 | # semantic_features, semantic_names = [], [] 83 | # for N, P in zip(pyramid_names, pyramid_features): 84 | # # Get level and determine how much to upsample 85 | # level = int(re.findall(r'\d+', N)[0]) 86 | # 87 | # n_upsample = level - target_level 88 | # target = semantic_features[-1] if len(semantic_features) > 0 else None 89 | # 90 | # # Use semantic upsample to get semantic map 91 | # semantic_features.append(semantic_upsample( 92 | # P, n_upsample, n_filters=n_filters, target=target, ndim=ndim, 93 | # upsample_type=upsample_type, interpolation=interpolation, 94 | # semantic_id=semantic_id)) 95 | # semantic_names.append('Q{}'.format(level)) 96 | 97 | # Add all the semantic features 98 | # semantic_sum = semantic_features[0] 99 | # for semantic_feature in semantic_features[1:]: 100 | # semantic_sum = Add()([semantic_sum, semantic_feature]) 101 | 102 | # TODO: bad name but using the same name more clearly indicates 103 | # how to integrate the previous version 104 | semantic_sum = pyramid_features[-1] 105 | 106 | # Final upsampling 107 | # min_level = int(re.findall(r'\d+', pyramid_names[-1])[0]) 108 | # n_upsample = min_level - target_level 109 | n_upsample = target_level 110 | x = semantic_upsample( 111 | semantic_sum, 112 | n_upsample, 113 | # n_filters=n_filters, # TODO: uncomment and retrain 114 | target=input_target, 115 | ndim=ndim, 116 | upsample_type=upsample_type, 117 | semantic_id=semantic_id, 118 | interpolation=interpolation, 119 | ) 120 | 121 | # Apply conv in place of previous tensor product 122 | x = conv( 123 | n_dense, 124 | conv_kernel, 125 | strides=1, 126 | padding="same", 127 | name="conv_0_semantic_{}".format(semantic_id), 128 | )(x) 129 | x = BatchNormalization( 130 | axis=channel_axis, name="batch_normalization_0_semantic_{}".format(semantic_id) 131 | )(x) 132 | x = Activation("relu", name="relu_0_semantic_{}".format(semantic_id))(x) 133 | 134 | # Apply conv and softmax layer 135 | x = conv( 136 | n_classes, 137 | conv_kernel, 138 | strides=1, 139 | padding="same", 140 | name="conv_1_semantic_{}".format(semantic_id), 141 | )(x) 142 | 143 | if include_top: 144 | x = Softmax(axis=channel_axis, dtype=K.floatx(), name="semantic_{}".format(semantic_id))(x) 145 | else: 146 | x = Activation("sigmoid", dtype=K.floatx(), name="semantic_{}".format(semantic_id))(x) 147 | 148 | return x 149 | -------------------------------------------------------------------------------- /src/deepcell/utils.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of 3 | # Technology (Caltech), with support from the Paul Allen Family Foundation, 4 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. 5 | # All rights reserved. 6 | # 7 | # Licensed under a modified Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE 12 | # 13 | # The Work provided may be used for non-commercial academic purposes only. 14 | # For any other use of the Work, including commercial use, please contact: 15 | # vanvalenlab@gmail.com 16 | # 17 | # Neither the name of Caltech nor the names of its contributors may be used 18 | # to endorse or promote products derived from this software without specific 19 | # prior written permission. 20 | # 21 | # Unless required by applicable law or agreed to in writing, software 22 | # distributed under the License is distributed on an "AS IS" BASIS, 23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | # See the License for the specific language governing permissions and 25 | # limitations under the License. 26 | # ============================================================================== 27 | from tensorflow.python.client import device_lib 28 | 29 | 30 | def count_gpus(): 31 | """Get the number of available GPUs. 32 | 33 | Returns: 34 | int: count of GPUs as integer 35 | """ 36 | devices = device_lib.list_local_devices() 37 | gpus = [d for d in devices if d.name.lower().startswith('/device:gpu')] 38 | return len(gpus) 39 | 40 | 41 | def get_sorted_keys(dict_to_sort): 42 | """Gets the keys from a dict and sorts them in ascending order. 43 | Assumes keys are of the form ``Ni``, where ``N`` is a letter and ``i`` 44 | is an integer. 45 | 46 | Args: 47 | dict_to_sort (dict): dict whose keys need sorting 48 | 49 | Returns: 50 | list: list of sorted keys from ``dict_to_sort`` 51 | """ 52 | sorted_keys = list(dict_to_sort.keys()) 53 | sorted_keys.sort(key=lambda x: int(x[1:])) 54 | return sorted_keys -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/tests/__init__.py -------------------------------------------------------------------------------- /tests/application_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import keras 3 | import pytest 4 | import tempfile 5 | import numpy as np 6 | from segmentation_data_prep_test import prep_object_and_inputs 7 | from cell_classification.application import Nimbus, nimbus_preprocess 8 | from cell_classification.application import prep_deepcell_naming_convention 9 | 10 | 11 | def test_prep_deepcell_naming_convention(): 12 | segmentation_naming_convention = prep_deepcell_naming_convention("test_dir") 13 | seg_path = segmentation_naming_convention("fov0") 14 | assert isinstance(segmentation_naming_convention, type(lambda x: x)) 15 | assert isinstance(seg_path, str) 16 | assert seg_path == os.path.join("test_dir", "fov0_whole_cell.tiff") 17 | 18 | 19 | def test_nimbus_preprocess(): 20 | input_data = np.random.rand(1, 1024, 1024, 2) 21 | expected_output = np.copy(input_data) 22 | expected_output[..., 0] = input_data[..., 0] / 1.2 23 | 24 | output = nimbus_preprocess( 25 | input_data, normalize=True, marker="test", normalization_dict={"test": 1.2} 26 | ) 27 | # check if shape and values are in the expected range 28 | assert output.shape == (1, 1024, 1024, 2) 29 | assert output.max() <= 1.0 30 | assert output.min() >= 0.0 31 | # check if normalization was applied 32 | assert np.allclose(expected_output, output, atol=1e-5) 33 | 34 | # check if normalization works when not dict is given 35 | output = nimbus_preprocess(input_data, normalize=True, marker="test") 36 | expected_output[..., 0] = input_data[..., 0] / np.quantile(input_data[..., 0], 0.999) 37 | expected_output = expected_output.clip(0, 1) 38 | 39 | assert np.allclose(expected_output, output, atol=1e-5) 40 | 41 | # check if normalization works when marker is not in dict 42 | output = nimbus_preprocess( 43 | input_data, normalize=True, marker="test2", normalization_dict={"test": 1.2} 44 | ) 45 | assert np.allclose(expected_output, output, atol=1e-5) 46 | 47 | # check if normalization works when normalization is set to False 48 | output = nimbus_preprocess(input_data, normalize=False) 49 | assert np.array_equal(input_data, output) 50 | 51 | 52 | def test_initialize_model(): 53 | with tempfile.TemporaryDirectory() as temp_dir: 54 | def segmentation_naming_convention(fov_path): 55 | return os.path.join(fov_path, "cell_segmentation.tiff") 56 | 57 | nimbus = Nimbus(["none_path"], segmentation_naming_convention, temp_dir) 58 | assert isinstance(nimbus.model, keras.engine.functional.Functional) 59 | 60 | 61 | def test_check_inputs(): 62 | with tempfile.TemporaryDirectory() as temp_dir: 63 | def segmentation_naming_convention(fov_path): 64 | return os.path.join(fov_path, "cell_segmentation.tiff") 65 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir) 66 | output_dir = temp_dir 67 | 68 | # check if no errors are raised when all inputs are valid 69 | nimbus = Nimbus(fov_paths, segmentation_naming_convention, output_dir) 70 | nimbus.check_inputs() 71 | 72 | # check if error is raised when a path in fov_paths does not exist on disk 73 | nimbus.fov_paths.append("invalid_path") 74 | with pytest.raises(FileNotFoundError, match="invalid_path"): 75 | nimbus.check_inputs() 76 | 77 | # check if error is raised when segmentation_name_convention does not return a valid path 78 | nimbus.fov_paths = fov_paths 79 | 80 | def segmentation_naming_convention(fov_path): 81 | return os.path.join(fov_path, "invalid_path.tiff") 82 | 83 | with pytest.raises(FileNotFoundError, match="invalid_path"): 84 | nimbus.check_inputs() 85 | 86 | # check if error is raised when output_dir does not exist on disk 87 | nimbus.segmentation_naming_convention = segmentation_naming_convention 88 | nimbus.output_dir = "invalid_path" 89 | with pytest.raises(FileNotFoundError, match="invalid_path"): 90 | nimbus.check_inputs() 91 | 92 | 93 | def test_prepare_normalization_dict(): 94 | # test if normalization dict gets prepared and saved, in-depth tests are in inference_test.py 95 | with tempfile.TemporaryDirectory() as temp_dir: 96 | def segmentation_naming_convention(fov_path): 97 | return os.path.join(fov_path, "cell_segmentation.tiff") 98 | 99 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir) 100 | nimbus = Nimbus( 101 | fov_paths, segmentation_naming_convention, temp_dir, exclude_channels=["CD57"] 102 | ) 103 | # test if normalization dict gets prepared and saved 104 | nimbus.prepare_normalization_dict(overwrite=True) 105 | assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) 106 | assert "CD57" not in nimbus.normalization_dict.keys() 107 | 108 | # test if normalization dict gets loaded 109 | nimbus_2 = Nimbus( 110 | fov_paths, segmentation_naming_convention, temp_dir, exclude_channels=["CD57"] 111 | ) 112 | nimbus_2.prepare_normalization_dict() 113 | assert nimbus_2.normalization_dict == nimbus.normalization_dict 114 | 115 | 116 | def test_predict_fovs(): 117 | with tempfile.TemporaryDirectory() as temp_dir: 118 | def segmentation_naming_convention(fov_path): 119 | return os.path.join(fov_path, "cell_segmentation.tiff") 120 | 121 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir) 122 | os.remove(os.path.join(temp_dir, 'normalization_dict.json')) 123 | output_dir = os.path.join(temp_dir, "nimbus_output") 124 | fov_paths = fov_paths[:1] 125 | nimbus = Nimbus( 126 | fov_paths, segmentation_naming_convention, output_dir, 127 | exclude_channels=["CD57", "CD11c", "XYZ"] 128 | ) 129 | cell_table = nimbus.predict_fovs() 130 | 131 | # check if all channels are in the cell_table 132 | for channel in nimbus.normalization_dict.keys(): 133 | assert channel+"_pred" in cell_table.columns 134 | # check if cell_table was saved 135 | assert os.path.exists(os.path.join(output_dir, "nimbus_cell_table.csv")) 136 | 137 | # check if fov folders were created in output_dir 138 | for fov_path in fov_paths: 139 | fov = os.path.basename(fov_path) 140 | assert os.path.exists(os.path.join(output_dir, fov)) 141 | # check if all predictions were saved 142 | for channel in nimbus.normalization_dict.keys(): 143 | assert os.path.exists(os.path.join(output_dir, fov, channel+".tiff")) 144 | # check if CD57 was excluded 145 | assert not os.path.exists(os.path.join(output_dir, fov, "CD57.tiff")) 146 | assert not os.path.exists(os.path.join(output_dir, fov, "CD11c.tiff")) 147 | assert not os.path.exists(os.path.join(output_dir, fov, "XYZ.tiff")) 148 | -------------------------------------------------------------------------------- /tests/augmentation_pipeline_test.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import imgaug.augmenters as iaa 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from cell_classification.augmentation_pipeline import ( 9 | Flip, GaussianBlur, GaussianNoise, LinearContrast, MixUp, Rot90, Zoom, 10 | augment_images, get_augmentation_pipeline, prepare_keras_aug, 11 | prepare_tf_aug, py_aug) 12 | 13 | parametrize = pytest.mark.parametrize 14 | 15 | 16 | def get_params(): 17 | return { 18 | # flip 19 | "flip_prob": 1.0, 20 | # affine 21 | "affine_prob": 1.0, 22 | "scale_min": 0.5, 23 | "scale_max": 1.5, 24 | "shear_angle": 10, 25 | # elastic 26 | "elastic_prob": 1.0, 27 | "elastic_alpha": [0, 5.0], 28 | "elastic_sigma": 0.5, 29 | # rotate 30 | "rotate_prob": 1.0, 31 | "rotate_count": 3, 32 | # gaussian noise 33 | "gaussian_noise_prob": 1.0, 34 | "gaussian_noise_min": 0.1, 35 | "gaussian_noise_max": 0.5, 36 | # gaussian blur 37 | "gaussian_blur_prob": 0.0, 38 | "gaussian_blur_min": 0.1, 39 | "gaussian_blur_max": 0.5, 40 | # contrast aug 41 | "contrast_prob": 0.0, 42 | "contrast_min": 0.1, 43 | "contrast_max": 2.0, 44 | } 45 | 46 | 47 | def test_get_augmentation_pipeline(): 48 | params = get_params() 49 | augmentation_pipeline = get_augmentation_pipeline(params) 50 | assert isinstance(augmentation_pipeline, iaa.Sequential) 51 | 52 | 53 | @parametrize("batch_num", [1, 2, 3]) 54 | @parametrize("chan_num", [1, 2, 3]) 55 | def test_augment_images(batch_num, chan_num): 56 | params = get_params() 57 | augmentation_pipeline = get_augmentation_pipeline(params) 58 | images = np.zeros([batch_num, 100, 100, chan_num], dtype=np.float32) 59 | masks = np.zeros([batch_num, 100, 100, chan_num], dtype=np.int32) 60 | images[0, :50, :50, :] = 10.1 61 | images[0, 50:, 50:, :] = 201.12 62 | masks[0, :50, :50] = 1 63 | masks[0, 50:, 50:] = 2 64 | augmented_images, augmented_masks = augment_images(images, masks, augmentation_pipeline) 65 | 66 | # check if right types and shapes are returned 67 | assert isinstance(augmented_images, np.ndarray) 68 | assert isinstance(augmented_masks, np.ndarray) 69 | assert augmented_images.dtype == np.float32 70 | assert augmented_masks.dtype == np.int32 71 | assert augmented_images.shape == images.shape 72 | assert augmented_masks.shape == masks.shape 73 | 74 | # check if images are augmented 75 | assert not np.array_equal(augmented_images, images) 76 | assert not np.array_equal(augmented_masks, masks) 77 | 78 | # check if masks are still binary with the right labels 79 | assert list(np.unique(augmented_masks)) == [0, 1, 2] 80 | 81 | # check if images and masks where augmented with the same spatial augmentations approx. 82 | assert np.abs(augmented_images[augmented_masks == 0].mean() - images[masks == 0].mean()) < 1.5 83 | assert np.abs(augmented_images[augmented_masks == 1].mean() - images[masks == 1].mean()) < 1.5 84 | assert np.abs(augmented_images[augmented_masks == 2].mean() - images[masks == 2].mean()) < 5 85 | 86 | # check control flow for no channel dimensions for the masks 87 | masks = np.zeros([batch_num, 100, 100], dtype=np.int32) 88 | _, augmented_masks = augment_images(images, masks, augmentation_pipeline) 89 | assert augmented_masks.shape == masks.shape 90 | 91 | 92 | def prepare_data(batch_num, return_tensor=False): 93 | mplex_img = np.zeros([batch_num, 100, 100, 2], dtype=np.float32) 94 | binary_mask = np.zeros([batch_num, 100, 100, 1], dtype=np.int32) 95 | marker_activity_mask = np.zeros([batch_num, 100, 100, 1], dtype=np.int32) 96 | mplex_img[0, :30, :50, :] = 10.1 97 | mplex_img[0, 50:, 50:, :] = 21.12 98 | mplex_img[-1, 30:60, :50, :] = 14.11 99 | mplex_img[-1, :50, 50:, :] = 18.12 100 | binary_mask[0, :30, :50] = 1 101 | binary_mask[0, 50:, 50:] = 1 102 | binary_mask[-1, 30:60, :50, :] = 1 103 | binary_mask[-1, :50, 50:, :] = 1 104 | marker_activity_mask[0, :30, :50] = 1 105 | marker_activity_mask[0, 50:, 50:] = 2 106 | marker_activity_mask[-1, 30:60, :50, :] = 1 107 | marker_activity_mask[-1, :50, 50:, :] = 2 108 | 109 | if return_tensor: 110 | mplex_img = tf.constant(mplex_img, tf.float32) 111 | binary_mask = tf.constant(binary_mask, tf.int32) 112 | marker_activity_mask = tf.constant(marker_activity_mask, tf.int32) 113 | return mplex_img, binary_mask, marker_activity_mask 114 | 115 | 116 | @parametrize("batch_num", [1, 2, 3]) 117 | def test_prepare_tf_aug(batch_num): 118 | params = get_params() 119 | augmentation_pipeline = get_augmentation_pipeline(params) 120 | tf_aug = prepare_tf_aug(augmentation_pipeline) 121 | mplex_img, binary_mask, marker_activity_mask = prepare_data(batch_num) 122 | mplex_aug, mask_out, marker_activity_aug = tf_aug( 123 | tf.constant(mplex_img, tf.float32), binary_mask, marker_activity_mask 124 | ) 125 | mplex_img, binary_mask, marker_activity_mask = prepare_data(batch_num) 126 | # check if right types and shapes are returned 127 | assert isinstance(mplex_aug, np.ndarray) 128 | assert isinstance(mask_out, np.ndarray) 129 | assert isinstance(marker_activity_aug, np.ndarray) 130 | assert mplex_aug.dtype == np.float32 131 | assert mask_out.dtype == np.int32 132 | assert marker_activity_aug.dtype == np.int32 133 | assert mplex_aug.shape == mplex_img.shape 134 | assert mask_out.shape == binary_mask.shape 135 | assert marker_activity_aug.shape == marker_activity_mask.shape 136 | 137 | 138 | @parametrize("batch_num", [1, 2, 3]) 139 | def test_py_aug(batch_num): 140 | params = get_params() 141 | augmentation_pipeline = get_augmentation_pipeline(params) 142 | tf_aug = prepare_tf_aug(augmentation_pipeline) 143 | mplex_img, binary_mask, marker_activity_mask = prepare_data(batch_num) 144 | batch = { 145 | "mplex_img": tf.constant(mplex_img, tf.float32), 146 | "binary_mask": tf.constant(binary_mask, tf.int32), 147 | "marker_activity_mask": tf.constant(marker_activity_mask, tf.int32), 148 | "dataset": "test_dataset", 149 | "marker": "test_marker", 150 | "imaging_platform": "test_platform", 151 | } 152 | batch_aug = py_aug(deepcopy(batch), tf_aug) 153 | 154 | # check if right types and shapes are returned 155 | for key in batch.keys(): 156 | assert isinstance(batch_aug[key], type(batch[key])) 157 | 158 | for key in ["mplex_img", "binary_mask", "marker_activity_mask"]: 159 | assert batch_aug[key].shape == batch[key].shape 160 | assert not np.array_equal(batch_aug[key], batch[key]) 161 | 162 | 163 | @parametrize("batch_num", [2, 4, 8]) 164 | def test_prepare_keras_aug(batch_num): 165 | params = get_params() 166 | augmentation_pipeline = prepare_keras_aug(params) 167 | images, _, masks = prepare_data(batch_num, True) 168 | augmented_images, augmented_masks = augmentation_pipeline(images, masks) 169 | 170 | # check if right types and shapes are returned 171 | assert augmented_images.dtype == tf.float32 172 | assert augmented_masks.dtype == tf.int32 173 | assert augmented_images.shape == images.shape 174 | assert augmented_masks.shape == masks.shape 175 | 176 | 177 | @parametrize("batch_num", [2, 4, 8]) 178 | def test_flip(batch_num): 179 | images, _, masks = prepare_data(batch_num, True) 180 | flip = Flip(prob=1.0) 181 | aug_img, aug_mask = flip(images, masks) 182 | 183 | # check if right types and shapes are returned 184 | assert aug_img.dtype == images.dtype 185 | assert aug_mask.dtype == masks.dtype 186 | assert aug_img.shape == images.shape 187 | assert aug_mask.shape == masks.shape 188 | 189 | # check if data got flipped 190 | assert not np.array_equal(aug_img, images) 191 | assert not np.array_equal(aug_mask, masks) 192 | assert np.sum(aug_img) == np.sum(images) 193 | assert np.sum(aug_mask) == np.sum(masks) 194 | 195 | 196 | @parametrize("batch_num", [2, 4, 8]) 197 | def test_rot90(batch_num): 198 | images, _, masks = prepare_data(batch_num, True) 199 | rot90 = Rot90(prob=1.0, rotate_count=2) 200 | aug_img, aug_mask = rot90(images, masks) 201 | 202 | # check if right types and shapes are returned 203 | assert aug_img.dtype == images.dtype 204 | assert aug_mask.dtype == masks.dtype 205 | assert aug_img.shape == images.shape 206 | assert aug_mask.shape == masks.shape 207 | 208 | # check if data got rotated 209 | assert not np.array_equal(aug_img, images) 210 | assert not np.array_equal(aug_mask, masks) 211 | assert np.sum(aug_img) == np.sum(images) 212 | assert np.sum(aug_mask) == np.sum(masks) 213 | 214 | 215 | @parametrize("batch_num", [2, 4, 8]) 216 | def test_gaussian_noise(batch_num): 217 | images, _, masks = prepare_data(batch_num, True) 218 | gaussian_noise = GaussianNoise(prob=1.0) 219 | aug_img, aug_mask = gaussian_noise(images, masks) 220 | 221 | # check if right types and shapes are returned 222 | assert aug_img.dtype == images.dtype 223 | assert aug_mask.dtype == masks.dtype 224 | assert aug_img.shape == images.shape 225 | assert aug_mask.shape == masks.shape 226 | 227 | # check if data got augmented 228 | assert not np.array_equal(aug_img, images) 229 | assert np.array_equal(aug_mask, masks) 230 | assert np.isclose(np.mean(aug_img), np.mean(images), atol=0.1) 231 | 232 | 233 | @parametrize("batch_num", [2, 4, 8]) 234 | def test_gaussian_blur(batch_num): 235 | images, _, masks = prepare_data(batch_num, True) 236 | gaussian_blur = GaussianBlur(1.0, 0.5, 1.5, 5) 237 | aug_img, aug_mask = gaussian_blur(images, masks) 238 | 239 | # check if right types and shapes are returned 240 | assert aug_img.dtype == images.dtype 241 | assert aug_mask.dtype == masks.dtype 242 | assert aug_img.shape == images.shape 243 | assert aug_mask.shape == masks.shape 244 | 245 | # check if data got augmented 246 | assert not np.array_equal(aug_img, images) 247 | assert np.array_equal(aug_mask, masks) 248 | assert np.isclose(np.mean(aug_img), np.mean(images), atol=0.2) 249 | 250 | 251 | @parametrize("batch_num", [2, 4, 8]) 252 | def test_zoom(batch_num): 253 | images, _, masks = prepare_data(batch_num, True) 254 | zoom = Zoom(1.0, 0.5, 0.5) 255 | aug_img, aug_mask = zoom(images, masks) 256 | 257 | # check if right types and shapes are returned 258 | assert aug_img.dtype == images.dtype 259 | assert aug_mask.dtype == masks.dtype 260 | assert aug_img.shape == images.shape 261 | assert aug_mask.shape == masks.shape 262 | 263 | # check if data got augmented 264 | assert not np.array_equal(aug_img, images) 265 | assert not np.array_equal(aug_mask, masks) 266 | 267 | # check if data got zoomed to 0.5 268 | assert np.sum(aug_img) == np.sum(images) / 4 269 | assert np.sum(aug_mask) == np.sum(masks) / 4 270 | 271 | 272 | @parametrize("batch_num", [2, 4, 8]) 273 | def test_linear_contrast(batch_num): 274 | images, _, masks = prepare_data(batch_num, True) 275 | linear_contrast = LinearContrast(1.0, 0.75, 0.75) 276 | aug_img, aug_mask = linear_contrast(images, masks) 277 | 278 | # check if right types and shapes are returned 279 | assert aug_img.dtype == images.dtype 280 | assert aug_mask.dtype == masks.dtype 281 | assert aug_img.shape == images.shape 282 | assert aug_mask.shape == masks.shape 283 | 284 | # check if data got augmented 285 | assert not np.array_equal(aug_img, images) 286 | assert np.array_equal(aug_mask, masks) 287 | assert np.isclose(np.mean(aug_img), np.mean(images) * 0.75, atol=0.01) 288 | 289 | 290 | @parametrize("batch_num", [2, 4, 8]) 291 | def test_mixup(batch_num): 292 | images, _, labels = prepare_data(batch_num, True) 293 | mixup = MixUp(1.0, 0.5) 294 | x_mplex, x_binary = tf.split(images, 2, axis=-1) 295 | loss_mask = tf.cast(labels, tf.float32) 296 | 297 | x_mplex_aug, x_binary_aug, labels_aug, loss_mask_aug = mixup( 298 | x_mplex, x_binary, labels, loss_mask 299 | ) 300 | 301 | # check if right types and shapes are returned 302 | assert x_mplex_aug.dtype == x_mplex.dtype 303 | assert x_binary_aug.dtype == x_binary.dtype 304 | assert labels_aug.dtype == tf.float32 305 | assert loss_mask_aug.dtype == loss_mask.dtype 306 | assert x_mplex_aug.shape == x_mplex.shape 307 | assert x_binary_aug.shape == x_binary.shape 308 | assert labels_aug.shape == labels.shape 309 | assert loss_mask_aug.shape == loss_mask.shape 310 | 311 | # check if data got augmented 312 | assert not np.array_equal(x_mplex_aug, x_mplex) 313 | assert not np.array_equal(x_binary_aug, x_binary) 314 | assert not np.array_equal(labels_aug, labels) 315 | assert not np.array_equal(loss_mask_aug, loss_mask) 316 | 317 | # check if data got mixed up 318 | assert np.isclose(np.mean(x_mplex_aug), np.mean(x_mplex), atol=0.1) 319 | assert np.isclose(np.mean(x_binary_aug), np.mean(x_binary), atol=0.1) 320 | assert np.isclose(np.mean(labels_aug), np.mean(labels), atol=0.1) 321 | assert np.isclose( 322 | np.mean(loss_mask_aug), np.mean(loss_mask*tf.reverse(loss_mask, [0])), atol=0.1 323 | ) 324 | -------------------------------------------------------------------------------- /tests/deepcell_application_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of 2 | # Technology (Caltech), with support from the Paul Allen Family Foundation, 3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. 4 | # All rights reserved. 5 | # 6 | # Licensed under a modified Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE 11 | # 12 | # The Work provided may be used for non-commercial academic purposes only. 13 | # For any other use of the Work, including commercial use, please contact: 14 | # vanvalenlab@gmail.com 15 | # 16 | # Neither the name of Caltech nor the names of its contributors may be used 17 | # to endorse or promote products derived from this software without specific 18 | # prior written permission. 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | # ============================================================================== 26 | """Tests for Application""" 27 | 28 | 29 | from itertools import product 30 | from unittest.mock import Mock 31 | 32 | import pytest 33 | import numpy as np 34 | 35 | from tensorflow.python.platform import test 36 | 37 | from deepcell.application import Application 38 | 39 | 40 | class DummyModel: 41 | 42 | def __init__(self, n_out=1): 43 | self.n_out = n_out 44 | 45 | def predict(self, x, batch_size=4): 46 | y = np.random.rand(*x.shape) 47 | return [y] * self.n_out 48 | 49 | 50 | class TestApplication(test.TestCase): 51 | 52 | def test_predict_notimplemented(self): 53 | model = DummyModel() 54 | kwargs = {'model_mpp': 0.65, 55 | 'model_image_shape': (128, 128, 1)} 56 | app = Application(model, **kwargs) 57 | 58 | x = np.random.rand(1, 500, 500, 1) 59 | 60 | with self.assertRaises(NotImplementedError): 61 | app.predict(x) 62 | 63 | def test_resize(self): 64 | model = DummyModel() 65 | kwargs = {'model_mpp': 0.65, 66 | 'model_image_shape': (128, 128, 1)} 67 | app = Application(model, **kwargs) 68 | 69 | x = np.random.rand(1, 500, 500, 1) 70 | 71 | # image_mpp = None --> No resize 72 | y = app._resize_input(x, image_mpp=None) 73 | self.assertEqual(x.shape, y.shape) 74 | 75 | # image_mpp = model_mpp --> No resize 76 | y = app._resize_input(x, image_mpp=kwargs['model_mpp']) 77 | self.assertEqual(x.shape, y.shape) 78 | 79 | # image_mpp > model_mpp --> resize 80 | y = app._resize_input(x, image_mpp=2.1 * kwargs['model_mpp']) 81 | self.assertEqual(2.1, np.round(y.shape[1] / x.shape[1], decimals=1)) 82 | 83 | # image_mpp < model_mpp --> resize 84 | y = app._resize_input(x, image_mpp=0.7 * kwargs['model_mpp']) 85 | self.assertEqual(0.7, np.round(y.shape[1] / x.shape[1], decimals=1)) 86 | 87 | def test_preprocess(self): 88 | 89 | def _preprocess(x): 90 | y = np.ones(x.shape) 91 | return y 92 | 93 | model = DummyModel() 94 | x = np.random.rand(1, 30, 30, 1) 95 | 96 | # Test no preprocess input 97 | app = Application(model) 98 | y = app._preprocess(x) 99 | self.assertAllEqual(x, y) 100 | 101 | # Test ones function 102 | kwargs = {'preprocessing_fn': _preprocess} 103 | app = Application(model, **kwargs) 104 | y = app._preprocess(x) 105 | self.assertAllEqual(np.ones(x.shape), y) 106 | 107 | # Test bad input 108 | kwargs = {'preprocessing_fn': 'x'} 109 | with self.assertRaises(ValueError): 110 | app = Application(model, **kwargs) 111 | 112 | def test_tile_input(self): 113 | model = DummyModel() 114 | kwargs = {'model_mpp': 0.65, 115 | 'model_image_shape': (128, 128, 1)} 116 | app = Application(model, **kwargs) 117 | 118 | # No tiling 119 | x = np.random.rand(1, 128, 128, 1) 120 | y, tile_info = app._tile_input(x) 121 | self.assertEqual(x.shape, y.shape) 122 | self.assertIsInstance(tile_info, dict) 123 | 124 | # Tiling square 125 | x = np.random.rand(1, 400, 400, 1) 126 | y, tile_info = app._tile_input(x) 127 | self.assertEqual(kwargs['model_image_shape'][:-1], y.shape[1:-1]) 128 | self.assertIsInstance(tile_info, dict) 129 | 130 | # Tiling rectangle 131 | x = np.random.rand(1, 300, 500, 1) 132 | y, tile_info = app._tile_input(x) 133 | self.assertEqual(kwargs['model_image_shape'][:-1], y.shape[1:-1]) 134 | self.assertIsInstance(tile_info, dict) 135 | 136 | # Smaller than expected 137 | x = np.random.rand(1, 100, 100, 1) 138 | y, tile_info = app._tile_input(x) 139 | self.assertEqual(kwargs['model_image_shape'][:-1], y.shape[1:-1]) 140 | self.assertIsInstance(tile_info, dict) 141 | 142 | def test_postprocess(self): 143 | 144 | def _postprocess(Lx): 145 | y = np.ones(Lx[0].shape) 146 | return y 147 | 148 | model = DummyModel() 149 | x = np.random.rand(1, 30, 30, 1) 150 | 151 | # No input 152 | app = Application(model) 153 | y = app._postprocess(x) 154 | self.assertAllEqual(x, y) 155 | 156 | # Ones 157 | kwargs = {'postprocessing_fn': _postprocess} 158 | app = Application(model, **kwargs) 159 | y = app._postprocess([x]) 160 | self.assertAllEqual(np.ones(x.shape), y) 161 | 162 | # Bad input 163 | kwargs = {'postprocessing_fn': 'x'} 164 | with self.assertRaises(ValueError): 165 | app = Application(model, **kwargs) 166 | 167 | def test_untile_output(self): 168 | model = DummyModel() 169 | kwargs = {'model_image_shape': (128, 128, 1)} 170 | app = Application(model, **kwargs) 171 | 172 | # No tiling 173 | x = np.random.rand(1, 128, 128, 1) 174 | tiles, tile_info = app._tile_input(x) 175 | y = app._untile_output(tiles, tile_info) 176 | self.assertEqual(x.shape, y.shape) 177 | 178 | # Tiling square 179 | x = np.random.rand(1, 400, 400, 1) 180 | tiles, tile_info = app._tile_input(x) 181 | y = app._untile_output(tiles, tile_info) 182 | self.assertEqual(x.shape, y.shape) 183 | 184 | # Tiling rectangle 185 | x = np.random.rand(1, 300, 500, 1) 186 | tiles, tile_info = app._tile_input(x) 187 | y = app._untile_output(tiles, tile_info) 188 | self.assertEqual(x.shape, y.shape) 189 | 190 | # Smaller than expected 191 | x = np.random.rand(1, 100, 100, 1) 192 | tiles, tile_info = app._tile_input(x) 193 | y = app._untile_output(tiles, tile_info) 194 | self.assertEqual(x.shape, y.shape) 195 | 196 | def test_resize_output(self): 197 | model = DummyModel() 198 | kwargs = {'model_image_shape': (128, 128, 1)} 199 | app = Application(model, **kwargs) 200 | 201 | x = np.random.rand(1, 128, 128, 1) 202 | 203 | # x.shape = original_shape --> no resize 204 | y = app._resize_output(x, x.shape) 205 | self.assertEqual(x.shape, y.shape) 206 | 207 | # x.shape != original_shape --> resize 208 | original_shape = (1, 500, 500, 1) 209 | y = app._resize_output(x, original_shape) 210 | self.assertEqual(original_shape, y.shape) 211 | 212 | # test multiple outputs are also resized 213 | x_list = [x, x] 214 | 215 | # x.shape = original_shape --> no resize 216 | y = app._resize_output(x_list, x.shape) 217 | self.assertIsInstance(y, list) 218 | for y_sub in y: 219 | self.assertEqual(x.shape, y_sub.shape) 220 | 221 | # x.shape != original_shape --> resize 222 | original_shape = (1, 500, 500, 1) 223 | y = app._resize_output(x_list, original_shape) 224 | self.assertIsInstance(y, list) 225 | for y_sub in y: 226 | self.assertEqual(original_shape, y_sub.shape) 227 | 228 | def test_format_model_output(self): 229 | def _format_model_output(Lx): 230 | return {'inner-distance': Lx} 231 | 232 | model = DummyModel() 233 | x = np.random.rand(1, 30, 30, 1) 234 | 235 | # No function 236 | app = Application(model) 237 | y = app._format_model_output(x) 238 | self.assertAllEqual(x, y) 239 | 240 | # single image 241 | kwargs = {'format_model_output_fn': _format_model_output} 242 | app = Application(model, **kwargs) 243 | y = app._format_model_output(x) 244 | self.assertAllEqual(x, y['inner-distance']) 245 | 246 | def test_batch_predict(self): 247 | 248 | def predict1(x, batch_size=4): 249 | y = np.random.rand(*x.shape) 250 | return [y] 251 | 252 | def predict2(x, batch_size=4): 253 | y = np.random.rand(*x.shape) 254 | return [y] * 2 255 | 256 | num_images = [4, 8, 10] 257 | num_pred_heads = [1, 2] 258 | batch_sizes = [1, 4, 5] 259 | prod = product(num_images, num_pred_heads, batch_sizes) 260 | 261 | for num_image, num_pred_head, batch_size in prod: 262 | model = DummyModel(n_out=num_pred_head) 263 | app = Application(model) 264 | 265 | x = np.random.rand(num_image, 128, 128, 1) 266 | 267 | if num_pred_head == 1: 268 | app.model.predict = Mock(side_effect=predict1) 269 | else: 270 | app.model.predict = Mock(side_effect=predict2) 271 | y = app._batch_predict(x, batch_size=batch_size) 272 | 273 | assert app.model.predict.call_count == np.ceil(num_image / batch_size) 274 | 275 | self.assertEqual(x.shape, y[0].shape) 276 | if num_pred_head == 2: 277 | self.assertEqual(x.shape, y[1].shape) 278 | 279 | def test_run_model(self): 280 | model = DummyModel(n_out=2) 281 | app = Application(model) 282 | 283 | x = np.random.rand(1, 128, 128, 1) 284 | y = app._run_model(x) 285 | self.assertEqual(x.shape, y[0].shape) 286 | self.assertEqual(x.shape, y[1].shape) 287 | 288 | def test_predict_segmentation(self): 289 | model = DummyModel() 290 | app = Application(model) 291 | 292 | x = np.random.rand(1, 128, 128, 1) 293 | y = app._predict_segmentation(x) 294 | self.assertEqual(x.shape, y.shape) 295 | 296 | # test with different MPP 297 | model = DummyModel() 298 | app = Application(model) 299 | 300 | x = np.random.rand(1, 128, 128, 1) 301 | y = app._predict_segmentation(x, image_mpp=1.3) 302 | self.assertEqual(x.shape, y.shape) 303 | 304 | 305 | @pytest.mark.parametrize( 306 | "orig_img_shape", 307 | ( 308 | (1, 216, 256, 1), # x < model_img_shape 309 | (1, 256, 64, 1), # y < model_img_shape 310 | ), 311 | ) 312 | def test_untile_non_empty_slices(orig_img_shape): 313 | """Test corner cases when tile_info["padding"] is True, but the padding in 314 | at least one of the dimensions is (0, 0). See gh-665.""" 315 | # Padding is "activated" whenever either of the input image dimensions is 316 | # smaller than app.model_image_shape 317 | app = Application(DummyModel, model_image_shape=(256, 256, 1)) 318 | 319 | img = np.ones(orig_img_shape) 320 | # Use _tile_input to generate a valid tile_info object 321 | tiles, tile_info = app._tile_input(img) 322 | # Input validation - tile_info should have a "padding" key and one of the 323 | # two pads should be (0, 0). 324 | assert tile_info.get("padding") and ( 325 | tile_info["x_pad"] == (0, 0) or tile_info["y_pad"] == (0, 0) 326 | ) 327 | untiled = app._untile_output(tiles, tile_info) 328 | assert untiled.shape == orig_img_shape 329 | -------------------------------------------------------------------------------- /tests/deepcell_backbone_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of 2 | # Technology (Caltech), with support from the Paul Allen Family Foundation, 3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. 4 | # All rights reserved. 5 | # 6 | # Licensed under a modified Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE 11 | # 12 | # The Work provided may be used for non-commercial academic purposes only. 13 | # For any other use of the Work, including commercial use, please contact: 14 | # vanvalenlab@gmail.com 15 | # 16 | # Neither the name of Caltech nor the names of its contributors may be used 17 | # to endorse or promote products derived from this software without specific 18 | # prior written permission. 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | # ============================================================================== 26 | """Tests for backbone_utils""" 27 | 28 | 29 | from absl.testing import parameterized 30 | 31 | from tensorflow.python.framework import test_util as tf_test_util 32 | from tensorflow.python.platform import test 33 | 34 | from tensorflow.keras import backend as K 35 | from tensorflow.keras.layers import Input 36 | from tensorflow.keras.models import Model 37 | from keras import keras_parameterized 38 | 39 | from deepcell import backbone_utils 40 | 41 | 42 | class TestBackboneUtils(keras_parameterized.TestCase): 43 | 44 | @keras_parameterized.run_with_all_model_types 45 | @keras_parameterized.run_all_keras_modes 46 | @parameterized.named_parameters( 47 | *tf_test_util.generate_combinations_with_testcase_name( 48 | data_format=[ 49 | # 'channels_first', 50 | 'channels_last'])) 51 | def test_get_featurenet_backbone(self, data_format): 52 | backbone = 'featurenet' 53 | input_shape = (256, 256, 3) 54 | inputs = Input(shape=input_shape) 55 | with self.cached_session(): 56 | K.set_image_data_format(data_format) 57 | model, output_dict = backbone_utils.get_backbone( 58 | backbone, inputs, return_dict=True) 59 | assert isinstance(output_dict, dict) 60 | assert all(k.startswith('C') for k in output_dict) 61 | assert isinstance(model, Model) 62 | 63 | # No imagenet weights for featurenet backbone 64 | with self.assertRaises(ValueError): 65 | backbone_utils.get_backbone(backbone, inputs, use_imagenet=True) 66 | 67 | # @keras_parameterized.run_all_keras_modes 68 | @parameterized.named_parameters( 69 | *tf_test_util.generate_combinations_with_testcase_name( 70 | data_format=[ 71 | # 'channels_first', 72 | 'channels_last'])) 73 | def test_get_featurenet3d_backbone(self, data_format): 74 | backbone = 'featurenet3d' 75 | input_shape = (40, 256, 256, 3) 76 | inputs = Input(shape=input_shape) 77 | with self.cached_session(): 78 | K.set_image_data_format(data_format) 79 | model, output_dict = backbone_utils.get_backbone( 80 | backbone, inputs, return_dict=True) 81 | assert isinstance(output_dict, dict) 82 | assert all(k.startswith('C') for k in output_dict) 83 | assert isinstance(model, Model) 84 | 85 | # No imagenet weights for featurenet backbone 86 | with self.assertRaises(ValueError): 87 | backbone_utils.get_backbone(backbone, inputs, use_imagenet=True) 88 | 89 | # @keras_parameterized.run_with_all_model_types 90 | # @keras_parameterized.run_all_keras_modes 91 | @parameterized.named_parameters( 92 | *tf_test_util.generate_combinations_with_testcase_name( 93 | backbone=[ 94 | 'resnet50', 95 | 'resnet101', 96 | 'resnet152', 97 | 'resnet50v2', 98 | 'resnet101v2', 99 | 'resnet152v2', 100 | # 'resnext50', 101 | # 'resnext101', 102 | 'vgg16', 103 | 'vgg19', 104 | 'densenet121', 105 | 'densenet169', 106 | 'densenet201', 107 | 'mobilenet', 108 | 'mobilenetv2', 109 | 'efficientnetb0', 110 | 'efficientnetb1', 111 | 'efficientnetb2', 112 | 'efficientnetb3', 113 | 'efficientnetb4', 114 | 'efficientnetb5', 115 | 'efficientnetb6', 116 | 'efficientnetb7', 117 | 'nasnet_large', 118 | 'nasnet_mobile'])) 119 | def test_get_backbone(self, backbone): 120 | with self.cached_session(): 121 | K.set_image_data_format('channels_last') 122 | inputs = Input(shape=(256, 256, 3)) 123 | model, output_dict = backbone_utils.get_backbone( 124 | backbone, inputs, return_dict=True) 125 | assert isinstance(output_dict, dict) 126 | assert all(k.startswith('C') for k in output_dict) 127 | assert isinstance(model, Model) 128 | 129 | def test_invalid_backbone(self): 130 | inputs = Input(shape=(4, 2, 3)) 131 | with self.assertRaises(ValueError): 132 | backbone_utils.get_backbone('bad', inputs, return_dict=True) 133 | 134 | 135 | if __name__ == '__main__': 136 | test.main() 137 | -------------------------------------------------------------------------------- /tests/deepcell_layers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of 2 | # Technology (Caltech), with support from the Paul Allen Family Foundation, 3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. 4 | # All rights reserved. 5 | # 6 | # Licensed under a modified Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE 11 | # 12 | # The Work provided may be used for non-commercial academic purposes only. 13 | # For any other use of the Work, including commercial use, please contact: 14 | # vanvalenlab@gmail.com 15 | # 16 | # Neither the name of Caltech nor the names of its contributors may be used 17 | # to endorse or promote products derived from this software without specific 18 | # prior written permission. 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | # ============================================================================== 26 | """Tests for the upsampling layers""" 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | from tensorflow.keras import backend as K 31 | from absl.testing import parameterized 32 | from keras import keras_parameterized 33 | from keras import testing_utils 34 | from tensorflow.python.framework import test_util as tf_test_util 35 | from tensorflow.python.platform import test 36 | from tensorflow.keras.utils import custom_object_scope 37 | 38 | from deepcell import layers 39 | 40 | 41 | @keras_parameterized.run_all_keras_modes 42 | class TestUpsampleLike(keras_parameterized.TestCase): 43 | 44 | def test_simple(self): 45 | # channels_last 46 | # create simple UpsampleLike layer 47 | upsample_like_layer = layers.UpsampleLike() 48 | 49 | # create input source 50 | source = np.zeros((1, 2, 2, 1), dtype=K.floatx()) 51 | source = K.variable(source) 52 | target = np.zeros((1, 5, 5, 1), dtype=K.floatx()) 53 | expected = target 54 | target = K.variable(target) 55 | 56 | # compute output 57 | computed_shape = upsample_like_layer.compute_output_shape( 58 | [source.shape, target.shape]) 59 | 60 | actual = upsample_like_layer.call([source, target]) 61 | actual = K.get_value(actual) 62 | 63 | self.assertEqual(actual.shape, computed_shape) 64 | self.assertAllEqual(actual, expected) 65 | # channels_first 66 | # create simple UpsampleLike layer 67 | upsample_like_layer = layers.UpsampleLike( 68 | data_format='channels_first') 69 | 70 | # create input source 71 | source = np.zeros((1, 1, 2, 2), dtype=K.floatx()) 72 | source = K.variable(source) 73 | target = np.zeros((1, 1, 5, 5), dtype=K.floatx()) 74 | expected = target 75 | target = K.variable(target) 76 | 77 | # compute output 78 | computed_shape = upsample_like_layer.compute_output_shape( 79 | [source.shape, target.shape]) 80 | actual = upsample_like_layer.call([source, target]) 81 | actual = K.get_value(actual) 82 | 83 | self.assertEqual(actual.shape, computed_shape) 84 | self.assertAllEqual(actual, expected) 85 | 86 | def test_simple_3d(self): 87 | # create simple UpsampleLike layer 88 | upsample_like_layer = layers.UpsampleLike() 89 | 90 | # create input source 91 | source = np.zeros((1, 2, 2, 2, 1), dtype=K.floatx()) 92 | source = K.variable(source) 93 | target = np.zeros((1, 5, 5, 5, 1), dtype=K.floatx()) 94 | expected = target 95 | target = K.variable(target) 96 | 97 | # compute output 98 | computed_shape = upsample_like_layer.compute_output_shape( 99 | [source.shape, target.shape]) 100 | 101 | actual = upsample_like_layer.call([source, target]) 102 | actual = K.get_value(actual) 103 | 104 | self.assertEqual(actual.shape, computed_shape) 105 | self.assertAllEqual(actual, expected) 106 | 107 | # channels_first 108 | # create simple UpsampleLike layer 109 | upsample_like_layer = layers.UpsampleLike( 110 | data_format='channels_first') 111 | 112 | # create input source 113 | source = np.zeros((1, 1, 2, 2, 2), dtype=K.floatx()) 114 | source = K.variable(source) 115 | target = np.zeros((1, 1, 5, 5, 5), dtype=K.floatx()) 116 | expected = target 117 | target = K.variable(target) 118 | 119 | # compute output 120 | computed_shape = upsample_like_layer.compute_output_shape( 121 | [source.shape, target.shape]) 122 | actual = upsample_like_layer.call([source, target]) 123 | actual = K.get_value(actual) 124 | 125 | self.assertEqual(actual.shape, computed_shape) 126 | self.assertAllEqual(actual, expected) 127 | 128 | def test_mini_batch(self): 129 | # create simple UpsampleLike layer 130 | upsample_like_layer = layers.UpsampleLike() 131 | 132 | # create input source 133 | source = np.zeros((2, 2, 2, 1), dtype=K.floatx()) 134 | source = K.variable(source) 135 | 136 | target = np.zeros((2, 5, 5, 1), dtype=K.floatx()) 137 | expected = target 138 | target = K.variable(target) 139 | 140 | # compute output 141 | actual = upsample_like_layer.call([source, target]) 142 | actual = K.get_value(actual) 143 | 144 | self.assertAllEqual(actual, expected) 145 | 146 | 147 | @keras_parameterized.run_all_keras_modes 148 | @parameterized.named_parameters( 149 | *tf_test_util.generate_combinations_with_testcase_name( 150 | norm_method=[None, 'std', 'max', 'whole_image'])) 151 | class ImageNormalizationTest(keras_parameterized.TestCase): 152 | 153 | def test_normalize_2d(self, norm_method): 154 | custom_objects = {'ImageNormalization2D': layers.ImageNormalization2D} 155 | with tf.keras.utils.custom_object_scope(custom_objects): 156 | testing_utils.layer_test( 157 | layers.ImageNormalization2D, 158 | kwargs={'norm_method': norm_method, 159 | 'filter_size': 3, 160 | 'data_format': 'channels_last'}, 161 | input_shape=(3, 5, 6, 4)) 162 | testing_utils.layer_test( 163 | layers.ImageNormalization2D, 164 | kwargs={'norm_method': norm_method, 165 | 'filter_size': 3, 166 | 'data_format': 'channels_first'}, 167 | input_shape=(3, 4, 5, 6)) 168 | # test constraints and bias 169 | k_constraint = tf.keras.constraints.max_norm(0.01) 170 | b_constraint = tf.keras.constraints.max_norm(0.01) 171 | layer = layers.ImageNormalization2D( 172 | use_bias=True, 173 | kernel_constraint=k_constraint, 174 | bias_constraint=b_constraint) 175 | layer(tf.keras.backend.variable(np.ones((3, 5, 6, 4)))) 176 | # self.assertEqual(layer.kernel.constraint, k_constraint) 177 | # self.assertEqual(layer.bias.constraint, b_constraint) 178 | # test bad norm_method 179 | with self.assertRaises(ValueError): 180 | layer = layers.ImageNormalization2D(norm_method='invalid') 181 | # test bad input dimensions 182 | with self.assertRaises(ValueError): 183 | layer = layers.ImageNormalization2D() 184 | layer.build([3, 10, 11, 12, 4]) 185 | # test invalid channel 186 | with self.assertRaises(ValueError): 187 | layer = layers.ImageNormalization2D() 188 | layer.build([3, 5, 6, None]) 189 | 190 | 191 | @keras_parameterized.run_all_keras_modes 192 | class LocationTest(keras_parameterized.TestCase): 193 | 194 | def test_location_2d(self): 195 | with custom_object_scope({'Location2D': layers.Location2D}): 196 | testing_utils.layer_test( 197 | layers.Location2D, 198 | kwargs={'data_format': 'channels_last'}, 199 | input_shape=(3, 5, 6, 4)) 200 | testing_utils.layer_test( 201 | layers.Location2D, 202 | kwargs={'in_shape': (4, 5, 6), 203 | 'data_format': 'channels_first'}, 204 | input_shape=(3, 4, 5, 6)) 205 | -------------------------------------------------------------------------------- /tests/deepcell_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of 2 | # Technology (Caltech), with support from the Paul Allen Family Foundation, 3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. 4 | # All rights reserved. 5 | # 6 | # Licensed under a modified Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE 11 | # 12 | # The Work provided may be used for non-commercial academic purposes only. 13 | # For any other use of the Work, including commercial use, please contact: 14 | # vanvalenlab@gmail.com 15 | # 16 | # Neither the name of Caltech nor the names of its contributors may be used 17 | # to endorse or promote products derived from this software without specific 18 | # prior written permission. 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | # ============================================================================== 26 | """Tests for misc_utils""" 27 | 28 | from tensorflow.python.platform import test 29 | 30 | from deepcell.utils import get_sorted_keys 31 | 32 | 33 | class MiscUtilsTest(test.TestCase): 34 | 35 | def test_get_sorted_keys(self): 36 | d = {'C1': 1, 'C3': 2, 'C2': 3} 37 | self.assertListEqual(get_sorted_keys(d), ['C1', 'C2', 'C3']) 38 | -------------------------------------------------------------------------------- /tests/inference_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pytest 4 | import tempfile 5 | import numpy as np 6 | from skimage import io 7 | from cell_classification.application import Nimbus 8 | from segmentation_data_prep_test import prepare_test_data_folders, prep_object_and_inputs 9 | from cell_classification.inference import calculate_normalization, prepare_normalization_dict 10 | from cell_classification.inference import prepare_input_data, segment_mean, predict_fovs 11 | from cell_classification.inference import test_time_aug as tt_aug 12 | 13 | 14 | def test_calculate_normalization(): 15 | with tempfile.TemporaryDirectory() as temp_dir: 16 | fov_paths = prepare_test_data_folders( 17 | 1, temp_dir, ["CD4"], random=True, 18 | scale=[0.5] 19 | ) 20 | channel = "CD4" 21 | channel_path = os.path.join(fov_paths[0], channel + ".tiff") 22 | channel_out, norm_val = calculate_normalization(channel_path, 0.999) 23 | # test if we get the correct channel and normalization value 24 | assert channel_out == channel 25 | assert np.isclose(norm_val, 0.5, 0.01) 26 | 27 | 28 | def test_prepare_normalization_dict(): 29 | with tempfile.TemporaryDirectory() as temp_dir: 30 | scales = [0.5, 1.0, 1.5, 2.0, 5.0] 31 | channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"] 32 | fov_paths = prepare_test_data_folders( 33 | 5, temp_dir, channels, random=True, 34 | scale=scales 35 | ) 36 | normalization_dict = prepare_normalization_dict( 37 | fov_paths, temp_dir, quantile=0.999, exclude_channels=["CD57"], n_subset=10, n_jobs=1, 38 | output_name="normalization_dict.json" 39 | ) 40 | # test if normalization dict got saved 41 | assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) 42 | assert normalization_dict == json.load( 43 | open(os.path.join(temp_dir, "normalization_dict.json")) 44 | ) 45 | # test if channel got excluded 46 | assert "CD57" not in normalization_dict.keys() 47 | # test if normalization dict is correct 48 | for channel, scale in zip(channels, scales): 49 | if channel == "CD57": 50 | continue 51 | assert np.isclose(normalization_dict[channel], scale, 0.01) 52 | 53 | # test if multiprocessing yields approximately the same results 54 | normalization_dict_mp = prepare_normalization_dict( 55 | fov_paths, temp_dir, quantile=0.999, exclude_channels=["CD57"], n_subset=10, n_jobs=2, 56 | output_name="normalization_dict.json" 57 | ) 58 | for key in normalization_dict.keys(): 59 | assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) 60 | 61 | 62 | def test_prepare_input_data(): 63 | with tempfile.TemporaryDirectory() as temp_dir: 64 | scales = [0.5] 65 | channels = ["CD4"] 66 | fov_paths = prepare_test_data_folders( 67 | 1, temp_dir, channels, random=True, 68 | scale=scales 69 | ) 70 | mplex_img = io.imread(os.path.join(fov_paths[0], "CD4.tiff")) 71 | instance_mask = io.imread(os.path.join(fov_paths[0], "cell_segmentation.tiff")) 72 | input_data = prepare_input_data(mplex_img, instance_mask) 73 | # check shape 74 | assert input_data.shape == (1, 256, 256, 2) 75 | # check if instance mask got binarized and eroded 76 | assert np.alltrue(np.unique(input_data[..., 1]) == np.array([0, 1])) 77 | assert np.sum(input_data[..., 1]) < np.sum(instance_mask) 78 | # check if mplex image is the same as before 79 | assert np.alltrue(input_data[0, ..., 0] == mplex_img) 80 | 81 | 82 | def test_segment_mean(): 83 | with tempfile.TemporaryDirectory() as temp_dir: 84 | scales = [0.5] 85 | channels = ["CD4"] 86 | fov_paths = prepare_test_data_folders( 87 | 1, temp_dir, channels, random=True, 88 | scale=scales 89 | ) 90 | mplex_img = io.imread(os.path.join(fov_paths[0], "CD4.tiff")) 91 | prediction = (mplex_img > 0.25).astype(np.float32) 92 | instance_mask = io.imread(os.path.join(fov_paths[0], "cell_segmentation.tiff")) 93 | instance_ids, mean_per_cell = segment_mean(instance_mask, prediction) 94 | # check if we get the correct number of cells 95 | assert len(instance_ids) == len(np.unique(instance_mask)[1:]) 96 | # check if we get the correct mean per cell 97 | for i in np.unique(instance_mask)[1:]: 98 | assert mean_per_cell[i-1] == np.mean(prediction[instance_mask == i]) 99 | 100 | 101 | def test_tt_aug(): 102 | with tempfile.TemporaryDirectory() as temp_dir: 103 | def segmentation_naming_convention(fov_path): 104 | return os.path.join(fov_path, "cell_segmentation.tiff") 105 | 106 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir) 107 | os.remove(os.path.join(temp_dir, 'normalization_dict.json')) 108 | output_dir = os.path.join(temp_dir, "nimbus_output") 109 | fov_paths = fov_paths[:1] 110 | nimbus = Nimbus( 111 | fov_paths, segmentation_naming_convention, output_dir, 112 | exclude_channels=["CD57", "CD11c", "XYZ"] 113 | ) 114 | nimbus.prepare_normalization_dict() 115 | channel = "CD4" 116 | mplex_img = io.imread(os.path.join(fov_paths[0], channel+".tiff")) 117 | instance_mask = io.imread(os.path.join(fov_paths[0], "cell_segmentation.tiff")) 118 | input_data = prepare_input_data(mplex_img, instance_mask) 119 | pred_map = tt_aug( 120 | input_data, channel, nimbus, nimbus.normalization_dict, rotate=True, flip=True, 121 | batch_size=32 122 | ) 123 | # check if we get the correct shape 124 | assert pred_map.shape == (256, 256, 1) 125 | 126 | pred_map_2 = tt_aug( 127 | input_data, channel, nimbus, nimbus.normalization_dict, rotate=False, flip=True, 128 | batch_size=32 129 | ) 130 | pred_map_3 = tt_aug( 131 | input_data, channel, nimbus, nimbus.normalization_dict, rotate=True, flip=False, 132 | batch_size=32 133 | ) 134 | pred_map_no_tt_aug = nimbus._predict_segmentation( 135 | input_data, 136 | batch_size=1, 137 | preprocess_kwargs={ 138 | "normalize": True, 139 | "marker": channel, 140 | "normalization_dict": nimbus.normalization_dict}, 141 | ) 142 | # check if we get roughly the same results for non augmented and augmented predictions 143 | assert np.allclose(pred_map, pred_map_no_tt_aug, atol=0.05) 144 | assert np.allclose(pred_map_2, pred_map_no_tt_aug, atol=0.05) 145 | assert np.allclose(pred_map_3, pred_map_no_tt_aug, atol=0.05) 146 | 147 | 148 | def test_predict_fovs(): 149 | with tempfile.TemporaryDirectory() as temp_dir: 150 | def segmentation_naming_convention(fov_path): 151 | return os.path.join(fov_path, "cell_segmentation.tiff") 152 | 153 | exclude_channels = ["CD57", "CD11c", "XYZ"] 154 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir) 155 | os.remove(os.path.join(temp_dir, 'normalization_dict.json')) 156 | output_dir = os.path.join(temp_dir, "nimbus_output") 157 | fov_paths = fov_paths[:1] 158 | nimbus = Nimbus( 159 | fov_paths, segmentation_naming_convention, output_dir, 160 | exclude_channels=exclude_channels 161 | ) 162 | output_dir = os.path.join(temp_dir, "nimbus_output") 163 | nimbus.prepare_normalization_dict() 164 | cell_table = predict_fovs( 165 | fov_paths, output_dir, nimbus, nimbus.normalization_dict, 166 | segmentation_naming_convention, exclude_channels=exclude_channels, 167 | save_predictions=False, half_resolution=True, 168 | ) 169 | # check if we get the correct number of cells 170 | assert len(cell_table) == 15 171 | # check if we get the correct columns (fov, label, CD4_pred, CD56_pred) 172 | assert np.alltrue( 173 | set(cell_table.columns) == set(["fov", "label", "CD4_pred", "CD56_pred"]) 174 | ) 175 | # check if predictions don't get written to output_dir 176 | assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) 177 | assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) 178 | # 179 | # run again with save_predictions=True and check if predictions get written to output_dir 180 | cell_table = predict_fovs( 181 | fov_paths, output_dir, nimbus, nimbus.normalization_dict, 182 | segmentation_naming_convention, exclude_channels=exclude_channels, 183 | save_predictions=True, half_resolution=True, test_time_augmentation=False, 184 | ) 185 | assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) 186 | assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) 187 | -------------------------------------------------------------------------------- /tests/loss_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from cell_classification.loss import Loss 5 | 6 | 7 | def test_loss(): 8 | pred = tf.constant(np.random.rand(1, 256, 256, 1)) 9 | target = tf.constant(np.random.randint(0, 2, size=(1, 256, 256, 1))) 10 | loss_fn = Loss("BinaryCrossentropy", False) 11 | loss = loss_fn(target, pred) 12 | 13 | # check if loss has the right shape 14 | assert isinstance(loss, tf.Tensor) 15 | assert loss.shape == (1, 256, 256) 16 | # check if loss is in the right range 17 | assert tf.reduce_mean(loss).numpy() >= 0 18 | 19 | # check if wrapper works with label_smoothing 20 | loss_fn = Loss("BinaryCrossentropy", False, label_smoothing=0.1) 21 | loss_smoothed = loss_fn(target, pred) 22 | 23 | # check if loss is a scalar 24 | assert isinstance(loss_smoothed, tf.Tensor) 25 | assert loss_smoothed.shape == (1, 256, 256) 26 | # check if loss is in the right range 27 | assert tf.reduce_mean(loss_smoothed).numpy() >= 0 28 | # check if smoothed loss is reasonably near the unsmoothed loss 29 | assert np.isclose(loss_smoothed.numpy().mean(), loss.numpy().mean(), atol=0.1) 30 | 31 | # check if loss is zero when target and pred are equal 32 | loss_fn = Loss("BinaryCrossentropy", False) 33 | loss = loss_fn(target, tf.cast(target, tf.float32)) 34 | assert loss.numpy().mean() == 0 35 | 36 | # check if loss is zero when target is 2 37 | loss_fn = Loss("BinaryCrossentropy", True) 38 | loss = loss_fn(2 * np.ones_like(target), pred) 39 | assert loss.numpy().mean() == 0 40 | 41 | # check if works with FocalLoss 42 | loss_fn = Loss("BinaryFocalCrossentropy", True, label_smoothing=0.1, gamma=2) 43 | loss = loss_fn(target, pred) 44 | assert loss.numpy().mean() >= 0 45 | 46 | # check if config is returned correctly 47 | config = loss_fn.get_config() 48 | assert config["loss_name"] == "BinaryFocalCrossentropy" 49 | assert config["label_smoothing"] == 0.1 50 | assert config["gamma"] == 2 51 | assert config["selective_masking"] 52 | -------------------------------------------------------------------------------- /tests/metrics_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import wandb 7 | 8 | from cell_classification.metrics import (HDF5Loader, average_roc, calc_metrics, 9 | calc_roc) 10 | from cell_classification.model_builder import ModelBuilder 11 | 12 | from .segmentation_data_prep_test import prep_object_and_inputs 13 | 14 | 15 | def make_pred_list(): 16 | pred_list = [] 17 | for i in range(10): 18 | instance_mask = np.random.randint(0, 10, size=(256, 256, 1)) 19 | binary_mask = (instance_mask > 0).astype(np.uint8) 20 | activity_df = pd.DataFrame( 21 | { 22 | "labels": np.array([1, 2, 5, 7, 9, 11], dtype=np.uint16), 23 | "activity": [1, 0, 0, 2, 0, 1], 24 | "cell_type": ["T cell", "B cell", "T cell", "B cell", "T cell", "B cell"], 25 | "sample": [str(i)]*6, 26 | "imaging_platform": ["test"]*6, 27 | "dataset": ["test"]*6, 28 | "marker": ["CD4"]*6 if i % 2 == 0 else "CD8", 29 | "prediction": np.random.rand(6), 30 | } 31 | ) 32 | pred_list.append( 33 | { 34 | "marker_activity_mask": np.random.randint(0, 2, (256, 256, 1))*binary_mask, 35 | "prediction": np.random.rand(256, 256, 1), 36 | "instance_mask": instance_mask, 37 | "binary_mask": binary_mask, 38 | "dataset": "test", "imaging_platform": "test", 39 | "marker": "CD4" if i % 2 == 0 else "CD8", 40 | "activity_df": activity_df, 41 | } 42 | ) 43 | pred_list[-1]["marker_activity_mask"] = np.zeros((256, 256, 1)) 44 | return pred_list 45 | 46 | 47 | def test_calc_roc(): 48 | pred_list = make_pred_list() 49 | roc = calc_roc(pred_list) 50 | 51 | # check if roc has the right keys 52 | assert set(roc.keys()) == set(["fpr", "tpr", "auc", "thresholds", "marker"]) 53 | 54 | # check if roc has the right number of items 55 | assert len(roc["fpr"]) == len(roc["tpr"]) == len(roc["thresholds"]) == len(roc["auc"]) == 9 56 | 57 | 58 | def test_calc_metrics(): 59 | pred_list = make_pred_list() 60 | avg_metrics = calc_metrics(pred_list) 61 | keys = [ 62 | "accuracy", "precision", "recall", "specificity", "f1_score", "tp", "tn", "fp", "fn", 63 | "dataset", "imaging_platform", "marker", "threshold", 64 | ] 65 | 66 | # check if avg_metrics has the right keys 67 | assert set(avg_metrics.keys()) == set(keys) 68 | 69 | # check if avg_metrics has the right number of items 70 | assert ( 71 | len(avg_metrics["accuracy"]) == len(avg_metrics["precision"]) 72 | == len(avg_metrics["recall"]) == len(avg_metrics["f1_score"]) 73 | == len(avg_metrics["tp"]) == len(avg_metrics["tn"]) 74 | == len(avg_metrics["fp"]) == len(avg_metrics["fn"]) 75 | == len(avg_metrics["dataset"]) == len(avg_metrics["imaging_platform"]) 76 | == len(avg_metrics["marker"]) == len(avg_metrics["threshold"]) == 50 77 | ) 78 | 79 | 80 | def test_average_roc(): 81 | pred_list = make_pred_list() 82 | roc_list = calc_roc(pred_list) 83 | tprs, mean_tprs, base, std, mean_thresh = average_roc(roc_list) 84 | 85 | # check if mean_tprs, base, std and mean_thresh have the same length 86 | assert len(mean_tprs) == len(base) == len(std) == len(mean_thresh) == tprs.shape[1] 87 | 88 | # check if mean and std give reasonable results 89 | assert np.array_equal(np.mean(tprs, axis=0), mean_tprs) 90 | assert np.array_equal(np.std(tprs, axis=0), std) 91 | 92 | 93 | def test_HDF5Generator(config_params): 94 | with tempfile.TemporaryDirectory() as temp_dir: 95 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir) 96 | data_prep.tf_record_path = temp_dir 97 | data_prep.make_tf_record() 98 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord") 99 | config_params["record_path"] = [tf_record_path] 100 | config_params["path"] = temp_dir 101 | config_params["experiment"] = "test" 102 | config_params["dataset_names"] = ["test1"] 103 | config_params["num_steps"] = 2 104 | config_params["dataset_sample_probs"] = [1.0] 105 | config_params["num_validation"] = [2] 106 | config_params["num_test"] = [2] 107 | config_params["snap_steps"] = 100 108 | config_params["val_steps"] = 100 109 | model = ModelBuilder(config_params) 110 | model.train() 111 | model.predict_dataset(model.validation_datasets[0], save_predictions=True) 112 | wandb.finish() 113 | generator = HDF5Loader(model.params['eval_dir']) 114 | 115 | # check if generator has the right number of items 116 | assert [len(generator)] == config_params['num_validation'] 117 | 118 | # check if generator returns the right items 119 | for sample in generator: 120 | assert isinstance(sample, dict) 121 | assert set(list(sample.keys())) == set([ 122 | 'binary_mask', 'dataset', 'folder_name', 'imaging_platform', 'instance_mask', 123 | 'marker', 'marker_activity_mask', 'membrane_img', 'mplex_img', 'nuclei_img', 124 | 'prediction', 'prediction_mean', 'activity_df', 'tissue_type' 125 | ]) 126 | -------------------------------------------------------------------------------- /tests/plot_utils_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import tensorflow as tf 5 | 6 | from cell_classification.metrics import average_roc, calc_metrics, calc_roc 7 | from cell_classification.plot_utils import (collapse_activity_dfs, 8 | heatmap_plot, plot_average_roc, 9 | plot_metrics_against_threshold, 10 | plot_overlay, plot_together, 11 | subset_activity_df, subset_plots) 12 | from cell_classification.segmentation_data_prep import (feature_description, 13 | parse_dict) 14 | 15 | from .metrics_test import make_pred_list 16 | from .segmentation_data_prep_test import prep_object_and_inputs 17 | 18 | 19 | def prepare_dataset(temp_dir): 20 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir) 21 | data_prep.tf_record_path = temp_dir 22 | data_prep.make_tf_record() 23 | tf_record_path = os.path.join(temp_dir, data_prep.dataset + ".tfrecord") 24 | dataset = tf.data.TFRecordDataset(tf_record_path) 25 | return dataset 26 | 27 | 28 | def test_plot_overlay(): 29 | with tempfile.TemporaryDirectory() as temp_dir: 30 | dataset = iter(prepare_dataset(temp_dir)) 31 | plot_path = os.path.join(temp_dir, "plots") 32 | os.makedirs(plot_path, exist_ok=True) 33 | record = next(dataset) 34 | example_encoded = tf.io.parse_single_example(record, feature_description) 35 | example = parse_dict(example_encoded) 36 | plot_overlay( 37 | example, save_dir=plot_path, save_file=f"{example['folder_name']}_overlay.png" 38 | ) 39 | 40 | # check if plot was saved 41 | assert os.path.exists(os.path.join(plot_path, f"{example['folder_name']}_overlay.png")) 42 | 43 | 44 | def test_plot_together(): 45 | with tempfile.TemporaryDirectory() as temp_dir: 46 | dataset = iter(prepare_dataset(temp_dir)) 47 | plot_path = os.path.join(temp_dir, "plots") 48 | os.makedirs(plot_path, exist_ok=True) 49 | record = next(dataset) 50 | example_encoded = tf.io.parse_single_example(record, feature_description) 51 | example = parse_dict(example_encoded) 52 | plot_together( 53 | example, ["mplex_img", "nuclei_img", "marker_activity_mask"], save_dir=plot_path, 54 | save_file=f"{example['folder_name']}_together.png" 55 | ) 56 | 57 | # check if plot was saved 58 | assert os.path.exists(os.path.join(plot_path, f"{example['folder_name']}_together.png")) 59 | 60 | 61 | def test_plot_average_roc(): 62 | with tempfile.TemporaryDirectory() as temp_dir: 63 | pred_list = make_pred_list() 64 | roc_list = calc_roc(pred_list) 65 | tprs, mean_tprs, base, std, mean_thresh = average_roc(roc_list) 66 | plot_path = os.path.join(temp_dir, "plots") 67 | os.makedirs(plot_path, exist_ok=True) 68 | plot_average_roc(mean_tprs, std, save_dir=plot_path, save_file="average_roc.png") 69 | 70 | # check if plot was saved 71 | assert os.path.exists(os.path.join(plot_path, "average_roc.png")) 72 | 73 | 74 | def test_plot_metrics_against_threshold(): 75 | with tempfile.TemporaryDirectory() as temp_dir: 76 | pred_list = make_pred_list() 77 | avg_metrics = calc_metrics(pred_list) 78 | plot_path = os.path.join(temp_dir, "plots") 79 | os.makedirs(plot_path, exist_ok=True) 80 | plot_metrics_against_threshold( 81 | avg_metrics, metric_keys=['precision', 'recall', 'f1_score'], save_dir=plot_path, 82 | threshold_key="threshold", save_file="metrics_against_threshold.png" 83 | ) 84 | 85 | # check if plot was saved 86 | assert os.path.exists(os.path.join(plot_path, "metrics_against_threshold.png")) 87 | 88 | 89 | def test_collapse_activity_dfs(): 90 | pred_list = make_pred_list() 91 | df = collapse_activity_dfs(pred_list) 92 | 93 | # check if df has the right shape and keys 94 | assert df.shape == (len(pred_list)*6, 8) 95 | assert set(df.columns) == set(pred_list[0]["activity_df"].columns) 96 | 97 | 98 | def test_subset_activity_df(): 99 | pred_list = make_pred_list() 100 | df = collapse_activity_dfs(pred_list) 101 | cd4_subset = subset_activity_df(df, {"marker": "CD4"}) 102 | cd4_tcells_subset = subset_activity_df(df, {"marker": "CD4", "cell_type": "T cell"}) 103 | 104 | # check if cd4_subset and cd8_subset have the right shape 105 | assert set(cd4_subset["marker"]) == set(["CD4"]) 106 | assert cd4_subset.shape == df[df.marker == "CD4"].shape 107 | 108 | # check if cd4_tcells_subset has the right shape 109 | assert set(cd4_tcells_subset["marker"]) == set(["CD4"]) 110 | assert set(cd4_tcells_subset["cell_type"]) == set(["T cell"]) 111 | assert cd4_tcells_subset.shape == df[(df.marker == "CD4") & (df.cell_type == "T cell")].shape 112 | 113 | 114 | def test_subset_plots(): 115 | pred_list = make_pred_list() 116 | activity_df = collapse_activity_dfs(pred_list) 117 | with tempfile.TemporaryDirectory() as temp_dir: 118 | plot_path = os.path.join(temp_dir, "plots") 119 | os.makedirs(plot_path, exist_ok=True) 120 | subset_plots( 121 | activity_df, subset_list=["marker"], save_dir=plot_path, 122 | save_file="split_by_marker.png" 123 | ) 124 | subset_plots( 125 | activity_df, subset_list=["marker", "cell_type"], save_dir=plot_path, 126 | save_file="split_by_marker_ct.png" 127 | ) 128 | 129 | # check if plots were saved 130 | assert os.path.exists(os.path.join(plot_path, "split_by_marker.png")) 131 | 132 | 133 | def test_heatmap_plot(): 134 | pred_list = make_pred_list() 135 | activity_df = collapse_activity_dfs(pred_list) 136 | with tempfile.TemporaryDirectory() as temp_dir: 137 | plot_path = os.path.join(temp_dir, "plots") 138 | os.makedirs(plot_path, exist_ok=True) 139 | heatmap_plot( 140 | activity_df, ["marker"], save_dir=plot_path, save_file="heatmap.png" 141 | ) 142 | 143 | # check if plot was saved 144 | assert os.path.exists(os.path.join(plot_path, "heatmap.png")) 145 | -------------------------------------------------------------------------------- /tests/post_processing_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from cell_classification.post_processing import (merge_activity_df, 5 | process_to_cells) 6 | 7 | 8 | def test_process_to_cells(): 9 | # prepare data 10 | instance_mask = np.random.randint(0, 10, size=(256, 256, 1)) 11 | prediction = np.random.rand(256, 256, 1) 12 | prediction[instance_mask == 1] = 1.0 13 | prediction[instance_mask == 2] = 0.0 14 | prediction_mean, activity_df = process_to_cells(instance_mask, prediction) 15 | 16 | # check types and shape 17 | assert isinstance(activity_df, pd.DataFrame) 18 | assert isinstance(prediction_mean, np.ndarray) 19 | assert len(activity_df.pred_activity) == 9 20 | assert prediction_mean.shape == (256, 256, 1) 21 | 22 | # check values 23 | assert prediction_mean[instance_mask == 1].mean() == 1.0 24 | assert prediction_mean[instance_mask == 2].mean() == 0.0 25 | assert activity_df.pred_activity[0] == 1.0 26 | assert activity_df.pred_activity[1] == 0.0 27 | 28 | 29 | def test_merge_activity_df(): 30 | # prepare data 31 | pred_df = pd.DataFrame({"labels": list(range(1, 10)), "pred_activity": np.random.rand(9)}) 32 | gt_df = pd.DataFrame( 33 | {"labels": list(range(1, 10)), "gt_activity": np.random.randint(0, 2, 9)} 34 | ) 35 | merged_df = merge_activity_df(gt_df, pred_df) 36 | 37 | # check if columns are there and no labels got lost in the merge 38 | assert np.array_equal(merged_df.labels, gt_df.labels) 39 | assert set(['labels', 'gt_activity', "pred_activity"]) == set(merged_df.columns) 40 | -------------------------------------------------------------------------------- /tests/promix_naive_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import tensorflow as tf 7 | import toml 8 | 9 | import wandb 10 | from cell_classification.promix_naive import PromixNaive 11 | 12 | from segmentation_data_prep_test import prep_object_and_inputs 13 | 14 | 15 | def test_reduce_to_cells(config_params): 16 | config_params["test"] = True 17 | pred = np.random.rand(16, 256, 266) 18 | instance_mask = np.random.randint(0, 100, (16, 256, 266)) 19 | instance_mask[-1, instance_mask[-1] == 1] = 0 20 | marker_activity_mask = np.zeros_like(instance_mask) 21 | marker_activity_mask[instance_mask > 90] = 1 22 | trainer = PromixNaive(config_params) 23 | uniques, mean_per_cell = tf.map_fn( 24 | trainer.reduce_to_cells, 25 | (pred, instance_mask), 26 | infer_shape=False, 27 | fn_output_signature=[ 28 | tf.RaggedTensorSpec(shape=[None], dtype=tf.int32, ragged_rank=0), 29 | tf.RaggedTensorSpec(shape=[None], dtype=tf.float32, ragged_rank=0), 30 | ], 31 | ) 32 | 33 | # check that the output has the right dimension 34 | assert uniques.shape[0] == instance_mask.shape[0] 35 | assert mean_per_cell.shape[0] == instance_mask.shape[0] 36 | 37 | # check that the output is correct 38 | assert set(np.unique(instance_mask[0])) == set(uniques[0].numpy()) 39 | for i in np.unique(instance_mask[0]): 40 | assert np.isclose( 41 | np.mean(pred[0][instance_mask[0] == i]), 42 | mean_per_cell[0][uniques[0] == i].numpy().max(), 43 | ) 44 | 45 | 46 | def test_matched_high_confidence_selection_thresholds(config_params): 47 | config_params["test"] = True 48 | trainer = PromixNaive(config_params) 49 | trainer.matched_high_confidence_selection_thresholds() 50 | thresholds = trainer.confidence_loss_thresholds 51 | # check that the output has the right dimension 52 | assert len(thresholds) == 2 53 | assert thresholds["positive"] > 0.0 54 | assert thresholds["negative"] > 0.0 55 | 56 | 57 | def test_train(config_params): 58 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 59 | with tempfile.TemporaryDirectory() as temp_dir: 60 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir) 61 | data_prep.tf_record_path = temp_dir 62 | data_prep.make_tf_record() 63 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord") 64 | config_params["record_path"] = [tf_record_path] 65 | config_params["path"] = temp_dir 66 | config_params["experiment"] = "test" 67 | config_params["dataset_names"] = ["test1"] 68 | config_params["dataset_sample_probs"] = [1.0] 69 | config_params["num_steps"] = 7 70 | config_params["num_validation"] = [2] 71 | config_params["num_test"] = [2] 72 | config_params["batch_size"] = 2 73 | config_params["test"] = True 74 | config_params["weight_decay"] = 1e-4 75 | config_params["snap_steps"] = 5 76 | config_params["val_steps"] = 5 77 | config_params["quantile"] = 0.3 78 | config_params["ema"] = 0.01 79 | config_params["confidence_thresholds"] = [0.1, 0.9] 80 | config_params["mixup_prob"] = 0.5 81 | trainer = PromixNaive(config_params) 82 | trainer.train() 83 | wandb.finish() 84 | # check params.toml is dumped to file and contains the created paths 85 | assert "params.toml" in os.listdir(trainer.params["model_dir"]) 86 | loaded_params = toml.load(os.path.join(trainer.params["model_dir"], "params.toml")) 87 | for key in ["model_dir", "log_dir", "model_path"]: 88 | assert key in list(loaded_params.keys()) 89 | 90 | # check if model can be loaded from file 91 | trainer.model = None 92 | trainer.load_model(trainer.params["model_path"]) 93 | assert isinstance(trainer.model, tf.keras.Model) 94 | 95 | 96 | def test_prep_data(config_params): 97 | with tempfile.TemporaryDirectory() as temp_dir: 98 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir) 99 | data_prep.tf_record_path = temp_dir 100 | data_prep.make_tf_record() 101 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord") 102 | config_params["record_path"] = tf_record_path 103 | config_params["path"] = temp_dir 104 | config_params["experiment"] = "test" 105 | config_params["dataset_names"] = ["test1"] 106 | config_params["dataset_sample_probs"] = [1.0] 107 | config_params["num_steps"] = 3 108 | config_params["num_validation"] = [2] 109 | config_params["num_test"] = [2] 110 | config_params["batch_size"] = 2 111 | trainer = PromixNaive(config_params) 112 | trainer.prep_data() 113 | 114 | # check if train and validation datasets exists and are of the right type 115 | assert isinstance(trainer.validation_datasets[0], tf.data.Dataset) 116 | assert isinstance(trainer.train_dataset, tf.data.Dataset) 117 | 118 | 119 | def prepare_activity_df(): 120 | activity_df_list = [] 121 | for i in range(4): 122 | activity_df = pd.DataFrame( 123 | { 124 | "labels": np.array([1, 2, 5, 7, 9, 11], dtype=np.uint16), 125 | "activity": [1, 0, 0, 0, 0, 1], 126 | "cell_type": ["T cell", "B cell", "T cell", "B cell", "T cell", "B cell"], 127 | "sample": [str(i)] * 6, 128 | "imaging_platform": ["test"] * 6, 129 | "dataset": ["test"] * 6, 130 | "marker": ["CD4"] * 6 if i % 2 == 0 else "CD8", 131 | "prediction": [0.9, 0.1, 0.1, 0.7, 0.7, 0.2], 132 | } 133 | ) 134 | activity_df_list.append(activity_df) 135 | return activity_df_list 136 | 137 | 138 | def test_class_wise_loss_selection(config_params): 139 | config_params["test"] = True 140 | trainer = PromixNaive(config_params) 141 | activity_df_list = prepare_activity_df() 142 | df = activity_df_list[0] 143 | marker = df["marker"][0] 144 | dataset = df["dataset"][0] 145 | 146 | dataset_marker = dataset + "_" + marker 147 | trainer.class_wise_loss_quantiles[dataset_marker] = {"positive": 0.5, "negative": 0.5} 148 | df["loss"] = df.activity * df.prediction + (1 - df.activity) * (1 - df.prediction) 149 | positive_df = df[df["activity"] == 1] 150 | negative_df = df[df["activity"] == 0] 151 | selected_subset = trainer.class_wise_loss_selection(positive_df, negative_df, marker, dataset) 152 | 153 | # check that the output has the right dimension 154 | assert len(selected_subset) == 2 155 | assert len(selected_subset[0]) == 1 156 | 157 | # check that the output is correct and only those cells are selected that have a loss 158 | # smaller than the threshold 159 | assert selected_subset[0].equals( 160 | df[df["activity"] == 1].loc[ 161 | df["loss"] <= trainer.class_wise_loss_quantiles[dataset_marker]["positive"] 162 | ] 163 | ) 164 | assert selected_subset[1].equals( 165 | df[df["activity"] == 0].loc[ 166 | df["loss"] <= trainer.class_wise_loss_quantiles[dataset_marker]["negative"] 167 | ] 168 | ) 169 | 170 | # check if quantiles got updated 171 | assert trainer.class_wise_loss_quantiles[dataset_marker]["positive"] != 0.5 172 | assert trainer.class_wise_loss_quantiles[dataset_marker]["negative"] != 0.5 173 | 174 | 175 | def test_matched_high_confidence_selection(config_params): 176 | trainer = PromixNaive(config_params) 177 | activity_df_list = prepare_activity_df() 178 | df = activity_df_list[0] 179 | df["loss"] = df.activity * df.prediction + (1 - df.activity) * (1 - df.prediction) 180 | positive_df = df[df["activity"] == 1] 181 | negative_df = df[df["activity"] == 0] 182 | mark = df["marker"][0] 183 | trainer.matched_high_confidence_selection_thresholds() 184 | selected_subset = trainer.matched_high_confidence_selection(positive_df, negative_df) 185 | df = pd.concat(selected_subset) 186 | 187 | # check that the output has the right dimension 188 | assert len(df) == 1 189 | 190 | # check that the output is correct and only those cells are selected that have a loss 191 | # smaller than the threshold 192 | gt_activity = "positive" if df["activity"].values[0] == 1 else "negative" 193 | assert (df.loss.values[0] <= trainer.confidence_loss_thresholds[gt_activity]) 194 | 195 | 196 | def test_batchwise_loss_selection(config_params): 197 | config_params["test"] = True 198 | trainer = PromixNaive(config_params) 199 | trainer.matched_high_confidence_selection_thresholds() 200 | activity_df_list = prepare_activity_df() 201 | instance_mask = np.zeros([256, 256], dtype=np.uint8) 202 | i = 1 203 | for h in range(0, 260, 20): 204 | for w in range(0, 260, 20): 205 | instance_mask[h: h + 10, w: w + 10] = i 206 | i += 1 207 | dfs = activity_df_list[:2] 208 | mark = [tf.constant(str(df["marker"][0]).encode()) for df in dfs] 209 | dset = [tf.constant(str(df["dataset"][0]).encode()) for df in dfs] 210 | 211 | for df in dfs: 212 | df["loss"] = df.activity * df.prediction + (1 - df.activity) * (1 - df.prediction) 213 | instance_mask = instance_mask[np.newaxis, ..., np.newaxis] 214 | instance_mask = np.concatenate([instance_mask, instance_mask], axis=0) 215 | loss_mask = trainer.batchwise_loss_selection(dfs, instance_mask, mark, dset) 216 | 217 | # check that the output has the right dimension 218 | assert list(loss_mask.shape) == [2, 256, 256] 219 | 220 | # check that they are equal 221 | assert np.array_equal(loss_mask[0], loss_mask[1]) 222 | 223 | 224 | def test_quantile_scheduler(config_params): 225 | config_params["test"] = True 226 | trainer = PromixNaive(config_params) 227 | quantile_start = config_params["quantile"] 228 | quantile_end = config_params["quantile_end"] 229 | quantile_warmup_steps = config_params["quantile_warmup_steps"] 230 | step_0_quantile = trainer.quantile_scheduler(0) 231 | step_half_warmpup_quantile = trainer.quantile_scheduler(quantile_warmup_steps // 2) 232 | step_warump_quantile = trainer.quantile_scheduler(quantile_warmup_steps) 233 | step_n_quantile = trainer.quantile_scheduler(quantile_warmup_steps*2) 234 | 235 | # check that the output has the expected values 236 | assert step_0_quantile == quantile_start 237 | assert step_half_warmpup_quantile == (quantile_start + quantile_end) / 2 238 | assert step_warump_quantile == quantile_end 239 | assert step_n_quantile == quantile_end 240 | 241 | 242 | def test_gen_prep_batches_promix_fn(config_params): 243 | with tempfile.TemporaryDirectory() as temp_dir: 244 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir) 245 | data_prep.tf_record_path = temp_dir 246 | data_prep.make_tf_record() 247 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord") 248 | config_params["record_path"] = [tf_record_path] 249 | config_params["path"] = temp_dir 250 | config_params["batch_constituents"] = ["mplex_img", "binary_mask", "nuclei_img", 251 | "membrane_img"] 252 | config_params["experiment"] = "test" 253 | config_params["dataset_names"] = ["test1"] 254 | config_params["num_steps"] = 20 255 | config_params["dataset_sample_probs"] = [1.0] 256 | config_params["batch_size"] = 1 257 | config_params["test"] = True 258 | config_params["num_validation"] = [2] 259 | config_params["num_test"] = [2] 260 | trainer = PromixNaive(config_params) 261 | trainer.prep_data() 262 | example = next(iter(trainer.train_dataset)) 263 | prep_batches_promix_4 = trainer.prep_batches_promix 264 | prep_batches_promix_2 = trainer.gen_prep_batches_promix_fn( 265 | keys=["mplex_img", "binary_mask"] 266 | ) 267 | 268 | # check if each batch contains the above specified constituents 269 | batch_2 = prep_batches_promix_2(example) 270 | assert batch_2[0].shape[-1] == 1 271 | assert np.array_equal(batch_2[0], example["mplex_img"]) 272 | 273 | batch_4 = prep_batches_promix_4(example) 274 | assert batch_4[0].shape[-1] == 3 275 | assert np.array_equal(batch_4[0], tf.concat([ 276 | example["mplex_img"], example["nuclei_img"], example["membrane_img"] 277 | ], axis=-1) 278 | ) 279 | -------------------------------------------------------------------------------- /tests/simple_data_prep_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import tempfile 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from cell_classification.simple_data_prep import SimpleTFRecords 9 | 10 | from .segmentation_data_prep_test import prepare_test_data_folders 11 | 12 | 13 | def test_get_marker_activity(): 14 | with tempfile.TemporaryDirectory() as temp_dir: 15 | norm_dict = {"CD11c": 1.0, "CD4": 1.0, "CD56": 1.0, "CD57": 1.0} 16 | with open(os.path.join(temp_dir, "norm_dict.json"), "w") as f: 17 | json.dump(norm_dict, f) 18 | data_folders = prepare_test_data_folders( 19 | 5, temp_dir, list(norm_dict.keys()) + ["XYZ"], random=True, 20 | scale=[0.5, 1.0, 1.5, 2.0, 5.0] 21 | ) 22 | cell_table_path = os.path.join(temp_dir, "cell_table.csv") 23 | cell_table = pd.DataFrame( 24 | { 25 | "SampleID": ["fov_0"] * 15 + ["fov_1"] * 15 + ["fov_2"] * 15 + ["fov_3"] * 15 + 26 | ["fov_4"] * 15 + ["fov_5"] * 15, 27 | "labels": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] * 6, 28 | "CD4_gt": [1, 1, 0, 0, 2] * 3 * 3 + 29 | [0, 0, 1, 1, 2] * 3 * 3, 30 | } 31 | ) 32 | 33 | cell_table.to_csv(cell_table_path, index=False) 34 | data_prep = SimpleTFRecords( 35 | data_dir=temp_dir, 36 | tf_record_path=temp_dir, 37 | cell_table_path=cell_table_path, 38 | normalization_dict_path=None, 39 | selected_markers=["CD4"], 40 | imaging_platform="test", 41 | dataset="test", 42 | tissue_type="test", 43 | tile_size=[256, 256], 44 | stride=[256, 256], 45 | nuclei_channels=["CD56"], 46 | membrane_channels=["CD57"], 47 | ) 48 | data_prep.load_and_check_input() 49 | marker = "CD4" 50 | sample_name = "fov_1" 51 | fov_1_subset = cell_table[cell_table.SampleID == sample_name] 52 | data_prep.sample_subset = fov_1_subset 53 | marker_activity, _ = data_prep.get_marker_activity(sample_name, marker) 54 | 55 | # check if the we get marker_acitivity for all labels in the fov_1 subset 56 | assert np.array_equal(marker_activity.labels, fov_1_subset.labels) 57 | 58 | # check if the df has the right marker activity values for a given cell 59 | assert np.array_equal(marker_activity.activity.values, fov_1_subset.CD4_gt.values) 60 | -------------------------------------------------------------------------------- /tests/unet_test.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/jakeret/unet/blob/master/tests/test_unet.py 2 | from unittest.mock import Mock, patch 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras import layers 6 | from tensorflow.keras import losses 7 | import sys 8 | from cell_classification import unet 9 | import pytest 10 | 11 | 12 | class TestPad2D: 13 | 14 | def test_serialization(self): 15 | pad2d = unet.Pad2D(padding=(1, 1), data_format="channels_last") 16 | config = pad2d.get_config() 17 | new_pad2d = unet.Pad2D.from_config(config) 18 | 19 | assert new_pad2d.padding == pad2d.padding 20 | assert new_pad2d.data_format == pad2d.data_format 21 | 22 | def test_padding(self): 23 | for data_format in ["channels_last", "channels_first"]: 24 | for mode in ["CONSTANT", "REFLECT", "SYMMETRIC"]: 25 | print(mode) 26 | pad2d = unet.Pad2D(data_format=data_format, mode=mode) 27 | if data_format == "channels_last": 28 | input_tensor = np.ones([1, 10, 10, 1]) 29 | output_tensor = pad2d(input_tensor) 30 | assert output_tensor.shape == (1, 12, 12, 1) 31 | else: 32 | input_tensor = np.ones([1, 1, 10, 10]) 33 | output_tensor = pad2d(input_tensor) 34 | assert output_tensor.shape == (1, 1, 12, 12) 35 | 36 | # test for valid padding 37 | pad2d = unet.Pad2D(data_format="channels_last", mode="VALID") 38 | input_tensor = np.ones([1, 10, 10, 1]) 39 | output_tensor = pad2d(input_tensor) 40 | assert output_tensor.shape == (1, 10, 10, 1) 41 | 42 | # check if error is raised when mode is not valid 43 | with pytest.raises(ValueError, match="mode must be"): 44 | pad2d = unet.Pad2D(data_format="channels_last", mode="same") 45 | 46 | # check if error is raised when data_format is not valid 47 | with pytest.raises(ValueError, match="data_format must be"): 48 | pad2d = unet.Pad2D(data_format="channels_in_the_middle", mode="CONSTANT") 49 | 50 | 51 | class TestConvBlock: 52 | 53 | def test_serialization(self): 54 | conv_block = unet.ConvBlock(layer_idx=1, 55 | filters_root=16, 56 | kernel_size=3, 57 | padding="REFLECT", 58 | activation="relu", 59 | name="conv_block_test", 60 | data_format="channels_last") 61 | 62 | config = conv_block.get_config() 63 | new_conv_block = unet.ConvBlock.from_config(config) 64 | 65 | assert new_conv_block.layer_idx == conv_block.layer_idx 66 | assert new_conv_block.filters_root == conv_block.filters_root 67 | assert new_conv_block.kernel_size == conv_block.kernel_size 68 | assert new_conv_block.padding == conv_block.padding 69 | assert new_conv_block.activation == conv_block.activation 70 | assert new_conv_block.activation == conv_block.activation 71 | assert new_conv_block.data_format == conv_block.data_format 72 | 73 | 74 | class TestUpconvBlock: 75 | 76 | def test_serialization(self): 77 | upconv_block = unet.UpconvBlock( 78 | layer_idx=1, filters_root=16, kernel_size=3, pool_size=2, padding="REFLECT", 79 | activation="relu", name="upconv_block_test", data_format="channels_last" 80 | ) 81 | 82 | config = upconv_block.get_config() 83 | new_upconv_block = unet.UpconvBlock.from_config(config) 84 | 85 | assert new_upconv_block.layer_idx == upconv_block.layer_idx 86 | assert new_upconv_block.filters_root == upconv_block.filters_root 87 | assert new_upconv_block.kernel_size == upconv_block.kernel_size 88 | assert new_upconv_block.pool_size == upconv_block.pool_size 89 | assert new_upconv_block.padding == upconv_block.padding 90 | assert new_upconv_block.activation == upconv_block.activation 91 | assert new_upconv_block.activation == upconv_block.activation 92 | assert new_upconv_block.data_format == upconv_block.data_format 93 | 94 | 95 | class TestCropConcatBlock(): 96 | 97 | def test_uneven_concat(self): 98 | layer = unet.CropConcatBlock(data_format="channels_last") 99 | down_tensor = np.ones([1, 61, 61, 32]) 100 | up_tensor = np.ones([1, 52, 52, 32]) 101 | 102 | concat_tensor = layer(up_tensor, down_tensor) 103 | 104 | assert concat_tensor.shape == (1, 52, 52, 64) 105 | 106 | 107 | class TestUnetModel: 108 | 109 | def test_serialization(self, tmpdir): 110 | save_path = str(tmpdir / "unet_model") 111 | unet_model = unet.build_model(layer_depth=3, filters_root=2) 112 | unet_model.save(save_path) 113 | 114 | reconstructed_model = tf.keras.models.load_model(save_path) 115 | assert reconstructed_model is not None 116 | 117 | def test_build_model(self): 118 | nx = 512 119 | ny = 512 120 | channels = 3 121 | num_classes = 2 122 | kernel_size = 3 123 | pool_size = 2 124 | filters_root = 64 125 | layer_depth = 5 126 | # same padding 127 | padding = "CONSTANT" 128 | model = unet.build_model(nx=nx, 129 | ny=ny, 130 | channels=channels, 131 | num_classes=num_classes, 132 | layer_depth=layer_depth, 133 | filters_root=filters_root, 134 | kernel_size=kernel_size, 135 | pool_size=pool_size, 136 | padding=padding, 137 | data_format="channels_last") 138 | 139 | input_shape = model.get_layer("inputs").output.shape 140 | assert tuple(input_shape) == (None, nx, ny, channels) 141 | output_shape = model.get_layer("semantic_head").output.shape 142 | assert tuple(output_shape) == (None, nx, ny, num_classes) 143 | 144 | # valid padding 145 | padding = "VALID" 146 | nx = 572 147 | ny = 572 148 | model = unet.build_model(nx=nx, 149 | ny=ny, 150 | channels=channels, 151 | num_classes=num_classes, 152 | layer_depth=layer_depth, 153 | filters_root=filters_root, 154 | kernel_size=kernel_size, 155 | pool_size=pool_size, 156 | padding=padding, 157 | data_format="channels_last") 158 | 159 | input_shape = model.get_layer("inputs").output.shape 160 | assert tuple(input_shape) == (None, nx, ny, channels) 161 | output_shape = model.get_layer("semantic_head").output.shape 162 | assert tuple(output_shape) == (None, 388, 388, num_classes) 163 | 164 | filters_per_layer = [filters_root, 128, 256, 512, 1024, 512, 256, 128, filters_root] 165 | conv2D_layers = _collect_conv2d_layers(model) 166 | 167 | assert len(conv2D_layers) == 2 * len(filters_per_layer) + 1 168 | 169 | for conv2D_layer in conv2D_layers[:-1]: 170 | assert conv2D_layer.kernel_size == (kernel_size, kernel_size) 171 | 172 | for i, filters in enumerate(filters_per_layer): 173 | assert conv2D_layers[i*2].filters == filters 174 | assert conv2D_layers[i*2+1].filters == filters 175 | 176 | maxpool_layers = [layer for layer in model.layers if isinstance(layer, layers.MaxPool2D)] 177 | 178 | assert len(maxpool_layers) == layer_depth - 1 179 | 180 | for maxpool_layer in maxpool_layers[:-1]: 181 | assert maxpool_layer.pool_size == (pool_size, pool_size) 182 | 183 | 184 | def _collect_conv2d_layers(model): 185 | conv2d_layers = [] 186 | for layer in model.layers: 187 | if isinstance(layer, layers.Conv2D): 188 | conv2d_layers.append(layer) 189 | elif isinstance(layer, unet.ConvBlock): 190 | conv2d_layers.append(layer.conv2d_1) 191 | conv2d_layers.append(layer.conv2d_2) 192 | 193 | return conv2d_layers 194 | -------------------------------------------------------------------------------- /tests/viewer_widget_test.py: -------------------------------------------------------------------------------- 1 | from cell_classification.viewer_widget import NimbusViewer 2 | from segmentation_data_prep_test import prep_object_and_inputs 3 | import numpy as np 4 | import tempfile 5 | import os 6 | 7 | 8 | def test_NimbusViewer(): 9 | with tempfile.TemporaryDirectory() as temp_dir: 10 | _, _, _, _ = prep_object_and_inputs(temp_dir) 11 | viewer_widget = NimbusViewer(temp_dir, temp_dir) 12 | assert isinstance(viewer_widget, NimbusViewer) 13 | 14 | 15 | def test_composite_image(): 16 | with tempfile.TemporaryDirectory() as temp_dir: 17 | _, _, _, _ = prep_object_and_inputs(temp_dir) 18 | viewer_widget = NimbusViewer(temp_dir, temp_dir) 19 | path_dict = { 20 | "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"), 21 | "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"), 22 | } 23 | composite_image = viewer_widget.create_composite_image(path_dict) 24 | assert isinstance(composite_image, np.ndarray) 25 | assert composite_image.shape == (256, 256, 2) 26 | 27 | path_dict["blue"] = os.path.join(temp_dir, "fov_0", "CD56.tiff") 28 | composite_image = viewer_widget.create_composite_image(path_dict) 29 | assert composite_image.shape == (256, 256, 3) 30 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Enable line length testing with maximum line length of 99 2 | [pycodestyle] 3 | max-line-length = 99 --------------------------------------------------------------------------------