├── .github
└── workflows
│ ├── build.yml
│ ├── ci.yml
│ ├── pypi_publish.yml
│ ├── release-drafter.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
├── __init__.py
└── halfres_512_checkpoint_160000.h5
├── configs
├── decidua_split.json
├── hickey_split.json
├── msk_colon_split.json
├── msk_pancreas_split.json
├── params.toml
└── tonic_split.json
├── conftest.py
├── pyproject.toml
├── scripts
├── evaluation_script.py
└── hyperparameter_search.py
├── src
├── cell_classification
│ ├── __init__.py
│ ├── application.py
│ ├── augmentation_pipeline.py
│ ├── inference.py
│ ├── loss.py
│ ├── metrics.py
│ ├── model_builder.py
│ ├── plot_utils.py
│ ├── post_processing.py
│ ├── prepare_data_script_MSK_colon.py
│ ├── prepare_data_script_MSK_pancreas.py
│ ├── prepare_data_script_decidua.py
│ ├── prepare_data_script_tonic.py
│ ├── prepare_data_script_tonic_annotated.py
│ ├── promix_naive.py
│ ├── segmentation_data_prep.py
│ ├── simple_data_prep.py
│ ├── unet.py
│ └── viewer_widget.py
└── deepcell
│ ├── __init__.py
│ ├── application.py
│ ├── backbone_utils.py
│ ├── deepcell_toolbox.py
│ ├── fpn.py
│ ├── layers.py
│ ├── panopticnet.py
│ ├── semantic_head.py
│ └── utils.py
├── templates
├── 1_Nimbus_Predict.ipynb
└── 2_Generic_Cell_Clustering.ipynb
├── tests
├── __init__.py
├── application_test.py
├── augmentation_pipeline_test.py
├── deepcell_application_test.py
├── deepcell_backbone_utils_test.py
├── deepcell_layers_test.py
├── deepcell_panopticnet_test.py
├── deepcell_toolbox_test.py
├── deepcell_utils_test.py
├── inference_test.py
├── loss_test.py
├── metrics_test.py
├── model_builder_test.py
├── plot_utils_test.py
├── post_processing_test.py
├── promix_naive_test.py
├── segmentation_data_prep_test.py
├── simple_data_prep_test.py
├── unet_test.py
└── viewer_widget_test.py
└── tox.ini
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Wheel Builder
2 |
3 | on:
4 | workflow_call:
5 |
6 | permissions:
7 | contents: read # to fetch code (actions/checkout)
8 |
9 | jobs:
10 | build:
11 | name: Pure Python Wheel and Source Distribution
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Checkout ${{ github.repository }}
16 | uses: actions/checkout@v3
17 | with:
18 | fetch-depth: 0
19 |
20 | - name: Build Wheels
21 | run: pipx run build
22 |
23 | - name: Check wheel Metadata
24 | run: pipx run twine check dist/*
25 |
26 | - name: Store Wheel Artifacts
27 | uses: actions/upload-artifact@v3
28 | with:
29 | name: distributions
30 | path: dist/*.whl
31 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | types: [labeled, opened, synchronize, reopened]
9 | workflow_dispatch:
10 | merge_group:
11 | types: [checks_requested]
12 | branches: [main]
13 |
14 | concurrency:
15 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
16 | cancel-in-progress: true
17 |
18 | permissions:
19 | contents: read # to fetch code (actions/checkout)
20 |
21 | jobs:
22 | test:
23 | name: Test
24 | permissions:
25 | contents: read
26 | pull-requests: write
27 | secrets: inherit
28 | uses: ./.github/workflows/test.yml
29 |
30 | build:
31 | name: Build
32 | permissions:
33 | contents: read
34 | pull-requests: write
35 | secrets: inherit
36 | uses: ./.github/workflows/build.yml
37 |
38 | upload_coverage:
39 | needs: [test]
40 | name: Upload Coverage
41 | runs-on: ubuntu-latest
42 | steps:
43 | - name: Checkout ${{github.repository }}
44 | uses: actions/checkout@v3
45 | with:
46 | fetch-depth: 0
47 |
48 | - name: Download Coverage Artifact
49 | uses: actions/download-artifact@v3
50 | # if `name` is not specified, all artifacts are downloaded.
51 |
52 | - name: Upload Coverage to Coveralls
53 | uses: coverallsapp/github-action@v2
54 | with:
55 | github-token: ${{ secrets.GITHUB_TOKEN }}
--------------------------------------------------------------------------------
/.github/workflows/pypi_publish.yml:
--------------------------------------------------------------------------------
1 | name: Build Wheels and upload to PyPI
2 |
3 | on:
4 | pull_request:
5 | branches: ["releases/**"]
6 | types: [labeled, opened, synchronize, reopened]
7 | release:
8 | types: [published]
9 | workflow_dispatch:
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
13 | cancel-in-progress: true
14 |
15 | permissions:
16 | contents: read # to fetch code (actions/checkout)
17 |
18 | jobs:
19 | test:
20 | name: Test
21 | permissions:
22 | contents: read
23 | secrets: inherit
24 | uses: ./.github/workflows/test.yml
25 |
26 | build_wheels_sdist:
27 | needs: [test]
28 | name: Build
29 | uses: ./.github/workflows/build.yml
30 | secrets: inherit
31 |
32 | test_pypi_publish:
33 | # Test PyPI publish, requires wheels and source dist (sdist)
34 | name: Publish ${{ github.repository }} to TestPyPI
35 | needs: [test, build_wheels_sdist]
36 | runs-on: ubuntu-latest
37 | steps:
38 | - uses: actions/download-artifact@v3
39 | with:
40 | name: distributions
41 | path: dist
42 |
43 | - uses: pypa/gh-action-pypi-publish@release/v1.6
44 | with:
45 | user: __token__
46 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
47 | repository_url: https://test.pypi.org/legacy/
48 | packages_dir: dist/
49 | verbose: true
50 |
51 | pypi_publish:
52 | name: Publish ${{ github.repository }} to to PyPI
53 | needs: [test, build_wheels_sdist, test_pypi_publish]
54 |
55 | runs-on: ubuntu-latest
56 | # Publish when a GitHub Release is created, use the following rule:
57 | if: github.event_name == 'release' && github.event.action == 'published'
58 | steps:
59 | - name: Download Artifact
60 | uses: actions/download-artifact@v3
61 | with:
62 | name: distributions
63 | path: dist
64 |
65 | - name: PYPI Publish
66 | uses: pypa/gh-action-pypi-publish@release/v1.6
67 | with:
68 | user: __token__
69 | password: ${{ secrets.PYPI_API_TOKEN }}
70 | packages_dir: dist/
71 | verbose: true
72 |
--------------------------------------------------------------------------------
/.github/workflows/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name: Release Drafter
2 |
3 | on:
4 | push:
5 | # branches to consider in the event; optional, defaults to all
6 | branches: ["main"]
7 | # pull_request event is required only for autolabeler
8 | pull_request:
9 | # Only following types are handled by the action, but one can default to all as well
10 | types: [opened, reopened, synchronize]
11 |
12 | issues:
13 | types: [closed]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | update_release_draft:
20 | runs-on: ubuntu-latest
21 | permissions:
22 | contents: write # for release-drafter/release-drafter to create a github release
23 | pull-requests: write # for release-drafter/release-drafter to add label to PR
24 | steps:
25 | # Drafts your next Release notes as Pull Requests are merged into "master"
26 | - uses: release-drafter/release-drafter@v5.20.1
27 | # Specify config name to use, relative to .github/. Default: release-drafter.yml
28 | with:
29 | config-name: release-drafter.yml
30 | # disable-autolabeler: true
31 | env:
32 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
33 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | workflow_call:
5 |
6 | permissions:
7 | contents: read # to fetch code (actions/checkout)
8 | jobs:
9 | test:
10 | name: ${{ github.repository }}
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: Checkout ${{ github.repository }}
15 | uses: actions/checkout@v3
16 | with:
17 | fetch-depth: 0
18 |
19 | - name: Set up Python 3.9
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: 3.9
23 | cache-dependency-path: "**/pyproject.toml"
24 | cache: "pip"
25 |
26 | - name: Install Dependencies and ${{ github.repository }}
27 | run: |
28 | pip install .[test]
29 |
30 | - name: Run Tests
31 | run: |
32 | pytest
33 |
34 | - name: Archive Coverage
35 | uses: actions/upload-artifact@v3
36 | with:
37 | name: coverage
38 | path: |
39 | coverage.lcov
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | .vscode
162 | .DS_Store
163 | coverage.lcov
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Modified Apache License
2 | Version 2.0, January 2004
3 |
4 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
5 |
6 | 1. Definitions.
7 |
8 | "License" shall mean the terms and conditions for use, reproduction,
9 | and distribution as defined by Sections 1 through 9 of this document.
10 |
11 | "Licensor" shall mean the copyright owner or entity authorized by
12 | the copyright owner that is granting the License.
13 |
14 | "Legal Entity" shall mean the union of the acting entity and all
15 | other entities that control, are controlled by, or are under common
16 | control with that entity. For the purposes of this definition,
17 | "control" means (i) the power, direct or indirect, to cause the
18 | direction or management of such entity, whether by contract or
19 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
20 | outstanding shares, or (iii) beneficial ownership of such entity.
21 |
22 | "You" (or "Your") shall mean an individual or Legal Entity
23 | exercising permissions granted by this License.
24 |
25 | "Source" form shall mean the preferred form for making modifications,
26 | including but not limited to software source code, documentation
27 | source, and configuration files.
28 |
29 | "Object" form shall mean any form resulting from mechanical
30 | transformation or translation of a Source form, including but
31 | not limited to compiled object code, generated documentation,
32 | and conversions to other media types.
33 |
34 | "Work" shall mean the work of authorship, whether in Source or
35 | Object form, made available under the License, as indicated by a
36 | copyright notice that is included in or attached to the work
37 | (an example is provided in the Appendix below).
38 |
39 | "Derivative Works" shall mean any work, whether in Source or Object
40 | form, that is based on (or derived from) the Work and for which the
41 | editorial revisions, annotations, elaborations, or other modifications
42 | represent, as a whole, an original work of authorship. For the purposes
43 | of this License, Derivative Works shall not include works that remain
44 | separable from, or merely link (or bind by name) to the interfaces of,
45 | the Work and Derivative Works thereof.
46 |
47 | "Contribution" shall mean any work of authorship, including
48 | the original version of the Work and any modifications or additions
49 | to that Work or Derivative Works thereof, that is intentionally
50 | submitted to Licensor for inclusion in the Work by the copyright owner
51 | or by an individual or Legal Entity authorized to submit on behalf of
52 | the copyright owner. For the purposes of this definition, "submitted"
53 | means any form of electronic, verbal, or written communication sent
54 | to the Licensor or its representatives, including but not limited to
55 | communication on electronic mailing lists, source code control systems,
56 | and issue tracking systems that are managed by, or on behalf of, the
57 | Licensor for the purpose of discussing and improving the Work, but
58 | excluding communication that is conspicuously marked or otherwise
59 | designated in writing by the copyright owner as "Not a Contribution."
60 |
61 | "Contributor" shall mean Licensor and any individual or Legal Entity
62 | on behalf of whom a Contribution has been received by Licensor and
63 | subsequently incorporated within the Work.
64 |
65 | 2. Grant of Copyright License. Subject to the terms and conditions of
66 | this License, each Contributor hereby grants to You a non-commercial,
67 | academic perpetual, worldwide, non-exclusive, no-charge, royalty-free,
68 | irrevocable copyright license to reproduce, prepare Derivative Works
69 | of, publicly display, publicly perform, sublicense, and distribute the
70 | Work and such Derivative Works in Source or Object form. For any other
71 | use, including commercial use, please contact: mangelo0@stanford.edu.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a non-commercial,
75 | academic perpetual, worldwide, non-exclusive, no-charge, royalty-free,
76 | irrevocable (except as stated in this section) patent license to make,
77 | have made, use, offer to sell, sell, import, and otherwise transfer the
78 | Work, where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | 10. Neither the name of Stanford nor the names of its contributors may be
177 | used to endorse or promote products derived from this software without
178 | specific prior written permission.
179 |
180 | END OF TERMS AND CONDITIONS
181 |
182 | APPENDIX: How to apply the Apache License to your work.
183 |
184 | To apply the Apache License to your work, attach the following
185 | boilerplate notice, with the fields enclosed by brackets "[]"
186 | replaced with your own identifying information. (Don't include
187 | the brackets!) The text should be enclosed in the appropriate
188 | comment syntax for the file format. We also recommend that a
189 | file or class name and description of purpose be included on the
190 | same "printed page" as the copyright notice for easier
191 | identification within third-party archives.
192 |
193 | Copyright [yyyy] [name of copyright owner]
194 |
195 | Licensed under the Apache License, Version 2.0 (the "License");
196 | you may not use this file except in compliance with the License.
197 | You may obtain a copy of the License at
198 |
199 | http://www.apache.org/licenses/LICENSE-2.0
200 |
201 | Unless required by applicable law or agreed to in writing, software
202 | distributed under the License is distributed on an "AS IS" BASIS,
203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
204 | See the License for the specific language governing permissions and
205 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nimbus
2 |
3 | The Nimbus repo contains code for training and validation of a machine learning model that classifies cells into marker positive/negative for arbitrary markers and different imaging platforms.
4 |
5 | The code for using the model and running inference on your own data can be found here: [Nimbus-Inference](https://github.com/angelolab/Nimbus-Inference). Code for generating the figures in the paper can be found here: [Publication plots](https://github.com/angelolab/publications/tree/main/2024-Rumberger_Greenwald_etal_Nimbus).
6 |
7 | ## Installation instructions
8 |
9 | Clone the repository
10 |
11 | `git clone https://github.com/angelolab/Nimbus.git`
12 |
13 |
14 | Make a conda environment for Nimbus and activate it
15 |
16 | `conda create -n Nimbus python==3.10`
17 |
18 | `conda activate Nimbus`
19 |
20 | Install CUDA libraries if you have a NVIDIA GPU available
21 |
22 | `conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0`
23 |
24 | Install the package and all depedencies in the conda environment
25 |
26 | `python -m pip install -e Nimbus`
27 |
28 | Install tensorflow-metal if you have an Apple Silicon GPU
29 |
30 | `python -m pip install tensorflow-metal`
31 |
32 | Navigate to the example notebooks and start jupyter
33 |
34 | `cd Nimbus/templates`
35 |
36 | `jupyter notebook`
37 |
38 | ## Citation
39 |
40 | ```bash
41 | @article{rum2024nimbus,
42 | title={Automated classification of cellular expression in multiplexed imaging data with Nimbus},
43 | author={Rumberger, J. Lorenz and Greenwald, Noah F. and Ranek, Jolene S. and Boonrat, Potchara and Walker, Cameron and Franzen, Jannik and Varra, Sricharan Reddy and Kong, Alex and Sowers, Cameron and Liu, Candace C. and Averbukh, Inna and Piyadasa, Hadeesha and Vanguri, Rami and Nederlof, Iris and Wang, Xuefei Julie and Van Valen, David and Kok, Marleen and Hollman, Travis J. and Kainmueller, Dagmar and Angelo, Michael},
44 | journal={bioRxiv},
45 | pages={2024--05},
46 | year={2024},
47 | publisher={Cold Spring Harbor Laboratory}
48 | }
49 | ```
50 |
--------------------------------------------------------------------------------
/checkpoints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/checkpoints/__init__.py
--------------------------------------------------------------------------------
/checkpoints/halfres_512_checkpoint_160000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/checkpoints/halfres_512_checkpoint_160000.h5
--------------------------------------------------------------------------------
/configs/decidua_split.json:
--------------------------------------------------------------------------------
1 | {"test": ["12_31750_16_12", "14_31758_20_4", "16_31762_16_21", "12_31750_1_10", "14_31759_20_6", "6_31728_15_5", "16_31773_4_9", "20_31797_6_9", "6_31731_11_8", "12_31754_16_14", "10_31745_16_9", "16_31770_14_16", "20_31796_6_8", "10_31745_1_6", "6_31731_11_7", "20_31798_6_10", "16_31771_14_14", "16_31767_3_3", "12_31748_1_8", "6_31729_9_4", "8_31736_12_18", "16_31774_14_13", "18_31783_14_11"], "validation": ["8_31734_12_5", "12_31751_1_12", "6_31730_10_10", "6_31729_10_5", "6_31726_8_8", "10_31747_1_7", "6_31730_10_6", "8_31737_12_22", "16_31771_4_3", "6_31733_11_15", "6_31728_9_3", "8_31734_16_1", "16_31770_14_15", "6_31728_9_1", "6_31731_11_5", "8_31734_12_3", "6_31729_10_4", "18_31785_5_14", "20_31793_13_9", "6_31731_15_7", "6_31732_11_10", "18_31783_5_8", "8_31734_12_1"], "train": ["8_31736_12_16", "12_31750_1_11", "6_31728_9_2", "8_31736_12_19", "6_31726_8_5", "14_31755_20_1", "6_31730_18_15", "10_31744_1_3", "18_31776_14_12", "6_31733_11_16", "16_31774_4_12", "14_31759_16_18", "20_31793_6_3", "6_31728_15_4", "20_31793_13_10", "16_31772_4_6", "18_31784_5_12", "20_31787_5_16", "20_31786_14_5", "14_31758_20_5", "6_31727_8_9", "20_31794_13_7", "6_31733_11_14", "12_31754_1_16", "6_31730_10_9", "8_31737_6_18", "16_31768_3_6", "14_31761_20_7", "10_31740_6_11", "8_31735_12_9", "12_31749_16_11", "6_31733_11_17", "20_31789_5_19", "6_31731_11_2", "6_31730_10_7", "10_31741_1_1", "10_31739_6_16", "16_31767_17_1", "18_31778_5_3", "6_31728_15_6", "16_31773_4_8", "20_31788_5_17", "16_31770_4_2", "8_31738_6_17", "8_31735_12_6", "6_31728_8_13", "14_31758_16_17", "12_31754_18_2", "18_31782_5_6", "6_31731_11_4", "12_31754_16_13", "6_31729_10_1", "6_31725_8_1", "6_31726_8_7", "6_31732_11_9", "8_31734_12_2", "10_31739_6_14", "6_31732_11_12", "6_31729_10_2", "20_31791_18_13", "6_31726_8_4", "8_31735_12_14", "10_31739_6_15", "16_31768_17_2", "8_31737_12_23", "18_31783_5_11", "16_31770_4_1", "20_31797_13_3", "20_31793_6_2", "8_31735_12_8", "18_31776_5_2", "16_31772_4_5", "16_31762_20_9", "20_31792_18_14", "18_31784_14_6", "16_31768_3_5", "8_31736_16_3", "6_31727_8_10", "20_31794_6_6", "18_31785_5_13", "20_31791_13_12", "6_31732_11_11", "8_31736_12_17", "18_31785_5_15", "18_31783_5_9", "20_31789_14_1", "20_31791_13_11", "14_31761_16_20", "6_31730_11_1", "16_31763_20_10", "8_31736_16_4", "14_31756_16_16", "20_31793_13_8", "12_31749_1_9", "6_31725_15_1", "8_31735_12_12", "6_31732_15_9", "6_31725_15_2", "20_31795_13_5", "12_31754_1_14", "6_31733_15_10", "8_31735_12_7", "10_31745_16_10", "20_31789_14_2", "20_31798_13_2", "20_31794_6_5", "18_31782_5_7", "16_31772_4_4", "6_31725_8_2", "18_31783_14_8", "6_31730_10_11", "16_31769_3_7", "16_31765_20_11", "20_31795_13_4", "12_31752_1_13", "6_31729_10_3", "14_31756_20_2", "10_31742_1_2", "8_31735_12_11", "10_31745_1_4", "18_31783_14_9", "20_31795_6_7", "20_31787_14_4", "18_31780_5_5", "6_31727_8_12", "16_31773_4_10", "12_31754_1_15", "20_31787_14_3", "20_31798_13_1", "16_31762_20_8", "20_31790_13_14", "6_31725_8_3", "6_31731_15_8", "6_31727_8_11", "6_31733_15_12", "8_31734_11_18", "16_31769_3_8", "16_31765_3_1", "6_31727_15_3", "6_31726_8_6", "10_31740_6_12", "18_31783_14_7", "6_31733_15_11", "10_31745_1_5", "16_31765_19_1", "6_31730_10_8", "10_31743_16_8", "14_31755_18_1", "16_31769_18_11", "18_31779_5_4", "16_31766_3_2", "6_31731_11_6", "8_31735_12_13", "8_31734_12_4", "16_31772_4_7", "20_31793_6_1", "16_31775_5_1", "8_31737_12_20", "14_31760_16_19", "20_31789_5_20", "16_31774_4_11", "20_31791_13_13", "8_31736_12_15", "8_31736_16_5", "10_31740_16_6", "12_31754_16_15", "20_31788_5_18", "8_31734_16_2", "10_31743_16_7", "14_31757_20_3", "8_31737_12_21", "20_31794_13_6", "18_31783_14_10", "6_31731_11_3", "6_31733_11_13", "20_31790_18_12", "20_31794_6_4", "8_31735_12_10", "18_31783_5_10", "10_31739_6_13", "16_31767_3_4", "6_31730_10_12"]}
--------------------------------------------------------------------------------
/configs/hickey_split.json:
--------------------------------------------------------------------------------
1 | {"test": ["B010A_reg003_X01_Y01_Z01", "B011B_reg003_X01_Y01_Z01", "B011B_reg001_X01_Y01_Z01"], "validation": ["B011A_reg001_X01_Y01_Z01", "B012A_reg003_X01_Y01_Z01", "B011A_reg002_X01_Y01_Z01"], "train": ["B010B_reg004_X01_Y01_Z01", "B011A_reg003_X01_Y01_Z01", "B012A_reg004_X01_Y01_Z01", "B009A_reg003_X01_Y01_Z01", "B009A_reg002_X01_Y01_Z01", "B010B_reg003_X01_Y01_Z01", "B010B_reg002_X01_Y01_Z01", "B012B_reg001_X01_Y01_Z01", "B009A_reg004_X01_Y01_Z01", "B011B_reg002_X01_Y01_Z01", "B012A_reg001_X01_Y01_Z01", "B012B_reg004_X01_Y01_Z01", "B011B_reg004_X01_Y01_Z01", "B012B_reg003_X01_Y01_Z01", "B010B_reg001_X01_Y01_Z01", "B010A_reg001_X01_Y01_Z01", "B009B_reg002_X01_Y01_Z01", "B011A_reg004_X01_Y01_Z01", "B009B_reg001_X01_Y01_Z01", "B009B_reg003_X01_Y01_Z01", "B012B_reg002_X01_Y01_Z01", "B009B_reg004_X01_Y01_Z01", "B009A_reg001_X01_Y01_Z01", "B010A_reg004_X01_Y01_Z01", "B012A_reg002_X01_Y01_Z01", "B010A_reg002_X01_Y01_Z01"]}
--------------------------------------------------------------------------------
/configs/params.toml:
--------------------------------------------------------------------------------
1 | record_path = "C:/Users/lorenz/Desktop/angelo_lab/MIBI_test/TNBC_CD45.tfrecord"
2 | path = "C:/Users/lorenz/OneDrive/Desktop/angelo_lab/"
3 | experiment = "test"
4 | project = "Nimbus"
5 | logging_mode = "offline"
6 | model = "ModelBuilder"
7 | num_steps = 20
8 | lr = 1e-3
9 | backbone = "resnet50"
10 | dataset_names = ["TNBC_CD45"]
11 | dataset_sample_probs = [1.0]
12 | input_shape = [256,256,4]
13 | batch_constituents = ["mplex_img", "binary_mask", "nuclei_img", "membrane_img"]
14 | num_validation = [3]
15 | num_test = [3]
16 | shuffle_buffer_size = 2000
17 | flip_prob = 0.5
18 | affine_prob = 0.5
19 | scale_min = 0.8
20 | scale_max = 1.2
21 | shear_angle = 0.2
22 | elastic_prob = 0.5
23 | elastic_alpha = 10
24 | elastic_sigma = 4
25 | rotate_prob = 0.5
26 | rotate_count = 4
27 | gaussian_noise_prob = 0.5
28 | gaussian_noise_min = 0.05
29 | gaussian_noise_max = 0.15
30 | gaussian_blur_prob = 0.5
31 | gaussian_blur_min = 0.05
32 | gaussian_blur_max = 0.15
33 | contrast_prob = 0.5
34 | contrast_min = 0.8
35 | contrast_max = 1.2
36 | mixup_prob = 0.5
37 | mixup_alpha = 4.0
38 | batch_size = 4
39 | loss_fn = "BinaryCrossentropy"
40 | loss_selective_masking = true
41 | quantile = 0.5
42 | quantile_end = 0.9
43 | quantile_warmup_steps = 100000
44 | confidence_thresholds = [0.1, 0.9]
45 | ema = 0.01
46 | location = false
47 | [loss_kwargs]
48 | from_logits = false
49 | label_smoothing = 0.1
50 |
51 | [classes]
52 | marker_positive = 1
53 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Iterator
2 | import numpy as np
3 | import pytest
4 | import toml
5 | import os
6 | import sys
7 |
8 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
9 | sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "tests"))
11 |
12 | @pytest.fixture(scope="function")
13 | def config_params() -> Iterator[Dict]:
14 | params: Dict[str, Any] = toml.load("./configs/params.toml")
15 | yield params
16 |
17 | @pytest.fixture(scope="function")
18 | def rng() -> Iterator[np.random.Generator]:
19 | rng_ = np.random.default_rng(seed=42)
20 | yield rng_
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools",
4 | "wheel",
5 | "setuptools_scm[toml]>=6.2",
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
9 | [project]
10 | dependencies = [
11 | "numpy>=1.20",
12 | "ark-analysis>=0.6.0",
13 | "spektral>=1",
14 | "scikit-image>=0.16",
15 | "scikit-learn>=1.2.2",
16 | "tifffile>=2023",
17 | "tqdm>=4",
18 | "imgaug>=0.4.0",
19 | "opencv-python>=4.6",
20 | "toml",
21 | "seaborn>=0.12",
22 | "alpineer>=0.1.5",
23 | "natsort>=7.1",
24 | "tensorflow~=2.8.0",
25 | "tensorflow_addons~=0.16.1",
26 | "pydot>=1.4.2,<2",
27 | "protobuf",
28 | "wandb"
29 | ]
30 | name = "cell_classification"
31 | authors = [{ name = "Angelo Lab", email = "theangelolab@gmail.com" }]
32 | description = "Cell classification tool for classifying cells into marker positive and negative for arbitrary markers."
33 | readme = "README.md"
34 | requires-python = ">=3.9"
35 | license = { text = "Modified Apache License 2.0" }
36 | classifiers = [
37 | "Development Status :: 4 - Beta",
38 | "Programming Language :: Python :: 3",
39 | "Programming Language :: Python :: 3.9",
40 | "License :: OSI Approved :: Apache Software License",
41 | "Topic :: Scientific/Engineering :: Bio-Informatics",
42 | "Topic :: Scientific/Engineering :: Image Processing",
43 | ]
44 | dynamic = ["version"]
45 | urls = { repository = "https://github.com/angelolab/Nimbus" }
46 |
47 | [project.optional-dependencies]
48 | test = [
49 | "attrs",
50 | "coveralls[toml]",
51 | "pytest",
52 | "pytest-cases",
53 | "pytest-cov",
54 | "pytest-mock",
55 | "pytest-pycodestyle",
56 | "pytest-randomly",
57 | ]
58 |
59 | [tool.setuptools_scm]
60 | version_scheme = "release-branch-semver"
61 | local_scheme = "no-local-version"
62 |
63 | # Coverage
64 | [tool.coverage.paths]
65 | source = ["src", "*/site-packages"]
66 |
67 | [tool.coverage.run]
68 | branch = true
69 | source = ["cell_classification"]
70 |
71 | [tool.coverage.report]
72 | exclude_lines = [
73 | "except ImportError",
74 | "raise AssertionError",
75 | "raise NotImplementedError",
76 | ]
77 |
78 | # Pytest Options
79 | [tool.pytest.ini_options]
80 | filterwarnings = [
81 | "ignore::DeprecationWarning",
82 | "ignore::PendingDeprecationWarning",
83 | ]
84 | addopts = [
85 | "-v",
86 | "-s",
87 | "--durations=20",
88 | "--randomly-seed=42",
89 | "--randomly-dont-reorganize",
90 | "--cov=cell_classification",
91 | "--cov-report=lcov",
92 | "--pycodestyle",
93 | ]
94 | console_output_style = "count"
95 | testpaths = ["tests"]
96 |
--------------------------------------------------------------------------------
/scripts/evaluation_script.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import tensorflow as tf
8 | import toml
9 |
10 | from cell_classification.metrics import (average_roc, calc_metrics, calc_roc,
11 | load_model)
12 | from cell_classification.plot_utils import (heatmap_plot, plot_average_roc,
13 | plot_metrics_against_threshold,
14 | plot_together, subset_plots)
15 | from cell_classification.segmentation_data_prep import (feature_description,
16 | parse_dict)
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument(
21 | "--model_path",
22 | type=str,
23 | help="Path to model weights",
24 | default=None,
25 | )
26 | parser.add_argument(
27 | "--params_path",
28 | type=str,
29 | help="Path to model params",
30 | default="E:\\angelo_lab\\test\\params.toml",
31 | )
32 | parser.add_argument(
33 | "--worst_n",
34 | type=int,
35 | help="Number of worst predictions to plot",
36 | default=20,
37 | )
38 | parser.add_argument(
39 | "--best_n",
40 | type=int,
41 | help="Number of best predictions to plot",
42 | default=20,
43 | )
44 | parser.add_argument(
45 | "--split_by_marker",
46 | type=bool,
47 | help="Split best/worst predictions by marker",
48 | default=True,
49 | )
50 | parser.add_argument(
51 | "--external_datasets",
52 | type=str,
53 | help="List of paths to tfrecord datasets",
54 | nargs='+',
55 | default=[],
56 | )
57 | args = parser.parse_args()
58 | with open(args.params_path, "r") as f:
59 | params = toml.load(f)
60 | if args.model_path is not None:
61 | params["model_path"] = args.model_path
62 |
63 | model = load_model(params)
64 | datasets = {name: dataset for name, dataset in zip(model.dataset_names, model.test_datasets)}
65 | if hasattr(args, "external_datasets"):
66 | external_datasets = {
67 | os.path.split(external_dataset)[-1].split(".")[0]: external_dataset.replace(",", "")
68 | for external_dataset in args.external_datasets
69 | }
70 | external_datasets = {
71 | key: tf.data.TFRecordDataset(external_datasets[key])
72 | for key in external_datasets.keys()
73 | }
74 | external_datasets = {
75 | name: dataset.map(
76 | lambda x: tf.io.parse_single_example(x, feature_description),
77 | num_parallel_calls=tf.data.AUTOTUNE,
78 | ) for name, dataset in external_datasets.items()
79 | }
80 | external_datasets = {
81 | name: dataset.map(
82 | parse_dict, num_parallel_calls=tf.data.AUTOTUNE
83 | ) for name, dataset in external_datasets.items()
84 | }
85 | external_datasets = {
86 | name: dataset.batch(
87 | params["batch_size"], drop_remainder=False
88 | ) for name, dataset in external_datasets.items()
89 | }
90 |
91 | datasets.update(external_datasets)
92 |
93 | for name, val_dset in datasets.items():
94 | params["eval_dir"] = os.path.join(*os.path.split(params["model_path"])[:-1], "eval", name)
95 | os.makedirs(params["eval_dir"], exist_ok=True)
96 | # iterate over datasets
97 | pred_list = model.predict_dataset(val_dset, False)
98 |
99 | # prepare cell_table
100 | activity_list = []
101 | for pred in pred_list:
102 | activity_df = pred["activity_df"].copy()
103 | for key in ["dataset", "marker", "folder_name"]:
104 | activity_df[key] = [pred[key]]*len(activity_df)
105 | activity_list.append(activity_df)
106 | activity_df = pd.concat(activity_list)
107 | activity_df.to_csv(os.path.join(params["eval_dir"], "pred_cell_table.csv"), index=False)
108 |
109 | # cell level evaluation
110 | roc = calc_roc(pred_list, gt_key="activity", pred_key="pred_activity", cell_level=True)
111 | with open(os.path.join(params["eval_dir"], "roc_cell_lvl.pkl"), "wb") as f:
112 | pickle.dump(roc, f)
113 |
114 | # find index of n worst predictions and save plots of them
115 | roc_df = pd.DataFrame(roc)
116 | if args.split_by_marker:
117 | worst_idx = []
118 | best_idx = []
119 | markers = np.unique(roc_df.marker)
120 | for marker in markers:
121 | marker_df = roc_df[roc_df.marker == marker]
122 | sort_idx = np.argsort(marker_df.auc).index
123 | worst_idx.extend(sort_idx[-args.worst_n:])
124 | best_idx.extend(sort_idx[:args.best_n])
125 | else:
126 | sort_idx = np.argsort(roc["auc"])
127 | worst_idx = sort_idx[-args.worst_n:]
128 | best_idx = sort_idx[:args.best_n]
129 | for idx_list, best_worst in [(best_idx, "best"), (worst_idx, "worst")]:
130 | for i, idx in enumerate(idx_list):
131 | pred = pred_list[idx]
132 | plot_together(
133 | pred, keys=["mplex_img", "marker_activity_mask", "prediction"],
134 | save_dir=os.path.join(params["eval_dir"], best_worst + "_predictions"),
135 | save_file="worst_{}_{}_{}.png".format(
136 | i, pred["marker"], pred["dataset"], pred["folder_name"]
137 | )
138 | )
139 |
140 | pd.DataFrame(roc).auc
141 | tprs, mean_tprs, fpr, std, mean_thresh = average_roc(roc)
142 | plot_average_roc(
143 | mean_tprs, std, save_dir=params["eval_dir"], save_file="avg_roc_cell_lvl.png"
144 | )
145 | print("AUC: {}".format(np.mean(roc["auc"])))
146 |
147 | print("Calculate precision, recall, f1_score and accuracy on the cell level")
148 | avg_metrics = calc_metrics(
149 | pred_list, gt_key="activity", pred_key="pred_activity", cell_level=True
150 | )
151 | pd.DataFrame(avg_metrics).to_csv(
152 | os.path.join(params["eval_dir"], "cell_metrics.csv"), index=False
153 | )
154 |
155 | plot_metrics_against_threshold(
156 | avg_metrics,
157 | metric_keys=["precision", "recall", "f1_score", "specificity"],
158 | threshold_key="threshold",
159 | save_dir=params["eval_dir"],
160 | save_file="precision_recall_f1_cell_lvl.png",
161 | )
162 |
163 | print("Plot activity predictions split by markers and cell types")
164 | subset_plots(
165 | activity_df, subset_list=["marker"],
166 | save_dir=params["eval_dir"],
167 | save_file="split_by_marker.png",
168 | gt_key="activity",
169 | pred_key="pred_activity",
170 | )
171 | if "cell_type" in activity_df.columns:
172 | subset_plots(
173 | activity_df, subset_list=["cell_type"],
174 | save_dir=params["eval_dir"],
175 | save_file="split_by_cell_type.png",
176 | gt_key="activity",
177 | pred_key="pred_activity",
178 | )
179 | heatmap_plot(
180 | activity_df, subset_list=["marker"],
181 | save_dir=params["eval_dir"],
182 | save_file="heatmap_split_by_marker.png",
183 | gt_key="activity",
184 | pred_key="pred_activity",
185 | )
186 |
--------------------------------------------------------------------------------
/scripts/hyperparameter_search.py:
--------------------------------------------------------------------------------
1 | from cell_classification.model_builder import ModelBuilder
2 | from cell_classification.metrics import calc_scores
3 | from joblib import Parallel, delayed
4 | from tqdm import tqdm
5 | import numpy as np
6 | import argparse
7 | import toml
8 | import os
9 |
10 |
11 | def hyperparameter_search(params, n_jobs=10, checkpoint=None):
12 | """
13 | Hyperparameter search for the best pos/neg threshold per marker and dataset
14 | Args:
15 | params: configs comparable to Nimbus/configs/params.toml
16 | n_jobs: number of jobs for parallelization
17 | checkpoint: name of checkpoint
18 | Returns:
19 | """
20 | optimal_thresh_dict = {}
21 | model = ModelBuilder(params)
22 | model.prep_data()
23 | if checkpoint:
24 | ckpt_path = os.path.join(params["model_dir"], checkpoint)
25 | else:
26 | ckpt_path = None
27 | df = model.predict_dataset_list(
28 | model.validation_datasets, save_predictions=False, ckpt_path=ckpt_path
29 | )
30 | print("Run hyperparameter search")
31 | thresholds = np.linspace(0, 1, 101)
32 | thresholds = [np.round(thresh, 2) for thresh in thresholds]
33 | for dataset in df.dataset.unique():
34 | df_subset = df[df.dataset == dataset]
35 | if dataset not in optimal_thresh_dict.keys():
36 | optimal_thresh_dict[dataset] = {}
37 | for marker in tqdm(df_subset.marker.unique()):
38 | df_subset_marker = df_subset[df_subset.marker == marker]
39 | metrics = Parallel(n_jobs=n_jobs)(
40 | delayed(calc_scores)(
41 | df_subset_marker["activity"].astype(np.int32), df_subset_marker["prediction"],
42 | threshold=thresh
43 | ) for thresh in thresholds
44 | )
45 | f1_scores = [metric["f1_score"] for metric in metrics]
46 | optimal_thresh_dict[dataset][marker] = thresholds[np.argmax(f1_scores)]
47 | # assign classes based on optimal thresholds
48 | df["pred_class"] = df.apply(
49 | lambda row: 1 if row["prediction"] >= optimal_thresh_dict[row["dataset"]][row["marker"]] else 0,
50 | axis=1
51 | )
52 | df.to_csv(os.path.join(
53 | params["path"], params["experiment"], "{}_validation_predictions.csv".format(checkpoint)
54 | ))
55 | # save as toml
56 | fpath = os.path.join(
57 | params["path"], params["experiment"], "{}_optimal_thresholds.toml".format(checkpoint)
58 | )
59 | with open(fpath, "w") as f:
60 | toml.dump(optimal_thresh_dict, f)
61 | return optimal_thresh_dict, df
62 |
63 |
64 | def prepare_testset_predictions(params, optimal_thresh_dict, checkpoint=None):
65 | """
66 | Prepare testset predictions based on optimal thresholds
67 | Args:
68 | params: configs comparable to Nimbus/configs/params.toml
69 | optimal_thresh_dict: optimal thresholds per marker and dataset
70 | checkpoint: name of checkpoint
71 | Returns:
72 | """
73 | model = ModelBuilder(params)
74 | model.prep_data()
75 | if checkpoint:
76 | ckpt_path = os.path.join(params["model_dir"], checkpoint)
77 | else:
78 | ckpt_path = None
79 | df = model.predict_dataset_list(model.test_datasets, save_predictions=False, ckpt_path=ckpt_path)
80 | # assign classes based on optimal thresholds
81 | df["pred_class"] = df.apply(
82 | lambda row: 1 if row["prediction"] >= optimal_thresh_dict[row["dataset"]][row["marker"]] else 0,
83 | axis=1
84 | )
85 | df.to_csv(os.path.join(
86 | params["path"], params["experiment"], "{}_test_predictions.csv".format(checkpoint)
87 | ))
88 | # calculate test set metrics
89 | metrics_dict = {}
90 | for dataset in df.dataset.unique():
91 | df_subset = df[df.dataset == dataset]
92 | if dataset not in metrics_dict.keys():
93 | metrics_dict[dataset] = {}
94 | for marker in df_subset.marker.unique():
95 | df_subset_marker = df_subset[df_subset.marker == marker]
96 | metrics_dict[dataset][marker] = calc_scores(
97 | df_subset_marker["activity"].astype(np.int32), df_subset_marker["prediction"],
98 | threshold=optimal_thresh_dict[dataset][marker]
99 | )
100 | # calculate per dataset scores and save as toml
101 | dataset_scores = {}
102 | for dataset in metrics_dict.keys():
103 | dataset_scores[dataset] = {
104 | "precision": np.mean([metrics_dict[dataset][marker]["precision"] for marker in metrics_dict[dataset].keys()]),
105 | "recall": np.mean([metrics_dict[dataset][marker]["recall"] for marker in metrics_dict[dataset].keys()]),
106 | "f1_score": np.mean([metrics_dict[dataset][marker]["f1_score"] for marker in metrics_dict[dataset].keys()]),
107 | "accuracy": np.mean([metrics_dict[dataset][marker]["accuracy"] for marker in metrics_dict[dataset].keys()]),
108 | "specificity": np.mean([metrics_dict[dataset][marker]["specificity"] for marker in metrics_dict[dataset].keys()])
109 | }
110 | fpath = os.path.join(
111 | params["path"], params["experiment"], "{}_test_scores.toml".format(checkpoint)
112 | )
113 | with open(fpath, "w") as f:
114 | toml.dump(dataset_scores, f)
115 | return df
116 |
117 |
118 | if __name__ == "__main__":
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument("--params", type=str, default="configs/params.toml")
121 | parser.add_argument("--ckpt", type=str, default=None)
122 | args = parser.parse_args()
123 | params = toml.load(args.params)
124 | if not os.path.exists(
125 | os.path.join(params["path"], params["experiment"], "optimal_thresholds.toml")
126 | ):
127 | optimal_thresh_dict,_ = hyperparameter_search(params, checkpoint=args.ckpt)
128 | prepare_testset_predictions(params, optimal_thresh_dict, checkpoint=args.ckpt)
129 |
--------------------------------------------------------------------------------
/src/cell_classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/src/cell_classification/__init__.py
--------------------------------------------------------------------------------
/src/cell_classification/application.py:
--------------------------------------------------------------------------------
1 | from deepcell.panopticnet import PanopticNet
2 | from deepcell.semantic_head import create_semantic_head
3 | from deepcell.application import Application
4 | from alpineer import io_utils
5 | from cell_classification.inference import prepare_normalization_dict, predict_fovs
6 | import cell_classification
7 | from pathlib import Path
8 | from glob import glob
9 | import tensorflow as tf
10 | import numpy as np
11 | import json
12 | import os
13 |
14 |
15 | def nimbus_preprocess(image, **kwargs):
16 | """Preprocess input data for Nimbus model.
17 | Args:
18 | image: array to be processed
19 | Returns:
20 | np.array: processed image array
21 | """
22 | output = np.copy(image)
23 | if len(image.shape) != 4:
24 | raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape))
25 |
26 | normalize = kwargs.get('normalize', True)
27 | if normalize:
28 | marker = kwargs.get('marker', None)
29 | normalization_dict = kwargs.get('normalization_dict', {})
30 | if marker in normalization_dict.keys():
31 | norm_factor = normalization_dict[marker]
32 | else:
33 | print("Norm_factor not found for marker {}, calculating directly from the image. \
34 | ".format(marker))
35 | norm_factor = np.quantile(output[..., 0], 0.999)
36 | # normalize only marker channel in chan 0 not binary mask in chan 1
37 | output[..., 0] /= norm_factor
38 | output = output.clip(0, 1)
39 | return output
40 |
41 |
42 | def nimbus_postprocess(model_output):
43 | return model_output
44 |
45 |
46 | def format_output(model_output):
47 | return model_output[0]
48 |
49 |
50 | def prep_deepcell_naming_convention(deepcell_output_dir):
51 | """Prepares the naming convention for the segmentation data
52 | Args:
53 | deepcell_output_dir (str): path to directory where segmentation data is saved
54 | Returns:
55 | segmentation_naming_convention (function): function that returns the path to the
56 | segmentation data for a given fov
57 | """
58 | def segmentation_naming_convention(fov_path):
59 | """Prepares the path to the segmentation data for a given fov
60 | Args:
61 | fov_path (str): path to fov
62 | Returns:
63 | seg_path (str): paths to segmentation fovs
64 | """
65 | fov_name = os.path.basename(fov_path)
66 | return os.path.join(
67 | deepcell_output_dir, fov_name + "_whole_cell.tiff"
68 | )
69 | return segmentation_naming_convention
70 |
71 |
72 | class Nimbus(Application):
73 | """Nimbus application class for predicting marker activity for cells in multiplexed images.
74 | """
75 | def __init__(
76 | self, fov_paths, segmentation_naming_convention, output_dir,
77 | save_predictions=True, exclude_channels=[], half_resolution=True,
78 | batch_size=4, test_time_aug=True, input_shape=[1024,1024]
79 | ):
80 | """Initializes a Nimbus Application.
81 | Args:
82 | fov_paths (list): List of paths to fovs to be analyzed.
83 | exclude_channels (list): List of channels to exclude from analysis.
84 | segmentation_naming_convention (function): Function that returns the path to the
85 | segmentation mask for a given fov path.
86 | output_dir (str): Path to directory to save output.
87 | save_predictions (bool): Whether to save predictions.
88 | half_resolution (bool): Whether to run model on half resolution images.
89 | batch_size (int): Batch size for model inference.
90 | test_time_aug (bool): Whether to use test time augmentation.
91 | input_shape (list): Shape of input images.
92 | """
93 | self.fov_paths = fov_paths
94 | self.exclude_channels = exclude_channels
95 | self.segmentation_naming_convention = segmentation_naming_convention
96 | self.output_dir = output_dir
97 | self.half_resolution = half_resolution
98 | self.save_predictions = save_predictions
99 | self._batch_size = batch_size
100 | self.checked_inputs = False
101 | self.test_time_aug = test_time_aug
102 | self.input_shape = input_shape
103 | # exclude segmentation channel from analysis
104 | seg_name = os.path.basename(self.segmentation_naming_convention(self.fov_paths[0]))
105 | self.exclude_channels.append(seg_name.split(".")[0])
106 | if self.output_dir != '':
107 | os.makedirs(self.output_dir, exist_ok=True)
108 |
109 | # initialize model and parent class
110 | self.initialize_model()
111 |
112 | super(Nimbus, self).__init__(
113 | model=self.model,
114 | model_image_shape=self.model.input_shape[1:],
115 | preprocessing_fn=nimbus_preprocess,
116 | postprocessing_fn=nimbus_postprocess,
117 | format_model_output_fn=format_output,
118 | )
119 |
120 | def check_inputs(self):
121 | """ check inputs for Nimbus model
122 | """
123 | # check if all paths in fov_paths exists
124 | io_utils.validate_paths(self.fov_paths)
125 |
126 | # check if segmentation_naming_convention returns valid paths
127 | path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0])
128 | if not os.path.exists(path_to_segmentation):
129 | raise FileNotFoundError("Function segmentation_naming_convention does not return valid\
130 | path. Segmentation path {} does not exist."\
131 | .format(path_to_segmentation))
132 | # check if output_dir exists
133 | io_utils.validate_paths([self.output_dir])
134 |
135 | if isinstance(self.exclude_channels, str):
136 | self.exclude_channels = [self.exclude_channels]
137 | self.checked_inputs = True
138 | print("All inputs are valid.")
139 |
140 | def initialize_model(self):
141 | """Initializes the model and load weights.
142 | """
143 | backbone = "efficientnetv2bs"
144 | input_shape = self.input_shape + [2]
145 | model = PanopticNet(
146 | backbone=backbone, input_shape=input_shape,
147 | norm_method="std", num_semantic_classes=[1],
148 | create_semantic_head=create_semantic_head, location=False,
149 | )
150 | # make sure path can be resolved on any OS and when importing from anywhere
151 | self.checkpoint_path = os.path.normpath(
152 | "../cell_classification/checkpoints/halfres_512_checkpoint_160000.h5"
153 | )
154 | if not os.path.exists(self.checkpoint_path):
155 | path = os.path.abspath(cell_classification.__file__)
156 | path = Path(path).resolve()
157 | self.checkpoint_path = os.path.join(
158 | *path.parts[:-3], 'checkpoints', 'halfres_512_checkpoint_160000.h5'
159 | )
160 | if not os.path.exists(self.checkpoint_path):
161 | self.checkpoint_path = os.path.abspath(*glob('**/halfres_512_checkpoint_160000.h5'))
162 |
163 | if not os.path.exists(self.checkpoint_path):
164 | self.checkpoint_path = os.path.join(
165 | os.getcwd(), 'checkpoints', 'halfres_512_checkpoint_160000.h5'
166 | )
167 |
168 | if os.path.exists(self.checkpoint_path):
169 | model.load_weights(self.checkpoint_path)
170 | print("Loaded weights from {}".format(self.checkpoint_path))
171 | else:
172 | raise FileNotFoundError("Could not find Nimbus weights at {ckpt_path}. \
173 | Current path is {current_path} and directory contains {dir_c},\
174 | path to cell_clasification i{p}".format(
175 | ckpt_path=self.checkpoint_path,
176 | current_path=os.getcwd(),
177 | dir_c=os.listdir(os.getcwd()),
178 | p=os.path.abspath(cell_classification.__file__)
179 | )
180 | )
181 | self.model = model
182 |
183 | def prepare_normalization_dict(
184 | self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False,
185 | ):
186 | """Load or prepare and save normalization dictionary for Nimbus model.
187 | Args:
188 | quantile (float): Quantile to use for normalization.
189 | n_subset (int): Number of fovs to use for normalization.
190 | multiprocessing (bool): Whether to use multiprocessing.
191 | overwrite (bool): Whether to overwrite existing normalization dict.
192 | Returns:
193 | dict: Dictionary of normalization factors.
194 | """
195 | self.normalization_dict_path = os.path.join(self.output_dir, "normalization_dict.json")
196 | if os.path.exists(self.normalization_dict_path) and not overwrite:
197 | self.normalization_dict = json.load(open(self.normalization_dict_path))
198 | else:
199 |
200 | n_jobs = os.cpu_count() if multiprocessing else 1
201 | self.normalization_dict = prepare_normalization_dict(
202 | self.fov_paths, self.output_dir, quantile, self.exclude_channels, n_subset, n_jobs
203 | )
204 |
205 | def predict_fovs(self):
206 | """Predicts cell classification for input data.
207 | Returns:
208 | np.array: Predicted cell classification.
209 | """
210 | if self.checked_inputs == False:
211 | self.check_inputs()
212 | if not hasattr(self, "normalization_dict"):
213 | self.prepare_normalization_dict()
214 | # check if GPU is available
215 | print("Available GPUs: ", tf.config.list_physical_devices('GPU'))
216 | print("Predictions will be saved in {}".format(self.output_dir))
217 | print("Iterating through fovs will take a while...")
218 | self.cell_table = predict_fovs(
219 | self.fov_paths, self.output_dir, self, self.normalization_dict,
220 | self.segmentation_naming_convention, self.exclude_channels, self.save_predictions,
221 | self.half_resolution, batch_size=self._batch_size,
222 | test_time_augmentation=self.test_time_aug,
223 | )
224 | self.cell_table.to_csv(
225 | os.path.join(self.output_dir,"nimbus_cell_table.csv"), index=False
226 | )
227 | return self.cell_table
228 |
229 | Nimbus(["none_path"], lambda x: x, "")
--------------------------------------------------------------------------------
/src/cell_classification/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import random
5 | import numpy as np
6 | import pandas as pd
7 | from skimage import io
8 | import tensorflow as tf
9 | import matplotlib.pyplot as plt
10 | from tqdm.autonotebook import tqdm
11 | from joblib import Parallel, delayed
12 | from skimage.segmentation import find_boundaries
13 |
14 |
15 | def calculate_normalization(channel_path, quantile):
16 | """Calculates the normalization value for a given channel
17 | Args:
18 | channel_path (str): path to channel
19 | quantile (float): quantile to use for normalization
20 | Returns:
21 | normalization_value (float): normalization value
22 | """
23 | mplex_img = io.imread(channel_path)
24 | normalization_value = np.quantile(mplex_img, quantile)
25 | chan = os.path.basename(channel_path).split(".")[0]
26 | return chan, normalization_value
27 |
28 |
29 | def prepare_normalization_dict(
30 | fov_paths, output_dir, quantile=0.999, exclude_channels=[], n_subset=10, n_jobs=1,
31 | output_name="normalization_dict.json"
32 | ):
33 | """Prepares the normalization dict for a list of fovs
34 | Args:
35 | fov_paths (list): list of paths to fovs
36 | output_dir (str): path to output directory
37 | quantile (float): quantile to use for normalization
38 | exclude_channels (list): list of channels to exclude
39 | n_subset (int): number of fovs to use for normalization
40 | n_jobs (int): number of jobs to use for joblib multiprocessing
41 | output_name (str): name of output file
42 | Returns:
43 | normalization_dict (dict): dict with channel names as keys and norm factors as values
44 | """
45 | normalization_dict = {}
46 | if n_subset is not None:
47 | random.shuffle(fov_paths)
48 | fov_paths = fov_paths[:n_subset]
49 | print("Iterate over fovs...")
50 | for fov_path in tqdm(fov_paths):
51 | channels = os.listdir(fov_path)
52 | channels = [
53 | channel for channel in channels if channel.split(".")[0] not in exclude_channels
54 | ]
55 | channel_paths = [os.path.join(fov_path, channel) for channel in channels]
56 | if n_jobs > 1:
57 | normalization_values = Parallel(n_jobs=n_jobs)(
58 | delayed(calculate_normalization)(channel_path, quantile)
59 | for channel_path in channel_paths
60 | )
61 | else:
62 | normalization_values = [
63 | calculate_normalization(channel_path, quantile)
64 | for channel_path in channel_paths
65 | ]
66 | for channel, normalization_value in normalization_values:
67 | if channel not in normalization_dict:
68 | normalization_dict[channel] = []
69 | normalization_dict[channel].append(normalization_value)
70 | for channel in normalization_dict.keys():
71 | normalization_dict[channel] = np.mean(normalization_dict[channel])
72 | # save normalization dict
73 | with open(os.path.join(output_dir, output_name), 'w') as f:
74 | json.dump(normalization_dict, f)
75 | return normalization_dict
76 |
77 |
78 | def prepare_input_data(mplex_img, instance_mask):
79 | """Prepares the input data for the segmentation model
80 | Args:
81 | mplex_img (np.array): multiplex image
82 | instance_mask (np.array): instance mask
83 | Returns:
84 | input_data (np.array): input data for segmentation model
85 | """
86 | edge = find_boundaries(instance_mask, mode="inner").astype(np.uint8)
87 | binary_mask = np.logical_and(edge == 0, instance_mask > 0).astype(np.float32)
88 | input_data = np.stack([mplex_img, binary_mask], axis=-1)[np.newaxis,...] # bhwc
89 | return input_data
90 |
91 |
92 | def segment_mean(instance_mask, prediction):
93 | """Calculates the mean prediction per instance
94 | Args:
95 | instance_mask (np.array): instance mask
96 | prediction (np.array): prediction
97 | Returns:
98 | uniques (np.array): unique instance ids
99 | mean_per_cell (np.array): mean prediction per instance
100 | """
101 | instance_mask_flat = tf.cast(tf.reshape(instance_mask, -1), tf.int32) # (h*w)
102 | pred_flat = tf.cast(tf.reshape(prediction, -1), tf.float32)
103 | sort_order = tf.argsort(instance_mask_flat)
104 | instance_mask_flat = tf.gather(instance_mask_flat, sort_order)
105 | uniques, _ = tf.unique(instance_mask_flat)
106 | pred_flat = tf.gather(pred_flat, sort_order)
107 | mean_per_cell = tf.math.segment_mean(pred_flat, instance_mask_flat)
108 | mean_per_cell = tf.gather(mean_per_cell, uniques)
109 | return [uniques.numpy()[1:], mean_per_cell.numpy()[1:]] # discard background
110 |
111 |
112 | def test_time_aug(
113 | input_data, channel, app, normalization_dict, rotate=True, flip=True, batch_size=4
114 | ):
115 | """Performs test time augmentation
116 | Args:
117 | input_data (np.array): input data for segmentation model, mplex_img and binary mask
118 | channel (str): channel name
119 | app (tf.keras.Model): segmentation model
120 | normalization_dict (dict): dict with channel names as keys and norm factors as values
121 | rotate (bool): whether to rotate
122 | flip (bool): whether to flip
123 | batch_size (int): batch size
124 | Returns:
125 | seg_map (np.array): predicted segmentation map
126 | """
127 | forward_augmentations = []
128 | backward_augmentations = []
129 | if rotate:
130 | for k in [0,1,2,3]:
131 | forward_augmentations.append(lambda x: tf.image.rot90(x, k=k))
132 | backward_augmentations.append(lambda x: tf.image.rot90(x, k=-k))
133 | if flip:
134 | forward_augmentations += [
135 | lambda x: tf.image.flip_left_right(x),
136 | lambda x: tf.image.flip_up_down(x)
137 | ]
138 | backward_augmentations += [
139 | lambda x: tf.image.flip_left_right(x),
140 | lambda x: tf.image.flip_up_down(x)
141 | ]
142 | input_batch = []
143 | for forw_aug in forward_augmentations:
144 | input_data_tmp = forw_aug(input_data).numpy() # bhwc
145 | input_batch.append(np.concatenate(input_data_tmp))
146 | input_batch = np.stack(input_batch, 0)
147 | seg_map = app._predict_segmentation(
148 | input_batch,
149 | batch_size=batch_size,
150 | preprocess_kwargs={
151 | "normalize": True,
152 | "marker": channel,
153 | "normalization_dict": normalization_dict},
154 | )
155 | tmp = []
156 | for backw_aug, seg_map_tmp in zip(backward_augmentations, seg_map):
157 | seg_map_tmp = backw_aug(seg_map_tmp[np.newaxis,...])
158 | seg_map_tmp = np.squeeze(seg_map_tmp)
159 | tmp.append(seg_map_tmp)
160 | seg_map = np.stack(tmp, -1)
161 | seg_map = np.mean(seg_map, axis = -1, keepdims = True)
162 | return seg_map
163 |
164 |
165 | def predict_fovs(
166 | fov_paths, cell_classification_output_dir, app, normalization_dict,
167 | segmentation_naming_convention, exclude_channels=[], save_predictions=True,
168 | half_resolution=False, batch_size=4, test_time_augmentation=True
169 | ):
170 | """Predicts the segmentation map for each mplex image in each fov
171 | Args:
172 | fov_paths (list): list of fov paths
173 | cell_classification_output_dir (str): path to cell classification output dir
174 | app (deepcell.applications.Application): segmentation model
175 | normalization_dict (dict): dict with channel names as keys and norm factors as values
176 | segmentation_naming_convention (function): function to get instance mask path from fov path
177 | exclude_channels (list): list of channels to exclude
178 | save_predictions (bool): whether to save predictions
179 | half_resolution (bool): whether to use half resolution
180 | batch_size (int): batch size
181 | test_time_augmentation (bool): whether to use test time augmentation
182 | Returns:
183 | cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell
184 | """
185 | fov_dict_list = []
186 | for fov_path in tqdm(fov_paths):
187 | out_fov_path = os.path.join(
188 | os.path.normpath(cell_classification_output_dir), os.path.basename(fov_path)
189 | )
190 | fov_dict = {}
191 | for channel in os.listdir(fov_path):
192 | channel_path = os.path.join(fov_path, channel)
193 | if not channel.endswith(".tiff"):
194 | continue
195 | if channel[:2] == "._":
196 | continue
197 | channel = channel.split(".")[0]
198 | if channel in exclude_channels:
199 | continue
200 | mplex_img = np.squeeze(io.imread(channel_path))
201 | instance_path = segmentation_naming_convention(fov_path)
202 | instance_mask = np.squeeze(io.imread(instance_path))
203 | input_data = prepare_input_data(mplex_img, instance_mask)
204 | if half_resolution:
205 | scale = 0.5
206 | input_data = np.squeeze(input_data)
207 | h,w,_ = input_data.shape
208 | img = cv2.resize(input_data[...,0], [int(h*scale), int(w*scale)])
209 | binary_mask = cv2.resize(
210 | input_data[...,1], [int(h*scale), int(w*scale)], interpolation=0
211 | )
212 | input_data = np.stack([img, binary_mask], axis=-1)[np.newaxis,...]
213 | if test_time_augmentation:
214 | prediction = test_time_aug(
215 | input_data, channel, app, normalization_dict, batch_size=batch_size
216 | )
217 | else:
218 | prediction = app._predict_segmentation(
219 | input_data,
220 | preprocess_kwargs={
221 | "normalize": True, "marker": channel,
222 | "normalization_dict": normalization_dict
223 | },
224 | batch_size=batch_size
225 | )
226 | prediction = np.squeeze(prediction)
227 | if half_resolution:
228 | prediction = cv2.resize(prediction, (h, w))
229 | instance_mask = np.expand_dims(instance_mask, axis=-1)
230 | labels, mean_per_cell = segment_mean(instance_mask, prediction)
231 | if "label" not in fov_dict.keys():
232 | fov_dict["fov"] = [os.path.basename(fov_path)]*len(labels)
233 | fov_dict["label"] = labels
234 | fov_dict[channel+"_pred"] = mean_per_cell
235 | if save_predictions:
236 | os.makedirs(out_fov_path, exist_ok=True)
237 | pred_int = tf.cast(prediction*255.0, tf.uint8).numpy()
238 | io.imsave(
239 | os.path.join(out_fov_path, channel+".tiff"), pred_int,
240 | photometric="minisblack", compression="zlib"
241 | )
242 | fov_dict_list.append(pd.DataFrame(fov_dict))
243 | cell_table = pd.concat(fov_dict_list, ignore_index=True)
244 | return cell_table
245 |
--------------------------------------------------------------------------------
/src/cell_classification/loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class Loss():
5 | """ Wrapper for loss functions that allows for selective masking of the loss
6 | """
7 | def __init__(self, loss_name, selective_masking, **kwargs):
8 | """ Initialize the loss function
9 | Args:
10 | loss_name (str):
11 | name of the loss function
12 | selective_masking (bool):
13 | whether to use selective masking
14 | **kwargs:
15 | additional arguments for the loss function
16 | """
17 | self.loss_fn = getattr(tf.keras.losses, loss_name)(
18 | reduction=tf.keras.losses.Reduction.NONE, **kwargs
19 | )
20 | self.selective_masking = selective_masking
21 |
22 | def mask_out(self, loss_img, y_true):
23 | """ Selectively mask the loss by setting it to zero where y_true == -1
24 | Args:
25 | loss_img (tf.Tensor):
26 | loss image
27 | y_true (tf.Tensor):
28 | ground truth image
29 | Returns:
30 | tf.Tensor:
31 | masked loss image
32 | """
33 | y_true = tf.reshape(y_true, tf.shape(loss_img))
34 | return tf.where(y_true == 2, tf.zeros_like(loss_img), loss_img)
35 |
36 | def __call__(self, y_true, y_pred):
37 | """ Call the loss function
38 | Args:
39 | y_true (tf.Tensor):
40 | ground truth image
41 | y_pred (tf.Tensor):
42 | prediction image
43 | Returns:
44 | tf.Tensor:
45 | loss image
46 | """
47 | loss_img = self.loss_fn(y_true=tf.clip_by_value(y_true, 0, 1), y_pred=y_pred)
48 | if self.selective_masking:
49 | loss_img = self.mask_out(loss_img, y_true)
50 | return loss_img
51 |
52 | def get_config(self):
53 | """ Get the configuration of the loss function
54 | Returns:
55 | dict:
56 | configuration of the loss function
57 | """
58 | return {
59 | "loss_name": self.loss_fn.__class__.__name__,
60 | "selective_masking": self.selective_masking,
61 | **self.loss_fn.get_config(),
62 | }
63 |
--------------------------------------------------------------------------------
/src/cell_classification/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 |
4 | import h5py
5 | import numpy as np
6 | import pandas as pd
7 | from sklearn.metrics import auc, confusion_matrix, roc_curve
8 |
9 |
10 | def calc_roc(pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False):
11 | """Calculate ROC curve
12 | Args:
13 | pred_list (list):
14 | list of samples with predictions
15 | gt_key (str):
16 | key for ground truth labels
17 | pred_key (str):
18 | key for predictions
19 | Returns:
20 | roc (dict):
21 | dictionary containing ROC curve data
22 | """
23 | roc = {"fpr": [], "tpr": [], "thresholds": [], "auc": [], "marker": []}
24 | for sample in pred_list:
25 | if cell_level:
26 | # filter out cells with gt activity == 2
27 | df = sample["activity_df"].copy()
28 | df = df[df[gt_key] != 2]
29 | gt = df[gt_key].to_numpy()
30 | pred = df[pred_key].to_numpy()
31 | else:
32 | foreground = sample["binary_mask"] > 0
33 | gt = sample[gt_key][foreground].flatten()
34 | pred = sample[pred_key][foreground].flatten()
35 | if gt.size > 0 and gt.min() == 0 and gt.max() > 0: # roc is only defined for this interval
36 | fpr, tpr, thresholds = roc_curve(gt, pred)
37 | roc["fpr"].append(fpr)
38 | roc["tpr"].append(tpr)
39 | roc["thresholds"].append(thresholds)
40 | roc["auc"].append(auc(fpr, tpr))
41 | roc["marker"].append(sample["marker"])
42 | return roc
43 |
44 |
45 | def calc_scores(gt, pred, threshold):
46 | """Calculate scores for a given threshold
47 | Args:
48 | gt (np.array):
49 | ground truth labels
50 | pred (np.array):
51 | predictions
52 | threshold (float):
53 | threshold for predictions
54 | Returns:
55 | scores (dict):
56 | dictionary containing scores
57 | """
58 | # exclude masked out regions from metric calculation
59 | pred = pred[gt < 2]
60 | gt = gt[gt < 2]
61 | tn, fp, fn, tp = confusion_matrix(
62 | y_true=gt, y_pred=(pred >= threshold).astype(int), labels=[0, 1]
63 | ).ravel()
64 | metrics = {
65 | "tp": tp, "tn": tn, "fp": fp, "fn": fn,
66 | "accuracy": (tp + tn) / (tp + tn + fp + fn + 1e-8),
67 | "precision": tp / (tp + fp + 1e-8),
68 | "recall": tp / (tp + fn + 1e-8),
69 | "specificity": tn / (tn + fp + 1e-8),
70 | "f1_score": 2 * tp / (2 * tp + fp + fn + 1e-8),
71 | }
72 | return metrics
73 |
74 |
75 | def calc_metrics(
76 | pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False
77 | ):
78 | """Calculate metrics
79 | Args:
80 | pred_list (list):
81 | list of samples with predictions
82 | gt_key (str):
83 | key of ground truth in pred_list
84 | pred_key (str):
85 | key of prediction in pred_list
86 | Returns:
87 | avg_metrics (dict):
88 | dictionary containing metrics averaged over all samples
89 | """
90 | metrics_dict = {
91 | "accuracy": [], "precision": [], "recall": [], "specificity": [], "f1_score": [], "tp": [],
92 | "tn": [], "fp": [], "fn": [],
93 | }
94 |
95 | def _calc_metrics(threshold):
96 | """Helper function to calculate metrics for a given threshold in parallel"""
97 | metrics = deepcopy(metrics_dict)
98 |
99 | for sample in pred_list:
100 | if cell_level:
101 | df = sample["activity_df"]
102 | # filter out cells with gt activity == 2
103 | df = df[df[gt_key] != 2]
104 | gt = np.array(df[gt_key])
105 | pred = np.array(df[pred_key])
106 | else:
107 | foreground = sample["binary_mask"] > 0
108 | gt = sample[gt_key][foreground].flatten()
109 | pred = sample[pred_key][foreground].flatten()
110 | if gt.size == 0:
111 | continue
112 | scores = calc_scores(gt, pred, threshold)
113 |
114 | # only add specificity for samples that have no positives
115 | if np.sum(gt) == 0:
116 | keys = ["specificity"]
117 | else:
118 | keys = scores.keys()
119 | for key in keys:
120 | metrics[key].append(scores[key])
121 | metrics["threshold"] = threshold
122 | for key in ["dataset", "imaging_platform", "marker"]:
123 | metrics[key] = sample[key]
124 | return metrics
125 |
126 | # calculate metrics for all thresholds in parallel
127 | thresholds = np.linspace(0.01, 1, 50)
128 | # metric_list = Parallel(n_jobs=8)(delayed(_calc_metrics)(i) for i in thresholds)
129 | metric_list = [_calc_metrics(i) for i in thresholds]
130 | # reduce metrics over all samples for each threshold
131 | avg_metrics = deepcopy(metrics_dict)
132 | for key in ["dataset", "imaging_platform", "marker", "threshold"]:
133 | avg_metrics[key] = []
134 | for metrics in metric_list:
135 | for key in ["accuracy", "precision", "recall", "specificity", "f1_score"]:
136 | avg_metrics[key].append(np.mean(metrics[key]))
137 | for key in ["tp", "tn", "fp", "fn"]: # sum fn, fp, tn, tp
138 | avg_metrics[key].append(np.sum(metrics[key]))
139 | for key in ["dataset", "imaging_platform", "marker", "threshold"]: # copy strings
140 | avg_metrics[key].append(metrics[key])
141 | return avg_metrics
142 |
143 |
144 | def average_roc(roc_list):
145 | """Average ROC curves
146 | Args:
147 | roc_list (list):
148 | list of ROC curves
149 | Returns:
150 | tprs (np.array):
151 | standardized true positive rates for each sample
152 | mean_tprs (np.array):
153 | mean true positive rates over all samples
154 | std np.array:
155 | standard deviation of true positive rates over all samples
156 | base (np.array):
157 | fpr values for interpolation
158 | mean_thresh (np.array):
159 | mean of the threshold values over all samples
160 | """
161 | base = np.linspace(0, 1, 101)
162 | tpr_list = []
163 | thresh_list = []
164 | for i in range(len(roc_list["tpr"])):
165 | tpr_list.append(np.interp(base, roc_list["fpr"][i], roc_list["tpr"][i]))
166 | thresh_list.append(np.interp(base, roc_list["tpr"][i], roc_list["thresholds"][i]))
167 |
168 | tprs = np.array(tpr_list)
169 | thresh_list = np.array(thresh_list)
170 | mean_thresh = np.mean(thresh_list, axis=0)
171 | mean_tprs = tprs.mean(axis=0)
172 | std = tprs.std(axis=0)
173 | return tprs, mean_tprs, base, std, mean_thresh
174 |
175 |
176 | class HDF5Loader(object):
177 | """HDF5 iterator for loading data from HDF5 files"""
178 |
179 | def __init__(self, folder):
180 | """Initialize HDF5 generator
181 | Args:
182 | folder (str):
183 | path to folder containing HDF5 files
184 | """
185 | self.folder = folder
186 | self.files = os.listdir(folder)
187 | # filter out hdf files
188 | self.files = [os.path.join(folder, f) for f in self.files if f.endswith(".hdf")]
189 | self.file_idx = 0
190 |
191 | def __len__(self):
192 | return len(self.files)
193 |
194 | def load_hdf(self, file):
195 | """Load HDF5 file
196 | Args:
197 | file (str):
198 | path to HDF5 file
199 | Returns:
200 | data (dict):
201 | dictionary containing data from HDF5 file
202 | """
203 | out_dict = {}
204 | with h5py.File(file, "r") as f:
205 | keys = [key for key in f.keys() if key != "activity_df"]
206 | for key in keys:
207 | if isinstance(f[key][()], bytes):
208 | out_dict[key] = f[key][()].decode("utf-8")
209 | else:
210 | out_dict[key] = f[key][()]
211 | out_dict["activity_df"] = pd.read_json(f["activity_df"][()].decode())
212 | return out_dict
213 |
214 | def __iter__(self):
215 | self.file_idx = 0
216 | return self
217 |
218 | def __next__(self):
219 | if self.file_idx >= len(self.files):
220 | raise StopIteration
221 | else:
222 | self.file_idx += 1
223 | return self.load_hdf(self.files[self.file_idx - 1])
224 |
--------------------------------------------------------------------------------
/src/cell_classification/post_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def process_to_cells(instance_mask, prediction):
6 | """Process predictions from pixel level to cell level
7 | Args:
8 | instance_mask (np.array):
9 | 2D array of instance mask
10 | prediction (np.array):
11 | 2D array of pixel level prediction
12 | Returns:
13 | np.array:
14 | 2D array of cell level averaged prediction
15 | pd.DataFrame:
16 | DataFrame of cell level predictions
17 | """
18 |
19 | unique_labels = np.unique(instance_mask)
20 | unique_labels = unique_labels[unique_labels != 0]
21 | mean_per_cell_mask = np.zeros_like(instance_mask, dtype=np.float32)
22 | df = pd.DataFrame(columns=['labels', 'pred_activity'])
23 | i = 0
24 | for unique_label in unique_labels:
25 | mask = instance_mask == unique_label
26 | mean_pred = prediction[mask].mean()
27 | mean_per_cell_mask[mask] = mean_pred
28 | df = pd.concat(
29 | [df, pd.DataFrame(
30 | {'labels': [unique_label], 'pred_activity': [mean_pred]}, index=[i]
31 | )],
32 | )
33 | i += 1
34 | return mean_per_cell_mask, df
35 |
36 |
37 | def merge_activity_df(gt_df, pred_df):
38 | """Merge ground truth and prediction dataframes over labels
39 | Args:
40 | gt_df (pd.DataFrame):
41 | DataFrame of ground truth
42 | pred_df (pd.DataFrame):
43 | DataFrame of prediction
44 | Returns:
45 | pd.DataFrame:
46 | DataFrame of merged ground truth and prediction
47 | """
48 | pred_df.labels = pred_df.labels.astype(int)
49 | gt_df.labels = gt_df.labels.astype(int)
50 | df = gt_df.merge(pred_df, on='labels', how='left')
51 | return df
52 |
--------------------------------------------------------------------------------
/src/cell_classification/prepare_data_script_MSK_colon.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from cell_classification.simple_data_prep import SimpleTFRecords
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6 |
7 |
8 | def naming_convention(fname):
9 | return os.path.join(
10 | "E:/angelo_lab/data/MSKCC_colon/segmentation",
11 | fname + "feature_0.ome.tif"
12 | )
13 |
14 |
15 | data_prep = SimpleTFRecords(
16 | data_dir=os.path.normpath(
17 | "E:/angelo_lab/data/MSKCC_colon/raw_structured"
18 | ),
19 | cell_table_path=os.path.normpath(
20 | "E:/angelo_lab/data/MSKCC_colon/cell_table.csv"
21 | ),
22 | imaging_platform="Vectra",
23 | dataset="MSK_colon",
24 | tissue_type="colon",
25 | nuclei_channels=["DAPI"],
26 | membrane_channels=["CD3", "CD8", "ICOS", "panCK+CK7+CAM5.2"],
27 | tile_size=[256, 256],
28 | stride=[240, 240],
29 | tf_record_path=os.path.normpath("E:/angelo_lab/data/MSKCC_colon"),
30 | normalization_quantile=0.999,
31 | selected_markers=["CD3", "CD8", "Foxp3", "ICOS", "panCK+CK7+CAM5.2", "PD-L1"],
32 | sample_key="fov",
33 | segment_label_key="labels",
34 | segmentation_naming_convention=naming_convention,
35 | exclude_background_tiles=True,
36 | img_suffix=".ome.tif",
37 | # normalization_dict_path=os.path.normpath(
38 | # "E:/angelo_lab/data/MSKCC_colon/normalization_dict.json"
39 | # ),
40 | )
41 |
42 | data_prep.make_tf_record()
43 |
--------------------------------------------------------------------------------
/src/cell_classification/prepare_data_script_MSK_pancreas.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from cell_classification.simple_data_prep import SimpleTFRecords
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6 |
7 |
8 | def naming_convention(fname):
9 | return os.path.join(
10 | "E:/angelo_lab/data/MSKCC_pancreas/segmentation",
11 | fname + "feature_0.ome.tif"
12 | )
13 |
14 |
15 | data_prep = SimpleTFRecords(
16 | data_dir=os.path.normpath(
17 | "E:/angelo_lab/data/MSKCC_pancreas/raw_structured"
18 | ),
19 | cell_table_path=os.path.normpath(
20 | "E:/angelo_lab/data/MSKCC_pancreas/cell_table.csv"
21 | ),
22 | imaging_platform="Vectra",
23 | dataset="MSK_pancreas",
24 | tissue_type="pancreas",
25 | nuclei_channels=["DAPI"],
26 | membrane_channels=["CD8", "CD40", "CD40-L", "panCK"],
27 | tile_size=[256, 256],
28 | stride=[240, 240],
29 | tf_record_path=os.path.normpath("E:/angelo_lab/data/MSKCC_pancreas"),
30 | normalization_quantile=0.999,
31 | selected_markers=["CD8", "CD40", "CD40-L", "panCK", "PD-1", "PD-L1"],
32 | sample_key="fov",
33 | segment_label_key="labels",
34 | segmentation_naming_convention=naming_convention,
35 | exclude_background_tiles=True,
36 | img_suffix=".ome.tif",
37 | gt_suffix="",
38 | normalization_dict_path=os.path.normpath(
39 | "E:/angelo_lab/data/MSKCC_pancreas/normalization_dict.json"
40 | ),
41 | )
42 |
43 | data_prep.make_tf_record()
44 |
--------------------------------------------------------------------------------
/src/cell_classification/prepare_data_script_decidua.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from cell_classification.segmentation_data_prep import SegmentationTFRecords
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6 |
7 |
8 | def naming_convention(fname):
9 | return os.path.join(
10 | "E:/angelo_lab/data/decidua/segmentation_data/",
11 | fname + "_segmentation_labels.tiff",
12 | )
13 |
14 |
15 | data_prep = SegmentationTFRecords(
16 | data_dir=os.path.normpath("E:/angelo_lab/data/decidua/image_data"),
17 | cell_table_path=os.path.normpath(
18 | "E:/angelo_lab/data/decidua/"
19 | "Supplementary_table_3_single_cells_updated.csv"
20 | ),
21 | conversion_matrix_path=os.path.normpath(
22 | "E:/angelo_lab/data/decidua/conversion_matrix.csv"
23 | ),
24 | imaging_platform="MIBI",
25 | dataset="decidua_erin",
26 | tissue_type="decidua",
27 | nuclei_channels=["H3"],
28 | membrane_channels=["VIM", "HLAG", "CD3", "CD14", "CD56"],
29 | tile_size=[256, 256],
30 | stride=[240, 240],
31 | tf_record_path=os.path.normpath("E:/angelo_lab/data/decidua"),
32 | normalization_dict_path=os.path.normpath(
33 | "E:/angelo_lab/data/decidua/normalization_dict.json"
34 | ),
35 | selected_markers=[
36 | "CD45", "CD14", "HLADR", "CD11c", "DCSIGN", "CD68", "CD206", "CD163", "CD3", "Ki67", "IDO",
37 | "CD8", "CD4", "CD16", "CD56", "CD57", "SMA", "VIM", "CD31", "CK7", "HLAG", "FoxP3", "PDL1",
38 | ],
39 | normalization_quantile=0.999,
40 | cell_type_key="lineage",
41 | sample_key="Point",
42 | segmentation_naming_convention=naming_convention,
43 | segment_label_key="cell_ID_in_Point",
44 | exclude_background_tiles=True,
45 | img_suffix=".tif",
46 | )
47 |
48 | data_prep.make_tf_record()
49 |
--------------------------------------------------------------------------------
/src/cell_classification/prepare_data_script_tonic.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from cell_classification.segmentation_data_prep import SegmentationTFRecords
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6 |
7 |
8 | def naming_convention(fname):
9 | return os.path.join(
10 | "E:/angelo_lab/data/TONIC/raw/segmentation_data/deepcell_output",
11 | fname + "_feature_0.tif"
12 | )
13 |
14 |
15 | data_prep = SegmentationTFRecords(
16 | data_dir=os.path.normpath(
17 | "E:/angelo_lab/data/TONIC/raw/image_data/samples"
18 | ),
19 | cell_table_path=os.path.normpath(
20 | "E:/angelo_lab/data/TONIC/raw/" +
21 | "combined_cell_table_normalized_cell_labels_updated.csv"
22 | ),
23 | conversion_matrix_path=os.path.normpath(
24 | "E:/angelo_lab/data/TONIC/raw/TONIC_conversion_matrix.csv"
25 | ),
26 | imaging_platform="MIBI",
27 | dataset="TONIC",
28 | tissue_type="TNBC",
29 | nuclei_channels=["H3K27me3", "H3K9ac"],
30 | membrane_channels=["CD45", "ECAD", "CD14", "CD38", "CK17"],
31 | tile_size=[256, 256],
32 | stride=[240, 240],
33 | tf_record_path=os.path.normpath("E:/angelo_lab/data/TONIC"),
34 | # normalization_dict_path=os.path.normpath(
35 | # "E:/angelo_lab/data/TONIC/normalization_dict.json"
36 | # ),
37 | normalization_quantile=0.999,
38 | cell_type_key="cell_meta_cluster",
39 | sample_key="fov",
40 | segmentation_fname="cell_segmentation",
41 | segmentation_naming_convention=naming_convention,
42 | segment_label_key="label",
43 | exclude_background_tiles=True,
44 | )
45 |
46 | data_prep.make_tf_record()
47 |
--------------------------------------------------------------------------------
/src/cell_classification/prepare_data_script_tonic_annotated.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from cell_classification.simple_data_prep import SimpleTFRecords
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6 |
7 |
8 | def naming_convention(fname):
9 | return os.path.join(
10 | "E:/angelo_lab/data/TONIC/raw/segmentation_data/deepcell_output",
11 | fname + "_feature_0.tif"
12 | )
13 |
14 |
15 | data_prep = SimpleTFRecords(
16 | data_dir=os.path.normpath(
17 | "C:/Users/Lorenz/OneDrive - Charité - Universitätsmedizin Berlin/cell_classification/" +
18 | "data_annotation/tonic/raw"
19 | ),
20 | cell_table_path=os.path.normpath(
21 | "C:/Users/Lorenz/OneDrive - Charité - Universitätsmedizin Berlin/cell_classification/" +
22 | "data_annotation/tonic/ground_truth.csv"
23 | ),
24 | imaging_platform="MIBI",
25 | dataset="TONIC",
26 | tissue_type="TNBC",
27 | nuclei_channels=["H3K27me3", "H3K9ac"],
28 | membrane_channels=["CD45", "ECAD", "CD14", "CD38", "CK17"],
29 | selected_markers=[
30 | "Calprotectin", "CD14", "CD163", "CD20", "CD3", "CD31", "CD4", "CD45", "CD56", "CD68",
31 | "CD8", "ChyTr", "CK17", "Collagen1", "ECAD", "FAP", "Fibronectin", "FOXP3", "HLADR", "SMA",
32 | "VIM"
33 | ],
34 | tile_size=[256, 256],
35 | stride=[240, 240],
36 | tf_record_path=os.path.normpath("E:/angelo_lab/data/TONIC/annotated"),
37 | normalization_dict_path=os.path.normpath(
38 | "E:/angelo_lab/data/TONIC/normalization_dict.json"
39 | ),
40 | normalization_quantile=0.999,
41 | sample_key="fov",
42 | segmentation_fname="cell_segmentation",
43 | segmentation_naming_convention=naming_convention,
44 | segment_label_key="labels",
45 | exclude_background_tiles=True,
46 | gt_suffix="",
47 | )
48 |
49 | data_prep.make_tf_record()
50 |
--------------------------------------------------------------------------------
/src/cell_classification/simple_data_prep.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from cell_classification.segmentation_data_prep import SegmentationTFRecords
4 |
5 |
6 | class SimpleTFRecords(SegmentationTFRecords):
7 | """Prepares the data for the segmentation model"""
8 | def __init__(
9 | self, data_dir, cell_table_path, imaging_platform, dataset, tissue_type, tile_size, stride,
10 | tf_record_path, nuclei_channels, membrane_channels, selected_markers=None,
11 | normalization_dict_path=None, normalization_quantile=0.999,
12 | segmentation_naming_convention=None, segmentation_fname="cell_segmentation",
13 | exclude_background_tiles=False, resize=None, img_suffix=".tiff", sample_key="SampleID",
14 | segment_label_key="labels", gt_suffix="_gt",
15 |
16 | ):
17 | """Initializes SegmentationTFRecords and loads everything except the images
18 |
19 | Args:
20 | data_dir str:
21 | Path where the data is stored
22 | cell_table_path (str):
23 | Path to the cell table
24 | imaging_platform (str):
25 | The imaging platform used to generate the multiplexed imaging data
26 | dataset (str):
27 | The dataset where the imaging data comes from
28 | tissue_type (str):
29 | The tissue type of the imaging data
30 | tile_size list [int,int]:
31 | The size of the tiles to use for the segmentation model
32 | stride list [int,int]:
33 | The stride to tile the data
34 | tf_record_path (str):
35 | The path to the tf record to make
36 | selected_markers (list):
37 | The markers of interest for generating the tf record. If None, all markers
38 | mentioned in the conversion_matrix are used
39 | normalization_dict_path (str):
40 | Path to the normalization dict json
41 | normalization_quantile (float):
42 | The quantile to use for normalization of multiplexed data
43 | segmentation_naming_convention (Function):
44 | Function that takes in the sample name and returns the path to the segmentation
45 | .tiff file. Default is None, then it is assumed that the segmentation file is in
46 | the sample folder and is named $segmentation_fname.tiff
47 | exclude_background_tiles (bool):
48 | Whether to exclude the all tiles that only contain background
49 | resize (float):
50 | The resize factor to use for the images
51 | img_suffix (str):
52 | The suffix of the image files
53 | sample_key (str):
54 | The key in the cell table that contains the sample name
55 | segment_label_key (str):
56 | The key in the cell table that contains the segmentation labels
57 | gt_suffix (str):
58 | The suffix of the ground truth column in the cell_table
59 | """
60 | super().__init__(
61 | data_dir=data_dir, cell_table_path=cell_table_path, imaging_platform=imaging_platform,
62 | dataset=dataset, tile_size=tile_size, stride=stride, tf_record_path=tf_record_path,
63 | nuclei_channels=nuclei_channels, membrane_channels=membrane_channels,
64 | selected_markers=selected_markers, normalization_dict_path=normalization_dict_path,
65 | normalization_quantile=normalization_quantile, segmentation_fname=segmentation_fname,
66 | segmentation_naming_convention=segmentation_naming_convention, resize=resize,
67 | exclude_background_tiles=exclude_background_tiles, img_suffix=img_suffix,
68 | sample_key=sample_key, segment_label_key=segment_label_key, tissue_type=tissue_type,
69 | conversion_matrix_path=None,
70 | )
71 | self.selected_markers = selected_markers
72 | self.data_dir = data_dir
73 | self.normalization_dict_path = normalization_dict_path
74 | self.normalization_quantile = normalization_quantile
75 | self.segmentation_fname = segmentation_fname
76 | self.segment_label_key = segment_label_key
77 | self.sample_key = sample_key
78 | self.dataset = dataset
79 | self.tissue_type = tissue_type
80 | self.imaging_platform = imaging_platform
81 | self.tf_record_path = tf_record_path
82 | self.cell_table_path = cell_table_path
83 | self.tile_size = tile_size
84 | self.stride = stride
85 | self.segmentation_naming_convention = segmentation_naming_convention
86 | self.exclude_background_tiles = exclude_background_tiles
87 | self.resize = resize
88 | self.img_suffix = img_suffix
89 | self.gt_suffix = gt_suffix
90 | self.nuclei_channels = nuclei_channels
91 | self.membrane_channels = membrane_channels
92 |
93 | def get_marker_activity(self, sample_name, marker):
94 | """Gets the marker activity for the given labels
95 | Args:
96 | sample_name (str):
97 | The name of the sample
98 | conversion_matrix (pd.DataFrame):
99 | The conversion matrix to use for the lookup
100 | marker (str, list):
101 | The markers to get the activity for
102 | Returns:
103 | np.array:
104 | The marker activity for the given labels, 1 if the marker is active, 0
105 | otherwise and -1 if the marker is not specific enough to be considered active
106 | """
107 |
108 | df = pd.DataFrame(
109 | {
110 | "labels": self.sample_subset[self.segment_label_key],
111 | "activity": self.sample_subset[marker + self.gt_suffix],
112 | }
113 | )
114 | return df, None
115 |
116 | def check_additional_inputs(self):
117 | """Checks additional inputs for correctness"""
118 |
--------------------------------------------------------------------------------
/src/cell_classification/viewer_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ipywidgets as widgets
3 | from IPython.display import display
4 | from io import BytesIO
5 | from skimage import io
6 | from copy import copy
7 | import numpy as np
8 | from natsort import natsorted
9 |
10 |
11 | class NimbusViewer(object):
12 | def __init__(self, input_dir, output_dir, img_width='600px'):
13 | """Viewer for Nimbus application.
14 | Args:
15 | input_dir (str): Path to directory containing individual channels of multiplexed images
16 | output_dir (str): Path to directory containing output of Nimbus application.
17 | img_width (str): Width of images in viewer.
18 | """
19 | self.image_width = img_width
20 | self.input_dir = input_dir
21 | self.output_dir = output_dir
22 | self.fov_names = [os.path.basename(p) for p in os.listdir(output_dir) if \
23 | os.path.isdir(os.path.join(output_dir, p))]
24 | self.fov_names = natsorted(self.fov_names)
25 | self.update_button = widgets.Button(description="Update Image")
26 | self.update_button.on_click(self.update_button_click)
27 |
28 | self.fov_select = widgets.Select(
29 | options=self.fov_names,
30 | description='FOV:',
31 | disabled=False
32 | )
33 | self.fov_select.observe(self.select_fov, names='value')
34 |
35 | self.red_select = widgets.Select(
36 | options=[],
37 | description='Red:',
38 | disabled=False
39 | )
40 | self.green_select = widgets.Select(
41 | options=[],
42 | description='Green:',
43 | disabled=False
44 | )
45 | self.blue_select = widgets.Select(
46 | options=[],
47 | description='Blue:',
48 | disabled=False
49 | )
50 | self.input_image = widgets.Image()
51 | self.output_image = widgets.Image()
52 |
53 | def select_fov(self, change):
54 | """Selects fov to display.
55 | Args:
56 | change (dict): Change dictionary from ipywidgets.
57 | """
58 | fov_path = os.path.join(self.output_dir, self.fov_select.value)
59 | channels = [
60 | ch for ch in os.listdir(fov_path) if os.path.isfile(os.path.join(fov_path, ch))
61 | ]
62 | self.red_select.options = natsorted(channels)
63 | self.green_select.options = natsorted(channels)
64 | self.blue_select.options = natsorted(channels)
65 |
66 | def create_composite_image(self, path_dict):
67 | """Creates composite image from input paths.
68 | Args:
69 | path_dict (dict): Dictionary of paths to images.
70 | Returns:
71 | composite_image (np.array): Composite image.
72 | """
73 | output_image = []
74 | img = None
75 | for k, p in path_dict.items():
76 | if p:
77 | img = io.imread(p)
78 | output_image.append(img)
79 | if p is None:
80 | non_none = [p for p in path_dict.values() if p]
81 | if not img:
82 | img = io.imread(non_none[0])
83 | output_image.append(img*0)
84 |
85 | composite_image = np.stack(output_image, axis=-1)
86 | return composite_image
87 |
88 | def layout(self):
89 | """Creates layout for viewer."""
90 | channel_selectors = widgets.VBox([
91 | self.red_select,
92 | self.green_select,
93 | self.blue_select
94 | ])
95 | self.input_image.layout.width = self.image_width
96 | self.output_image.layout.width = self.image_width
97 | viewer_html = widgets.HTML("
Select files
")
98 | input_html = widgets.HTML("Input
")
99 | output_html = widgets.HTML("Nimbus Output
")
100 |
101 | layout = widgets.HBox([
102 | widgets.VBox([
103 | viewer_html,
104 | self.fov_select,
105 | channel_selectors,
106 | self.update_button
107 | ]),
108 | widgets.VBox([
109 | input_html,
110 | self.input_image
111 | ]),
112 | widgets.VBox([
113 | output_html,
114 | self.output_image
115 | ])
116 | ])
117 | display(layout)
118 |
119 | def search_for_similar(self, select_value):
120 | """Searches for similar filename in input directory.
121 | Args:
122 | select_value (str): Filename to search for.
123 | Returns:
124 | similar_path (str): Path to similar filename.
125 | """
126 | in_f_path = os.path.join(self.input_dir, self.fov_select.value)
127 | # search for similar filename in in_f_path
128 | in_f_files = [
129 | f for f in os.listdir(in_f_path) if os.path.isfile(os.path.join(in_f_path, f))
130 | ]
131 | similar_path = None
132 | for f in in_f_files:
133 | if select_value.split(".")[0]+"." in f:
134 | similar_path = os.path.join(self.input_dir, self.fov_select.value, f)
135 | return similar_path
136 |
137 | def update_img(self, image_viewer, composite_image):
138 | """Updates image in viewer by saving it as png and loading it with the viewer widget.
139 | Args:
140 | image_viewer (ipywidgets.Image): Image widget to update.
141 | composite_image (np.array): Composite image to display.
142 | """
143 | # Convert composite image to bytes and assign it to the output_image widget
144 | with BytesIO() as output_buffer:
145 | io.imsave(output_buffer, composite_image, format="png")
146 | output_buffer.seek(0)
147 | image_viewer.value = output_buffer.read()
148 |
149 | def update_composite(self):
150 | """Updates composite image in viewer."""
151 | path_dict = {
152 | "red": None,
153 | "green": None,
154 | "blue": None
155 | }
156 | in_path_dict = copy(path_dict)
157 | if self.red_select.value:
158 | path_dict["red"] = os.path.join(
159 | self.output_dir, self.fov_select.value, self.red_select.value
160 | )
161 | in_path_dict["red"] = self.search_for_similar(self.red_select.value)
162 | if self.green_select.value:
163 | path_dict["green"] = os.path.join(
164 | self.output_dir, self.fov_select.value, self.green_select.value
165 | )
166 | in_path_dict["green"] = self.search_for_similar(self.green_select.value)
167 | if self.blue_select.value:
168 | path_dict["blue"] = os.path.join(
169 | self.output_dir, self.fov_select.value, self.blue_select.value
170 | )
171 | in_path_dict["blue"] = self.search_for_similar(self.blue_select.value)
172 | non_none = [p for p in path_dict.values() if p]
173 | if not non_none:
174 | return
175 | composite_image = self.create_composite_image(path_dict)
176 | in_composite_image = self.create_composite_image(in_path_dict)
177 | in_composite_image = in_composite_image / np.quantile(
178 | in_composite_image, 0.999, axis=(0,1)
179 | )
180 | in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8)
181 | # update image viewers
182 | self.update_img(self.input_image, in_composite_image)
183 | self.update_img(self.output_image, composite_image)
184 |
185 | def update_button_click(self, button):
186 | """Updates composite image in viewer when update button is clicked."""
187 | self.update_composite()
188 |
189 | def display(self):
190 | """Displays viewer."""
191 | self.select_fov(None)
192 | self.layout()
193 | self.update_composite()
194 |
--------------------------------------------------------------------------------
/src/deepcell/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/src/deepcell/__init__.py
--------------------------------------------------------------------------------
/src/deepcell/panopticnet.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of
3 | # Technology (Caltech), with support from the Paul Allen Family Foundation,
4 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
5 | # All rights reserved.
6 | #
7 | # Licensed under a modified Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
12 | #
13 | # The Work provided may be used for non-commercial academic purposes only.
14 | # For any other use of the Work, including commercial use, please contact:
15 | # vanvalenlab@gmail.com
16 | #
17 | # Neither the name of Caltech nor the names of its contributors may be used
18 | # to endorse or promote products derived from this software without specific
19 | # prior written permission.
20 | #
21 | # Unless required by applicable law or agreed to in writing, software
22 | # distributed under the License is distributed on an "AS IS" BASIS,
23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | # See the License for the specific language governing permissions and
25 | # limitations under the License.
26 | # ==============================================================================
27 | """Feature pyramid network utility functions"""
28 |
29 |
30 | import math
31 | import re
32 |
33 | from tensorflow.keras import backend as K
34 | from tensorflow.keras.models import Model
35 | from tensorflow.keras.layers import Conv2D, Conv3D
36 | from tensorflow.keras.layers import TimeDistributed, ConvLSTM2D
37 | from tensorflow.keras.layers import Input, Concatenate
38 | from tensorflow.keras.layers import Activation, BatchNormalization
39 |
40 | from deepcell.layers import ImageNormalization2D, Location2D
41 | from deepcell.fpn import __create_pyramid_features
42 | from deepcell.fpn import __create_semantic_head
43 | from deepcell.backbone_utils import get_backbone
44 |
45 |
46 | def __merge_temporal_features(feature, mode='conv', feature_size=256,
47 | frames_per_batch=1):
48 | """Merges feature with its temporal residual through addition.
49 |
50 | Input feature (x) --> Temporal convolution* --> Residual feature (x')
51 | *Type of temporal convolution specified by ``mode``.
52 |
53 | Output: ``y = x + x'``
54 |
55 | Args:
56 | feature (tensorflow.keras.Layer): Input layer
57 | mode (str): Mode of temporal convolution. One of
58 | ``{'conv','lstm', None}``.
59 | feature_size (int): Length of convolutional kernel
60 | frames_per_batch (int): Size of z axis in generated batches.
61 | If equal to 1, assumes 2D data.
62 |
63 | Raises:
64 | ValueError: ``mode`` not 'conv', 'lstm' or ``None``
65 |
66 | Returns:
67 | tensorflow.keras.Layer: Input feature merged with its residual
68 | from a temporal convolution. If mode is ``None``,
69 | the output is exactly the input.
70 | """
71 | # Check inputs to mode
72 | acceptable_modes = {'conv', 'lstm', None}
73 | if mode is not None:
74 | mode = str(mode).lower()
75 | if mode not in acceptable_modes:
76 | raise ValueError(f'Mode {mode} not supported. Please choose '
77 | f'from {str(acceptable_modes)}.')
78 |
79 | f_name = str(feature.name)[:2]
80 |
81 | if mode == 'conv':
82 | x = Conv3D(feature_size,
83 | (frames_per_batch, 3, 3),
84 | strides=(1, 1, 1),
85 | padding='same',
86 | name=f'conv3D_mtf_{f_name}',
87 | )(feature)
88 | x = BatchNormalization(axis=-1, name=f'bnorm_mtf_{f_name}')(x)
89 | x = Activation('relu', name=f'acti_mtf_{f_name}')(x)
90 | elif mode == 'lstm':
91 | x = ConvLSTM2D(feature_size,
92 | (3, 3),
93 | padding='same',
94 | activation='relu',
95 | return_sequences=True,
96 | name=f'convLSTM_mtf_{f_name}')(feature)
97 | else:
98 | x = feature
99 |
100 | temporal_feature = x
101 |
102 | return temporal_feature
103 |
104 |
105 | def PanopticNet(backbone,
106 | input_shape,
107 | inputs=None,
108 | backbone_levels=['C3', 'C4', 'C5'],
109 | pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
110 | create_pyramid_features=__create_pyramid_features,
111 | create_semantic_head=__create_semantic_head,
112 | frames_per_batch=1,
113 | temporal_mode=None,
114 | num_semantic_classes=[3],
115 | required_channels=3,
116 | norm_method=None,
117 | pooling=None,
118 | location=True,
119 | use_imagenet=True,
120 | lite=False,
121 | upsample_type='upsampling2d',
122 | interpolation='bilinear',
123 | name='panopticnet',
124 | z_axis_convolutions=False,
125 | **kwargs):
126 | """Constructs a Mask-RCNN model using a backbone from
127 | ``keras-applications`` with optional semantic segmentation transforms.
128 |
129 | Args:
130 | backbone (str): Name of backbone to use.
131 | input_shape (tuple): The shape of the input data.
132 | backbone_levels (list): The backbone levels to be used.
133 | to create the feature pyramid.
134 | pyramid_levels (list): Pyramid levels to use.
135 | create_pyramid_features (function): Function to get the pyramid
136 | features from the backbone.
137 | create_semantic_head (function): Function to build a semantic head
138 | submodel.
139 | frames_per_batch (int): Size of z axis in generated batches.
140 | If equal to 1, assumes 2D data.
141 | temporal_mode: Mode of temporal convolution. Choose from
142 | ``{'conv','lstm', None}``.
143 | num_semantic_classes (list or dict): Number of semantic classes
144 | for each semantic head. If a ``dict``, keys will be used as
145 | head names and values will be the number of classes.
146 | norm_method (str): Normalization method to use with the
147 | :mod:`deepcell.layers.normalization.ImageNormalization2D` layer.
148 | location (bool): Whether to include a
149 | :mod:`deepcell.layers.location.Location2D` layer.
150 | use_imagenet (bool): Whether to load imagenet-based pretrained weights.
151 | lite (bool): Whether to use a depthwise conv in the feature pyramid
152 | rather than regular conv.
153 | upsample_type (str): Choice of upsampling layer to use from
154 | ``['upsamplelike', 'upsampling2d', 'upsampling3d']``.
155 | interpolation (str): Choice of interpolation mode for upsampling
156 | layers from ``['bilinear', 'nearest']``.
157 | pooling (str): optional pooling mode for feature extraction
158 | when ``include_top`` is ``False``.
159 |
160 | - None means that the output of the model will be
161 | the 4D tensor output of the
162 | last convolutional layer.
163 | - 'avg' means that global average pooling
164 | will be applied to the output of the
165 | last convolutional layer, and thus
166 | the output of the model will be a 2D tensor.
167 | - 'max' means that global max pooling will
168 | be applied.
169 |
170 | z_axis_convolutions (bool): Whether or not to do convolutions on
171 | 3D data across the z axis.
172 | required_channels (int): The required number of channels of the
173 | backbone. 3 is the default for all current backbones.
174 | kwargs (dict): Other standard inputs for ``retinanet_mask``.
175 |
176 | Raises:
177 | ValueError: ``temporal_mode`` not 'conv', 'lstm' or ``None``
178 |
179 | Returns:
180 | tensorflow.keras.Model: Panoptic model with a backbone.
181 | """
182 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
183 | conv = Conv3D if frames_per_batch > 1 else Conv2D
184 | conv_kernel = (1, 1, 1) if frames_per_batch > 1 else (1, 1)
185 |
186 | # Check input to __merge_temporal_features
187 | acceptable_modes = {'conv', 'lstm', None}
188 | if temporal_mode is not None:
189 | temporal_mode = str(temporal_mode).lower()
190 | if temporal_mode not in acceptable_modes:
191 | raise ValueError(f'temporal_mode {temporal_mode} not supported. Please choose '
192 | f'from {acceptable_modes}.')
193 |
194 | # TODO only works for 2D: do we check for 3D as well?
195 | # What are the requirements for 3D data?
196 | img_shape = input_shape[1:] if channel_axis == 1 else input_shape[:-1]
197 | if img_shape[0] != img_shape[1]:
198 | raise ValueError(f'Input data must be square, got dimensions {img_shape}')
199 |
200 | if not math.log(img_shape[0], 2).is_integer():
201 | raise ValueError('Input data dimensions must be a power of 2, '
202 | f'got {img_shape[0]}')
203 |
204 | # Check input to interpolation
205 | acceptable_interpolation = {'bilinear', 'nearest'}
206 | if interpolation not in acceptable_interpolation:
207 | raise ValueError(f'Interpolation mode "{interpolation}" not supported. '
208 | f'Choose from {list(acceptable_interpolation)}.')
209 |
210 | if inputs is None:
211 | if frames_per_batch > 1:
212 | if channel_axis == 1:
213 | input_shape_with_time = tuple(
214 | [input_shape[0], frames_per_batch] + list(input_shape)[1:])
215 | else:
216 | input_shape_with_time = tuple(
217 | [frames_per_batch] + list(input_shape))
218 | inputs = Input(shape=input_shape_with_time, name='input_0')
219 | else:
220 | inputs = Input(shape=input_shape, name='input_0')
221 |
222 | # Normalize input images
223 | if norm_method is None:
224 | norm = inputs
225 | else:
226 | if frames_per_batch > 1:
227 | norm = TimeDistributed(ImageNormalization2D(
228 | norm_method=norm_method, name='norm'), name='td_norm')(inputs)
229 | else:
230 | norm = ImageNormalization2D(norm_method=norm_method,
231 | name='norm')(inputs)
232 |
233 | # Add location layer
234 | if location:
235 | if frames_per_batch > 1:
236 | # TODO: TimeDistributed is incompatible with channels_first
237 | loc = TimeDistributed(Location2D(name='location'),
238 | name='td_location')(norm)
239 | else:
240 | loc = Location2D(name='location')(norm)
241 | concat = Concatenate(axis=channel_axis,
242 | name='concatenate_location')([norm, loc])
243 | else:
244 | concat = norm
245 |
246 | # Force the channel size for backbone input to be `required_channels`
247 | fixed_inputs = conv(required_channels, conv_kernel, strides=1,
248 | padding='same', name='conv_channels')(concat)
249 |
250 | # Force the input shape
251 | axis = 0 if K.image_data_format() == 'channels_first' else -1
252 | fixed_input_shape = list(input_shape)
253 | fixed_input_shape[axis] = required_channels
254 | fixed_input_shape = tuple(fixed_input_shape)
255 |
256 | model_kwargs = {
257 | 'include_top': False,
258 | 'weights': None,
259 | 'input_shape': fixed_input_shape,
260 | 'pooling': pooling
261 | }
262 |
263 | _, backbone_dict = get_backbone(backbone, fixed_inputs,
264 | use_imagenet=use_imagenet,
265 | frames_per_batch=frames_per_batch,
266 | return_dict=True,
267 | **model_kwargs)
268 |
269 | backbone_dict_reduced = {k: backbone_dict[k] for k in backbone_dict
270 | if k in backbone_levels}
271 |
272 | ndim = 2 if frames_per_batch == 1 else 3
273 |
274 | pyramid_dict = create_pyramid_features(backbone_dict_reduced,
275 | ndim=ndim,
276 | lite=lite,
277 | interpolation=interpolation,
278 | upsample_type=upsample_type,
279 | z_axis_convolutions=z_axis_convolutions)
280 |
281 | features = [pyramid_dict[key] for key in pyramid_levels]
282 |
283 | if frames_per_batch > 1:
284 | temporal_features = [__merge_temporal_features(f, mode=temporal_mode,
285 | frames_per_batch=frames_per_batch)
286 |
287 | for f in features]
288 | for f, k in zip(temporal_features, pyramid_levels):
289 | pyramid_dict[k] = f
290 |
291 | semantic_levels = [int(re.findall(r'\d+', k)[0]) for k in pyramid_dict]
292 | target_level = min(semantic_levels)
293 |
294 | semantic_head_list = []
295 | if not isinstance(num_semantic_classes, dict):
296 | num_semantic_classes = {
297 | k: v for k, v in enumerate(num_semantic_classes)
298 | }
299 |
300 | for k, v in num_semantic_classes.items():
301 | semantic_head_list.append(create_semantic_head(
302 | pyramid_dict, n_classes=v,
303 | input_target=inputs, target_level=target_level,
304 | semantic_id=k, ndim=ndim, upsample_type=upsample_type,
305 | interpolation=interpolation, **kwargs))
306 |
307 | outputs = semantic_head_list
308 |
309 | model = Model(inputs=inputs, outputs=outputs, name=name)
310 | return model
311 |
--------------------------------------------------------------------------------
/src/deepcell/semantic_head.py:
--------------------------------------------------------------------------------
1 | from deepcell.fpn import semantic_upsample
2 | from deepcell.utils import get_sorted_keys
3 | from tensorflow.keras import backend as K
4 | from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D,
5 | Conv3D, Softmax)
6 |
7 |
8 | def create_semantic_head(
9 | pyramid_dict, input_target=None, n_classes=3, n_filters=128, n_dense=128, semantic_id=0,
10 | ndim=2, include_top=True, target_level=2, upsample_type="upsamplelike",
11 | interpolation="bilinear", **kwargs
12 | ):
13 | """Creates a semantic head from a feature pyramid network.
14 | Args:
15 | pyramid_dict (dict): Dictionary of pyramid names and features.
16 | input_target (tensor): Optional tensor with the input image.
17 | n_classes (int): The number of classes to be predicted.
18 | n_filters (int): The number of convolutional filters.
19 | n_dense (int): Number of dense filters.
20 | semantic_id (int): ID of the semantic head.
21 | ndim (int): The spatial dimensions of the input data.
22 | Must be either 2 or 3.
23 | include_top (bool): Whether to include the final layer of the model
24 | target_level (int): The level we need to reach. Performs
25 | 2x upsampling until we're at the target level.
26 | upsample_type (str): Choice of upsampling layer to use from
27 | ``['upsamplelike', 'upsampling2d', 'upsampling3d']``.
28 | interpolation (str): Choice of interpolation mode for upsampling
29 | layers from ``['bilinear', 'nearest']``.
30 | Raises:
31 | ValueError: ``ndim`` must be 2 or 3
32 | ValueError: ``interpolation`` not in ``['bilinear', 'nearest']``
33 | ValueError: ``upsample_type`` not in
34 | ``['upsamplelike','upsampling2d', 'upsampling3d']``
35 | Returns:
36 | tensorflow.keras.Layer: The semantic segmentation head
37 | """
38 | # Check input to ndims
39 | if ndim not in {2, 3}:
40 | raise ValueError("ndim must be either 2 or 3. " "Received ndim = {}".format(ndim))
41 |
42 | # Check input to interpolation
43 | acceptable_interpolation = {"bilinear", "nearest"}
44 | if interpolation not in acceptable_interpolation:
45 | raise ValueError(
46 | 'Interpolation mode "{}" not supported. '
47 | "Choose from {}.".format(interpolation, list(acceptable_interpolation))
48 | )
49 |
50 | # Check input to upsample_type
51 | acceptable_upsample = {"upsamplelike", "upsampling2d", "upsampling3d"}
52 | if upsample_type not in acceptable_upsample:
53 | raise ValueError(
54 | 'Upsample method "{}" not supported. '
55 | "Choose from {}.".format(upsample_type, list(acceptable_upsample))
56 | )
57 |
58 | # Check that there is an input_target if upsamplelike is used
59 | if upsample_type == "upsamplelike" and input_target is None:
60 | raise ValueError("upsamplelike requires an input_target.")
61 |
62 | conv = Conv2D if ndim == 2 else Conv3D
63 | conv_kernel = (1,) * ndim
64 |
65 | if K.image_data_format() == "channels_first":
66 | channel_axis = 1
67 | else:
68 | channel_axis = -1
69 |
70 | if n_classes == 1:
71 | include_top = False
72 |
73 | # Get pyramid names and features into list form
74 | pyramid_names = get_sorted_keys(pyramid_dict)
75 | pyramid_features = [pyramid_dict[name] for name in pyramid_names]
76 |
77 | # Reverse pyramid names and features
78 | pyramid_names.reverse()
79 | pyramid_features.reverse()
80 |
81 | # Previous method of building feature pyramids
82 | # semantic_features, semantic_names = [], []
83 | # for N, P in zip(pyramid_names, pyramid_features):
84 | # # Get level and determine how much to upsample
85 | # level = int(re.findall(r'\d+', N)[0])
86 | #
87 | # n_upsample = level - target_level
88 | # target = semantic_features[-1] if len(semantic_features) > 0 else None
89 | #
90 | # # Use semantic upsample to get semantic map
91 | # semantic_features.append(semantic_upsample(
92 | # P, n_upsample, n_filters=n_filters, target=target, ndim=ndim,
93 | # upsample_type=upsample_type, interpolation=interpolation,
94 | # semantic_id=semantic_id))
95 | # semantic_names.append('Q{}'.format(level))
96 |
97 | # Add all the semantic features
98 | # semantic_sum = semantic_features[0]
99 | # for semantic_feature in semantic_features[1:]:
100 | # semantic_sum = Add()([semantic_sum, semantic_feature])
101 |
102 | # TODO: bad name but using the same name more clearly indicates
103 | # how to integrate the previous version
104 | semantic_sum = pyramid_features[-1]
105 |
106 | # Final upsampling
107 | # min_level = int(re.findall(r'\d+', pyramid_names[-1])[0])
108 | # n_upsample = min_level - target_level
109 | n_upsample = target_level
110 | x = semantic_upsample(
111 | semantic_sum,
112 | n_upsample,
113 | # n_filters=n_filters, # TODO: uncomment and retrain
114 | target=input_target,
115 | ndim=ndim,
116 | upsample_type=upsample_type,
117 | semantic_id=semantic_id,
118 | interpolation=interpolation,
119 | )
120 |
121 | # Apply conv in place of previous tensor product
122 | x = conv(
123 | n_dense,
124 | conv_kernel,
125 | strides=1,
126 | padding="same",
127 | name="conv_0_semantic_{}".format(semantic_id),
128 | )(x)
129 | x = BatchNormalization(
130 | axis=channel_axis, name="batch_normalization_0_semantic_{}".format(semantic_id)
131 | )(x)
132 | x = Activation("relu", name="relu_0_semantic_{}".format(semantic_id))(x)
133 |
134 | # Apply conv and softmax layer
135 | x = conv(
136 | n_classes,
137 | conv_kernel,
138 | strides=1,
139 | padding="same",
140 | name="conv_1_semantic_{}".format(semantic_id),
141 | )(x)
142 |
143 | if include_top:
144 | x = Softmax(axis=channel_axis, dtype=K.floatx(), name="semantic_{}".format(semantic_id))(x)
145 | else:
146 | x = Activation("sigmoid", dtype=K.floatx(), name="semantic_{}".format(semantic_id))(x)
147 |
148 | return x
149 |
--------------------------------------------------------------------------------
/src/deepcell/utils.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of
3 | # Technology (Caltech), with support from the Paul Allen Family Foundation,
4 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
5 | # All rights reserved.
6 | #
7 | # Licensed under a modified Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
12 | #
13 | # The Work provided may be used for non-commercial academic purposes only.
14 | # For any other use of the Work, including commercial use, please contact:
15 | # vanvalenlab@gmail.com
16 | #
17 | # Neither the name of Caltech nor the names of its contributors may be used
18 | # to endorse or promote products derived from this software without specific
19 | # prior written permission.
20 | #
21 | # Unless required by applicable law or agreed to in writing, software
22 | # distributed under the License is distributed on an "AS IS" BASIS,
23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | # See the License for the specific language governing permissions and
25 | # limitations under the License.
26 | # ==============================================================================
27 | from tensorflow.python.client import device_lib
28 |
29 |
30 | def count_gpus():
31 | """Get the number of available GPUs.
32 |
33 | Returns:
34 | int: count of GPUs as integer
35 | """
36 | devices = device_lib.list_local_devices()
37 | gpus = [d for d in devices if d.name.lower().startswith('/device:gpu')]
38 | return len(gpus)
39 |
40 |
41 | def get_sorted_keys(dict_to_sort):
42 | """Gets the keys from a dict and sorts them in ascending order.
43 | Assumes keys are of the form ``Ni``, where ``N`` is a letter and ``i``
44 | is an integer.
45 |
46 | Args:
47 | dict_to_sort (dict): dict whose keys need sorting
48 |
49 | Returns:
50 | list: list of sorted keys from ``dict_to_sort``
51 | """
52 | sorted_keys = list(dict_to_sort.keys())
53 | sorted_keys.sort(key=lambda x: int(x[1:]))
54 | return sorted_keys
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/angelolab/Nimbus/b418eb842ee095afd0a9abe79be98a66ba157127/tests/__init__.py
--------------------------------------------------------------------------------
/tests/application_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import keras
3 | import pytest
4 | import tempfile
5 | import numpy as np
6 | from segmentation_data_prep_test import prep_object_and_inputs
7 | from cell_classification.application import Nimbus, nimbus_preprocess
8 | from cell_classification.application import prep_deepcell_naming_convention
9 |
10 |
11 | def test_prep_deepcell_naming_convention():
12 | segmentation_naming_convention = prep_deepcell_naming_convention("test_dir")
13 | seg_path = segmentation_naming_convention("fov0")
14 | assert isinstance(segmentation_naming_convention, type(lambda x: x))
15 | assert isinstance(seg_path, str)
16 | assert seg_path == os.path.join("test_dir", "fov0_whole_cell.tiff")
17 |
18 |
19 | def test_nimbus_preprocess():
20 | input_data = np.random.rand(1, 1024, 1024, 2)
21 | expected_output = np.copy(input_data)
22 | expected_output[..., 0] = input_data[..., 0] / 1.2
23 |
24 | output = nimbus_preprocess(
25 | input_data, normalize=True, marker="test", normalization_dict={"test": 1.2}
26 | )
27 | # check if shape and values are in the expected range
28 | assert output.shape == (1, 1024, 1024, 2)
29 | assert output.max() <= 1.0
30 | assert output.min() >= 0.0
31 | # check if normalization was applied
32 | assert np.allclose(expected_output, output, atol=1e-5)
33 |
34 | # check if normalization works when not dict is given
35 | output = nimbus_preprocess(input_data, normalize=True, marker="test")
36 | expected_output[..., 0] = input_data[..., 0] / np.quantile(input_data[..., 0], 0.999)
37 | expected_output = expected_output.clip(0, 1)
38 |
39 | assert np.allclose(expected_output, output, atol=1e-5)
40 |
41 | # check if normalization works when marker is not in dict
42 | output = nimbus_preprocess(
43 | input_data, normalize=True, marker="test2", normalization_dict={"test": 1.2}
44 | )
45 | assert np.allclose(expected_output, output, atol=1e-5)
46 |
47 | # check if normalization works when normalization is set to False
48 | output = nimbus_preprocess(input_data, normalize=False)
49 | assert np.array_equal(input_data, output)
50 |
51 |
52 | def test_initialize_model():
53 | with tempfile.TemporaryDirectory() as temp_dir:
54 | def segmentation_naming_convention(fov_path):
55 | return os.path.join(fov_path, "cell_segmentation.tiff")
56 |
57 | nimbus = Nimbus(["none_path"], segmentation_naming_convention, temp_dir)
58 | assert isinstance(nimbus.model, keras.engine.functional.Functional)
59 |
60 |
61 | def test_check_inputs():
62 | with tempfile.TemporaryDirectory() as temp_dir:
63 | def segmentation_naming_convention(fov_path):
64 | return os.path.join(fov_path, "cell_segmentation.tiff")
65 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir)
66 | output_dir = temp_dir
67 |
68 | # check if no errors are raised when all inputs are valid
69 | nimbus = Nimbus(fov_paths, segmentation_naming_convention, output_dir)
70 | nimbus.check_inputs()
71 |
72 | # check if error is raised when a path in fov_paths does not exist on disk
73 | nimbus.fov_paths.append("invalid_path")
74 | with pytest.raises(FileNotFoundError, match="invalid_path"):
75 | nimbus.check_inputs()
76 |
77 | # check if error is raised when segmentation_name_convention does not return a valid path
78 | nimbus.fov_paths = fov_paths
79 |
80 | def segmentation_naming_convention(fov_path):
81 | return os.path.join(fov_path, "invalid_path.tiff")
82 |
83 | with pytest.raises(FileNotFoundError, match="invalid_path"):
84 | nimbus.check_inputs()
85 |
86 | # check if error is raised when output_dir does not exist on disk
87 | nimbus.segmentation_naming_convention = segmentation_naming_convention
88 | nimbus.output_dir = "invalid_path"
89 | with pytest.raises(FileNotFoundError, match="invalid_path"):
90 | nimbus.check_inputs()
91 |
92 |
93 | def test_prepare_normalization_dict():
94 | # test if normalization dict gets prepared and saved, in-depth tests are in inference_test.py
95 | with tempfile.TemporaryDirectory() as temp_dir:
96 | def segmentation_naming_convention(fov_path):
97 | return os.path.join(fov_path, "cell_segmentation.tiff")
98 |
99 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir)
100 | nimbus = Nimbus(
101 | fov_paths, segmentation_naming_convention, temp_dir, exclude_channels=["CD57"]
102 | )
103 | # test if normalization dict gets prepared and saved
104 | nimbus.prepare_normalization_dict(overwrite=True)
105 | assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json"))
106 | assert "CD57" not in nimbus.normalization_dict.keys()
107 |
108 | # test if normalization dict gets loaded
109 | nimbus_2 = Nimbus(
110 | fov_paths, segmentation_naming_convention, temp_dir, exclude_channels=["CD57"]
111 | )
112 | nimbus_2.prepare_normalization_dict()
113 | assert nimbus_2.normalization_dict == nimbus.normalization_dict
114 |
115 |
116 | def test_predict_fovs():
117 | with tempfile.TemporaryDirectory() as temp_dir:
118 | def segmentation_naming_convention(fov_path):
119 | return os.path.join(fov_path, "cell_segmentation.tiff")
120 |
121 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir)
122 | os.remove(os.path.join(temp_dir, 'normalization_dict.json'))
123 | output_dir = os.path.join(temp_dir, "nimbus_output")
124 | fov_paths = fov_paths[:1]
125 | nimbus = Nimbus(
126 | fov_paths, segmentation_naming_convention, output_dir,
127 | exclude_channels=["CD57", "CD11c", "XYZ"]
128 | )
129 | cell_table = nimbus.predict_fovs()
130 |
131 | # check if all channels are in the cell_table
132 | for channel in nimbus.normalization_dict.keys():
133 | assert channel+"_pred" in cell_table.columns
134 | # check if cell_table was saved
135 | assert os.path.exists(os.path.join(output_dir, "nimbus_cell_table.csv"))
136 |
137 | # check if fov folders were created in output_dir
138 | for fov_path in fov_paths:
139 | fov = os.path.basename(fov_path)
140 | assert os.path.exists(os.path.join(output_dir, fov))
141 | # check if all predictions were saved
142 | for channel in nimbus.normalization_dict.keys():
143 | assert os.path.exists(os.path.join(output_dir, fov, channel+".tiff"))
144 | # check if CD57 was excluded
145 | assert not os.path.exists(os.path.join(output_dir, fov, "CD57.tiff"))
146 | assert not os.path.exists(os.path.join(output_dir, fov, "CD11c.tiff"))
147 | assert not os.path.exists(os.path.join(output_dir, fov, "XYZ.tiff"))
148 |
--------------------------------------------------------------------------------
/tests/augmentation_pipeline_test.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import imgaug.augmenters as iaa
4 | import numpy as np
5 | import pytest
6 | import tensorflow as tf
7 |
8 | from cell_classification.augmentation_pipeline import (
9 | Flip, GaussianBlur, GaussianNoise, LinearContrast, MixUp, Rot90, Zoom,
10 | augment_images, get_augmentation_pipeline, prepare_keras_aug,
11 | prepare_tf_aug, py_aug)
12 |
13 | parametrize = pytest.mark.parametrize
14 |
15 |
16 | def get_params():
17 | return {
18 | # flip
19 | "flip_prob": 1.0,
20 | # affine
21 | "affine_prob": 1.0,
22 | "scale_min": 0.5,
23 | "scale_max": 1.5,
24 | "shear_angle": 10,
25 | # elastic
26 | "elastic_prob": 1.0,
27 | "elastic_alpha": [0, 5.0],
28 | "elastic_sigma": 0.5,
29 | # rotate
30 | "rotate_prob": 1.0,
31 | "rotate_count": 3,
32 | # gaussian noise
33 | "gaussian_noise_prob": 1.0,
34 | "gaussian_noise_min": 0.1,
35 | "gaussian_noise_max": 0.5,
36 | # gaussian blur
37 | "gaussian_blur_prob": 0.0,
38 | "gaussian_blur_min": 0.1,
39 | "gaussian_blur_max": 0.5,
40 | # contrast aug
41 | "contrast_prob": 0.0,
42 | "contrast_min": 0.1,
43 | "contrast_max": 2.0,
44 | }
45 |
46 |
47 | def test_get_augmentation_pipeline():
48 | params = get_params()
49 | augmentation_pipeline = get_augmentation_pipeline(params)
50 | assert isinstance(augmentation_pipeline, iaa.Sequential)
51 |
52 |
53 | @parametrize("batch_num", [1, 2, 3])
54 | @parametrize("chan_num", [1, 2, 3])
55 | def test_augment_images(batch_num, chan_num):
56 | params = get_params()
57 | augmentation_pipeline = get_augmentation_pipeline(params)
58 | images = np.zeros([batch_num, 100, 100, chan_num], dtype=np.float32)
59 | masks = np.zeros([batch_num, 100, 100, chan_num], dtype=np.int32)
60 | images[0, :50, :50, :] = 10.1
61 | images[0, 50:, 50:, :] = 201.12
62 | masks[0, :50, :50] = 1
63 | masks[0, 50:, 50:] = 2
64 | augmented_images, augmented_masks = augment_images(images, masks, augmentation_pipeline)
65 |
66 | # check if right types and shapes are returned
67 | assert isinstance(augmented_images, np.ndarray)
68 | assert isinstance(augmented_masks, np.ndarray)
69 | assert augmented_images.dtype == np.float32
70 | assert augmented_masks.dtype == np.int32
71 | assert augmented_images.shape == images.shape
72 | assert augmented_masks.shape == masks.shape
73 |
74 | # check if images are augmented
75 | assert not np.array_equal(augmented_images, images)
76 | assert not np.array_equal(augmented_masks, masks)
77 |
78 | # check if masks are still binary with the right labels
79 | assert list(np.unique(augmented_masks)) == [0, 1, 2]
80 |
81 | # check if images and masks where augmented with the same spatial augmentations approx.
82 | assert np.abs(augmented_images[augmented_masks == 0].mean() - images[masks == 0].mean()) < 1.5
83 | assert np.abs(augmented_images[augmented_masks == 1].mean() - images[masks == 1].mean()) < 1.5
84 | assert np.abs(augmented_images[augmented_masks == 2].mean() - images[masks == 2].mean()) < 5
85 |
86 | # check control flow for no channel dimensions for the masks
87 | masks = np.zeros([batch_num, 100, 100], dtype=np.int32)
88 | _, augmented_masks = augment_images(images, masks, augmentation_pipeline)
89 | assert augmented_masks.shape == masks.shape
90 |
91 |
92 | def prepare_data(batch_num, return_tensor=False):
93 | mplex_img = np.zeros([batch_num, 100, 100, 2], dtype=np.float32)
94 | binary_mask = np.zeros([batch_num, 100, 100, 1], dtype=np.int32)
95 | marker_activity_mask = np.zeros([batch_num, 100, 100, 1], dtype=np.int32)
96 | mplex_img[0, :30, :50, :] = 10.1
97 | mplex_img[0, 50:, 50:, :] = 21.12
98 | mplex_img[-1, 30:60, :50, :] = 14.11
99 | mplex_img[-1, :50, 50:, :] = 18.12
100 | binary_mask[0, :30, :50] = 1
101 | binary_mask[0, 50:, 50:] = 1
102 | binary_mask[-1, 30:60, :50, :] = 1
103 | binary_mask[-1, :50, 50:, :] = 1
104 | marker_activity_mask[0, :30, :50] = 1
105 | marker_activity_mask[0, 50:, 50:] = 2
106 | marker_activity_mask[-1, 30:60, :50, :] = 1
107 | marker_activity_mask[-1, :50, 50:, :] = 2
108 |
109 | if return_tensor:
110 | mplex_img = tf.constant(mplex_img, tf.float32)
111 | binary_mask = tf.constant(binary_mask, tf.int32)
112 | marker_activity_mask = tf.constant(marker_activity_mask, tf.int32)
113 | return mplex_img, binary_mask, marker_activity_mask
114 |
115 |
116 | @parametrize("batch_num", [1, 2, 3])
117 | def test_prepare_tf_aug(batch_num):
118 | params = get_params()
119 | augmentation_pipeline = get_augmentation_pipeline(params)
120 | tf_aug = prepare_tf_aug(augmentation_pipeline)
121 | mplex_img, binary_mask, marker_activity_mask = prepare_data(batch_num)
122 | mplex_aug, mask_out, marker_activity_aug = tf_aug(
123 | tf.constant(mplex_img, tf.float32), binary_mask, marker_activity_mask
124 | )
125 | mplex_img, binary_mask, marker_activity_mask = prepare_data(batch_num)
126 | # check if right types and shapes are returned
127 | assert isinstance(mplex_aug, np.ndarray)
128 | assert isinstance(mask_out, np.ndarray)
129 | assert isinstance(marker_activity_aug, np.ndarray)
130 | assert mplex_aug.dtype == np.float32
131 | assert mask_out.dtype == np.int32
132 | assert marker_activity_aug.dtype == np.int32
133 | assert mplex_aug.shape == mplex_img.shape
134 | assert mask_out.shape == binary_mask.shape
135 | assert marker_activity_aug.shape == marker_activity_mask.shape
136 |
137 |
138 | @parametrize("batch_num", [1, 2, 3])
139 | def test_py_aug(batch_num):
140 | params = get_params()
141 | augmentation_pipeline = get_augmentation_pipeline(params)
142 | tf_aug = prepare_tf_aug(augmentation_pipeline)
143 | mplex_img, binary_mask, marker_activity_mask = prepare_data(batch_num)
144 | batch = {
145 | "mplex_img": tf.constant(mplex_img, tf.float32),
146 | "binary_mask": tf.constant(binary_mask, tf.int32),
147 | "marker_activity_mask": tf.constant(marker_activity_mask, tf.int32),
148 | "dataset": "test_dataset",
149 | "marker": "test_marker",
150 | "imaging_platform": "test_platform",
151 | }
152 | batch_aug = py_aug(deepcopy(batch), tf_aug)
153 |
154 | # check if right types and shapes are returned
155 | for key in batch.keys():
156 | assert isinstance(batch_aug[key], type(batch[key]))
157 |
158 | for key in ["mplex_img", "binary_mask", "marker_activity_mask"]:
159 | assert batch_aug[key].shape == batch[key].shape
160 | assert not np.array_equal(batch_aug[key], batch[key])
161 |
162 |
163 | @parametrize("batch_num", [2, 4, 8])
164 | def test_prepare_keras_aug(batch_num):
165 | params = get_params()
166 | augmentation_pipeline = prepare_keras_aug(params)
167 | images, _, masks = prepare_data(batch_num, True)
168 | augmented_images, augmented_masks = augmentation_pipeline(images, masks)
169 |
170 | # check if right types and shapes are returned
171 | assert augmented_images.dtype == tf.float32
172 | assert augmented_masks.dtype == tf.int32
173 | assert augmented_images.shape == images.shape
174 | assert augmented_masks.shape == masks.shape
175 |
176 |
177 | @parametrize("batch_num", [2, 4, 8])
178 | def test_flip(batch_num):
179 | images, _, masks = prepare_data(batch_num, True)
180 | flip = Flip(prob=1.0)
181 | aug_img, aug_mask = flip(images, masks)
182 |
183 | # check if right types and shapes are returned
184 | assert aug_img.dtype == images.dtype
185 | assert aug_mask.dtype == masks.dtype
186 | assert aug_img.shape == images.shape
187 | assert aug_mask.shape == masks.shape
188 |
189 | # check if data got flipped
190 | assert not np.array_equal(aug_img, images)
191 | assert not np.array_equal(aug_mask, masks)
192 | assert np.sum(aug_img) == np.sum(images)
193 | assert np.sum(aug_mask) == np.sum(masks)
194 |
195 |
196 | @parametrize("batch_num", [2, 4, 8])
197 | def test_rot90(batch_num):
198 | images, _, masks = prepare_data(batch_num, True)
199 | rot90 = Rot90(prob=1.0, rotate_count=2)
200 | aug_img, aug_mask = rot90(images, masks)
201 |
202 | # check if right types and shapes are returned
203 | assert aug_img.dtype == images.dtype
204 | assert aug_mask.dtype == masks.dtype
205 | assert aug_img.shape == images.shape
206 | assert aug_mask.shape == masks.shape
207 |
208 | # check if data got rotated
209 | assert not np.array_equal(aug_img, images)
210 | assert not np.array_equal(aug_mask, masks)
211 | assert np.sum(aug_img) == np.sum(images)
212 | assert np.sum(aug_mask) == np.sum(masks)
213 |
214 |
215 | @parametrize("batch_num", [2, 4, 8])
216 | def test_gaussian_noise(batch_num):
217 | images, _, masks = prepare_data(batch_num, True)
218 | gaussian_noise = GaussianNoise(prob=1.0)
219 | aug_img, aug_mask = gaussian_noise(images, masks)
220 |
221 | # check if right types and shapes are returned
222 | assert aug_img.dtype == images.dtype
223 | assert aug_mask.dtype == masks.dtype
224 | assert aug_img.shape == images.shape
225 | assert aug_mask.shape == masks.shape
226 |
227 | # check if data got augmented
228 | assert not np.array_equal(aug_img, images)
229 | assert np.array_equal(aug_mask, masks)
230 | assert np.isclose(np.mean(aug_img), np.mean(images), atol=0.1)
231 |
232 |
233 | @parametrize("batch_num", [2, 4, 8])
234 | def test_gaussian_blur(batch_num):
235 | images, _, masks = prepare_data(batch_num, True)
236 | gaussian_blur = GaussianBlur(1.0, 0.5, 1.5, 5)
237 | aug_img, aug_mask = gaussian_blur(images, masks)
238 |
239 | # check if right types and shapes are returned
240 | assert aug_img.dtype == images.dtype
241 | assert aug_mask.dtype == masks.dtype
242 | assert aug_img.shape == images.shape
243 | assert aug_mask.shape == masks.shape
244 |
245 | # check if data got augmented
246 | assert not np.array_equal(aug_img, images)
247 | assert np.array_equal(aug_mask, masks)
248 | assert np.isclose(np.mean(aug_img), np.mean(images), atol=0.2)
249 |
250 |
251 | @parametrize("batch_num", [2, 4, 8])
252 | def test_zoom(batch_num):
253 | images, _, masks = prepare_data(batch_num, True)
254 | zoom = Zoom(1.0, 0.5, 0.5)
255 | aug_img, aug_mask = zoom(images, masks)
256 |
257 | # check if right types and shapes are returned
258 | assert aug_img.dtype == images.dtype
259 | assert aug_mask.dtype == masks.dtype
260 | assert aug_img.shape == images.shape
261 | assert aug_mask.shape == masks.shape
262 |
263 | # check if data got augmented
264 | assert not np.array_equal(aug_img, images)
265 | assert not np.array_equal(aug_mask, masks)
266 |
267 | # check if data got zoomed to 0.5
268 | assert np.sum(aug_img) == np.sum(images) / 4
269 | assert np.sum(aug_mask) == np.sum(masks) / 4
270 |
271 |
272 | @parametrize("batch_num", [2, 4, 8])
273 | def test_linear_contrast(batch_num):
274 | images, _, masks = prepare_data(batch_num, True)
275 | linear_contrast = LinearContrast(1.0, 0.75, 0.75)
276 | aug_img, aug_mask = linear_contrast(images, masks)
277 |
278 | # check if right types and shapes are returned
279 | assert aug_img.dtype == images.dtype
280 | assert aug_mask.dtype == masks.dtype
281 | assert aug_img.shape == images.shape
282 | assert aug_mask.shape == masks.shape
283 |
284 | # check if data got augmented
285 | assert not np.array_equal(aug_img, images)
286 | assert np.array_equal(aug_mask, masks)
287 | assert np.isclose(np.mean(aug_img), np.mean(images) * 0.75, atol=0.01)
288 |
289 |
290 | @parametrize("batch_num", [2, 4, 8])
291 | def test_mixup(batch_num):
292 | images, _, labels = prepare_data(batch_num, True)
293 | mixup = MixUp(1.0, 0.5)
294 | x_mplex, x_binary = tf.split(images, 2, axis=-1)
295 | loss_mask = tf.cast(labels, tf.float32)
296 |
297 | x_mplex_aug, x_binary_aug, labels_aug, loss_mask_aug = mixup(
298 | x_mplex, x_binary, labels, loss_mask
299 | )
300 |
301 | # check if right types and shapes are returned
302 | assert x_mplex_aug.dtype == x_mplex.dtype
303 | assert x_binary_aug.dtype == x_binary.dtype
304 | assert labels_aug.dtype == tf.float32
305 | assert loss_mask_aug.dtype == loss_mask.dtype
306 | assert x_mplex_aug.shape == x_mplex.shape
307 | assert x_binary_aug.shape == x_binary.shape
308 | assert labels_aug.shape == labels.shape
309 | assert loss_mask_aug.shape == loss_mask.shape
310 |
311 | # check if data got augmented
312 | assert not np.array_equal(x_mplex_aug, x_mplex)
313 | assert not np.array_equal(x_binary_aug, x_binary)
314 | assert not np.array_equal(labels_aug, labels)
315 | assert not np.array_equal(loss_mask_aug, loss_mask)
316 |
317 | # check if data got mixed up
318 | assert np.isclose(np.mean(x_mplex_aug), np.mean(x_mplex), atol=0.1)
319 | assert np.isclose(np.mean(x_binary_aug), np.mean(x_binary), atol=0.1)
320 | assert np.isclose(np.mean(labels_aug), np.mean(labels), atol=0.1)
321 | assert np.isclose(
322 | np.mean(loss_mask_aug), np.mean(loss_mask*tf.reverse(loss_mask, [0])), atol=0.1
323 | )
324 |
--------------------------------------------------------------------------------
/tests/deepcell_application_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of
2 | # Technology (Caltech), with support from the Paul Allen Family Foundation,
3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
4 | # All rights reserved.
5 | #
6 | # Licensed under a modified Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
11 | #
12 | # The Work provided may be used for non-commercial academic purposes only.
13 | # For any other use of the Work, including commercial use, please contact:
14 | # vanvalenlab@gmail.com
15 | #
16 | # Neither the name of Caltech nor the names of its contributors may be used
17 | # to endorse or promote products derived from this software without specific
18 | # prior written permission.
19 | #
20 | # Unless required by applicable law or agreed to in writing, software
21 | # distributed under the License is distributed on an "AS IS" BASIS,
22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | # See the License for the specific language governing permissions and
24 | # limitations under the License.
25 | # ==============================================================================
26 | """Tests for Application"""
27 |
28 |
29 | from itertools import product
30 | from unittest.mock import Mock
31 |
32 | import pytest
33 | import numpy as np
34 |
35 | from tensorflow.python.platform import test
36 |
37 | from deepcell.application import Application
38 |
39 |
40 | class DummyModel:
41 |
42 | def __init__(self, n_out=1):
43 | self.n_out = n_out
44 |
45 | def predict(self, x, batch_size=4):
46 | y = np.random.rand(*x.shape)
47 | return [y] * self.n_out
48 |
49 |
50 | class TestApplication(test.TestCase):
51 |
52 | def test_predict_notimplemented(self):
53 | model = DummyModel()
54 | kwargs = {'model_mpp': 0.65,
55 | 'model_image_shape': (128, 128, 1)}
56 | app = Application(model, **kwargs)
57 |
58 | x = np.random.rand(1, 500, 500, 1)
59 |
60 | with self.assertRaises(NotImplementedError):
61 | app.predict(x)
62 |
63 | def test_resize(self):
64 | model = DummyModel()
65 | kwargs = {'model_mpp': 0.65,
66 | 'model_image_shape': (128, 128, 1)}
67 | app = Application(model, **kwargs)
68 |
69 | x = np.random.rand(1, 500, 500, 1)
70 |
71 | # image_mpp = None --> No resize
72 | y = app._resize_input(x, image_mpp=None)
73 | self.assertEqual(x.shape, y.shape)
74 |
75 | # image_mpp = model_mpp --> No resize
76 | y = app._resize_input(x, image_mpp=kwargs['model_mpp'])
77 | self.assertEqual(x.shape, y.shape)
78 |
79 | # image_mpp > model_mpp --> resize
80 | y = app._resize_input(x, image_mpp=2.1 * kwargs['model_mpp'])
81 | self.assertEqual(2.1, np.round(y.shape[1] / x.shape[1], decimals=1))
82 |
83 | # image_mpp < model_mpp --> resize
84 | y = app._resize_input(x, image_mpp=0.7 * kwargs['model_mpp'])
85 | self.assertEqual(0.7, np.round(y.shape[1] / x.shape[1], decimals=1))
86 |
87 | def test_preprocess(self):
88 |
89 | def _preprocess(x):
90 | y = np.ones(x.shape)
91 | return y
92 |
93 | model = DummyModel()
94 | x = np.random.rand(1, 30, 30, 1)
95 |
96 | # Test no preprocess input
97 | app = Application(model)
98 | y = app._preprocess(x)
99 | self.assertAllEqual(x, y)
100 |
101 | # Test ones function
102 | kwargs = {'preprocessing_fn': _preprocess}
103 | app = Application(model, **kwargs)
104 | y = app._preprocess(x)
105 | self.assertAllEqual(np.ones(x.shape), y)
106 |
107 | # Test bad input
108 | kwargs = {'preprocessing_fn': 'x'}
109 | with self.assertRaises(ValueError):
110 | app = Application(model, **kwargs)
111 |
112 | def test_tile_input(self):
113 | model = DummyModel()
114 | kwargs = {'model_mpp': 0.65,
115 | 'model_image_shape': (128, 128, 1)}
116 | app = Application(model, **kwargs)
117 |
118 | # No tiling
119 | x = np.random.rand(1, 128, 128, 1)
120 | y, tile_info = app._tile_input(x)
121 | self.assertEqual(x.shape, y.shape)
122 | self.assertIsInstance(tile_info, dict)
123 |
124 | # Tiling square
125 | x = np.random.rand(1, 400, 400, 1)
126 | y, tile_info = app._tile_input(x)
127 | self.assertEqual(kwargs['model_image_shape'][:-1], y.shape[1:-1])
128 | self.assertIsInstance(tile_info, dict)
129 |
130 | # Tiling rectangle
131 | x = np.random.rand(1, 300, 500, 1)
132 | y, tile_info = app._tile_input(x)
133 | self.assertEqual(kwargs['model_image_shape'][:-1], y.shape[1:-1])
134 | self.assertIsInstance(tile_info, dict)
135 |
136 | # Smaller than expected
137 | x = np.random.rand(1, 100, 100, 1)
138 | y, tile_info = app._tile_input(x)
139 | self.assertEqual(kwargs['model_image_shape'][:-1], y.shape[1:-1])
140 | self.assertIsInstance(tile_info, dict)
141 |
142 | def test_postprocess(self):
143 |
144 | def _postprocess(Lx):
145 | y = np.ones(Lx[0].shape)
146 | return y
147 |
148 | model = DummyModel()
149 | x = np.random.rand(1, 30, 30, 1)
150 |
151 | # No input
152 | app = Application(model)
153 | y = app._postprocess(x)
154 | self.assertAllEqual(x, y)
155 |
156 | # Ones
157 | kwargs = {'postprocessing_fn': _postprocess}
158 | app = Application(model, **kwargs)
159 | y = app._postprocess([x])
160 | self.assertAllEqual(np.ones(x.shape), y)
161 |
162 | # Bad input
163 | kwargs = {'postprocessing_fn': 'x'}
164 | with self.assertRaises(ValueError):
165 | app = Application(model, **kwargs)
166 |
167 | def test_untile_output(self):
168 | model = DummyModel()
169 | kwargs = {'model_image_shape': (128, 128, 1)}
170 | app = Application(model, **kwargs)
171 |
172 | # No tiling
173 | x = np.random.rand(1, 128, 128, 1)
174 | tiles, tile_info = app._tile_input(x)
175 | y = app._untile_output(tiles, tile_info)
176 | self.assertEqual(x.shape, y.shape)
177 |
178 | # Tiling square
179 | x = np.random.rand(1, 400, 400, 1)
180 | tiles, tile_info = app._tile_input(x)
181 | y = app._untile_output(tiles, tile_info)
182 | self.assertEqual(x.shape, y.shape)
183 |
184 | # Tiling rectangle
185 | x = np.random.rand(1, 300, 500, 1)
186 | tiles, tile_info = app._tile_input(x)
187 | y = app._untile_output(tiles, tile_info)
188 | self.assertEqual(x.shape, y.shape)
189 |
190 | # Smaller than expected
191 | x = np.random.rand(1, 100, 100, 1)
192 | tiles, tile_info = app._tile_input(x)
193 | y = app._untile_output(tiles, tile_info)
194 | self.assertEqual(x.shape, y.shape)
195 |
196 | def test_resize_output(self):
197 | model = DummyModel()
198 | kwargs = {'model_image_shape': (128, 128, 1)}
199 | app = Application(model, **kwargs)
200 |
201 | x = np.random.rand(1, 128, 128, 1)
202 |
203 | # x.shape = original_shape --> no resize
204 | y = app._resize_output(x, x.shape)
205 | self.assertEqual(x.shape, y.shape)
206 |
207 | # x.shape != original_shape --> resize
208 | original_shape = (1, 500, 500, 1)
209 | y = app._resize_output(x, original_shape)
210 | self.assertEqual(original_shape, y.shape)
211 |
212 | # test multiple outputs are also resized
213 | x_list = [x, x]
214 |
215 | # x.shape = original_shape --> no resize
216 | y = app._resize_output(x_list, x.shape)
217 | self.assertIsInstance(y, list)
218 | for y_sub in y:
219 | self.assertEqual(x.shape, y_sub.shape)
220 |
221 | # x.shape != original_shape --> resize
222 | original_shape = (1, 500, 500, 1)
223 | y = app._resize_output(x_list, original_shape)
224 | self.assertIsInstance(y, list)
225 | for y_sub in y:
226 | self.assertEqual(original_shape, y_sub.shape)
227 |
228 | def test_format_model_output(self):
229 | def _format_model_output(Lx):
230 | return {'inner-distance': Lx}
231 |
232 | model = DummyModel()
233 | x = np.random.rand(1, 30, 30, 1)
234 |
235 | # No function
236 | app = Application(model)
237 | y = app._format_model_output(x)
238 | self.assertAllEqual(x, y)
239 |
240 | # single image
241 | kwargs = {'format_model_output_fn': _format_model_output}
242 | app = Application(model, **kwargs)
243 | y = app._format_model_output(x)
244 | self.assertAllEqual(x, y['inner-distance'])
245 |
246 | def test_batch_predict(self):
247 |
248 | def predict1(x, batch_size=4):
249 | y = np.random.rand(*x.shape)
250 | return [y]
251 |
252 | def predict2(x, batch_size=4):
253 | y = np.random.rand(*x.shape)
254 | return [y] * 2
255 |
256 | num_images = [4, 8, 10]
257 | num_pred_heads = [1, 2]
258 | batch_sizes = [1, 4, 5]
259 | prod = product(num_images, num_pred_heads, batch_sizes)
260 |
261 | for num_image, num_pred_head, batch_size in prod:
262 | model = DummyModel(n_out=num_pred_head)
263 | app = Application(model)
264 |
265 | x = np.random.rand(num_image, 128, 128, 1)
266 |
267 | if num_pred_head == 1:
268 | app.model.predict = Mock(side_effect=predict1)
269 | else:
270 | app.model.predict = Mock(side_effect=predict2)
271 | y = app._batch_predict(x, batch_size=batch_size)
272 |
273 | assert app.model.predict.call_count == np.ceil(num_image / batch_size)
274 |
275 | self.assertEqual(x.shape, y[0].shape)
276 | if num_pred_head == 2:
277 | self.assertEqual(x.shape, y[1].shape)
278 |
279 | def test_run_model(self):
280 | model = DummyModel(n_out=2)
281 | app = Application(model)
282 |
283 | x = np.random.rand(1, 128, 128, 1)
284 | y = app._run_model(x)
285 | self.assertEqual(x.shape, y[0].shape)
286 | self.assertEqual(x.shape, y[1].shape)
287 |
288 | def test_predict_segmentation(self):
289 | model = DummyModel()
290 | app = Application(model)
291 |
292 | x = np.random.rand(1, 128, 128, 1)
293 | y = app._predict_segmentation(x)
294 | self.assertEqual(x.shape, y.shape)
295 |
296 | # test with different MPP
297 | model = DummyModel()
298 | app = Application(model)
299 |
300 | x = np.random.rand(1, 128, 128, 1)
301 | y = app._predict_segmentation(x, image_mpp=1.3)
302 | self.assertEqual(x.shape, y.shape)
303 |
304 |
305 | @pytest.mark.parametrize(
306 | "orig_img_shape",
307 | (
308 | (1, 216, 256, 1), # x < model_img_shape
309 | (1, 256, 64, 1), # y < model_img_shape
310 | ),
311 | )
312 | def test_untile_non_empty_slices(orig_img_shape):
313 | """Test corner cases when tile_info["padding"] is True, but the padding in
314 | at least one of the dimensions is (0, 0). See gh-665."""
315 | # Padding is "activated" whenever either of the input image dimensions is
316 | # smaller than app.model_image_shape
317 | app = Application(DummyModel, model_image_shape=(256, 256, 1))
318 |
319 | img = np.ones(orig_img_shape)
320 | # Use _tile_input to generate a valid tile_info object
321 | tiles, tile_info = app._tile_input(img)
322 | # Input validation - tile_info should have a "padding" key and one of the
323 | # two pads should be (0, 0).
324 | assert tile_info.get("padding") and (
325 | tile_info["x_pad"] == (0, 0) or tile_info["y_pad"] == (0, 0)
326 | )
327 | untiled = app._untile_output(tiles, tile_info)
328 | assert untiled.shape == orig_img_shape
329 |
--------------------------------------------------------------------------------
/tests/deepcell_backbone_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of
2 | # Technology (Caltech), with support from the Paul Allen Family Foundation,
3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
4 | # All rights reserved.
5 | #
6 | # Licensed under a modified Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
11 | #
12 | # The Work provided may be used for non-commercial academic purposes only.
13 | # For any other use of the Work, including commercial use, please contact:
14 | # vanvalenlab@gmail.com
15 | #
16 | # Neither the name of Caltech nor the names of its contributors may be used
17 | # to endorse or promote products derived from this software without specific
18 | # prior written permission.
19 | #
20 | # Unless required by applicable law or agreed to in writing, software
21 | # distributed under the License is distributed on an "AS IS" BASIS,
22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | # See the License for the specific language governing permissions and
24 | # limitations under the License.
25 | # ==============================================================================
26 | """Tests for backbone_utils"""
27 |
28 |
29 | from absl.testing import parameterized
30 |
31 | from tensorflow.python.framework import test_util as tf_test_util
32 | from tensorflow.python.platform import test
33 |
34 | from tensorflow.keras import backend as K
35 | from tensorflow.keras.layers import Input
36 | from tensorflow.keras.models import Model
37 | from keras import keras_parameterized
38 |
39 | from deepcell import backbone_utils
40 |
41 |
42 | class TestBackboneUtils(keras_parameterized.TestCase):
43 |
44 | @keras_parameterized.run_with_all_model_types
45 | @keras_parameterized.run_all_keras_modes
46 | @parameterized.named_parameters(
47 | *tf_test_util.generate_combinations_with_testcase_name(
48 | data_format=[
49 | # 'channels_first',
50 | 'channels_last']))
51 | def test_get_featurenet_backbone(self, data_format):
52 | backbone = 'featurenet'
53 | input_shape = (256, 256, 3)
54 | inputs = Input(shape=input_shape)
55 | with self.cached_session():
56 | K.set_image_data_format(data_format)
57 | model, output_dict = backbone_utils.get_backbone(
58 | backbone, inputs, return_dict=True)
59 | assert isinstance(output_dict, dict)
60 | assert all(k.startswith('C') for k in output_dict)
61 | assert isinstance(model, Model)
62 |
63 | # No imagenet weights for featurenet backbone
64 | with self.assertRaises(ValueError):
65 | backbone_utils.get_backbone(backbone, inputs, use_imagenet=True)
66 |
67 | # @keras_parameterized.run_all_keras_modes
68 | @parameterized.named_parameters(
69 | *tf_test_util.generate_combinations_with_testcase_name(
70 | data_format=[
71 | # 'channels_first',
72 | 'channels_last']))
73 | def test_get_featurenet3d_backbone(self, data_format):
74 | backbone = 'featurenet3d'
75 | input_shape = (40, 256, 256, 3)
76 | inputs = Input(shape=input_shape)
77 | with self.cached_session():
78 | K.set_image_data_format(data_format)
79 | model, output_dict = backbone_utils.get_backbone(
80 | backbone, inputs, return_dict=True)
81 | assert isinstance(output_dict, dict)
82 | assert all(k.startswith('C') for k in output_dict)
83 | assert isinstance(model, Model)
84 |
85 | # No imagenet weights for featurenet backbone
86 | with self.assertRaises(ValueError):
87 | backbone_utils.get_backbone(backbone, inputs, use_imagenet=True)
88 |
89 | # @keras_parameterized.run_with_all_model_types
90 | # @keras_parameterized.run_all_keras_modes
91 | @parameterized.named_parameters(
92 | *tf_test_util.generate_combinations_with_testcase_name(
93 | backbone=[
94 | 'resnet50',
95 | 'resnet101',
96 | 'resnet152',
97 | 'resnet50v2',
98 | 'resnet101v2',
99 | 'resnet152v2',
100 | # 'resnext50',
101 | # 'resnext101',
102 | 'vgg16',
103 | 'vgg19',
104 | 'densenet121',
105 | 'densenet169',
106 | 'densenet201',
107 | 'mobilenet',
108 | 'mobilenetv2',
109 | 'efficientnetb0',
110 | 'efficientnetb1',
111 | 'efficientnetb2',
112 | 'efficientnetb3',
113 | 'efficientnetb4',
114 | 'efficientnetb5',
115 | 'efficientnetb6',
116 | 'efficientnetb7',
117 | 'nasnet_large',
118 | 'nasnet_mobile']))
119 | def test_get_backbone(self, backbone):
120 | with self.cached_session():
121 | K.set_image_data_format('channels_last')
122 | inputs = Input(shape=(256, 256, 3))
123 | model, output_dict = backbone_utils.get_backbone(
124 | backbone, inputs, return_dict=True)
125 | assert isinstance(output_dict, dict)
126 | assert all(k.startswith('C') for k in output_dict)
127 | assert isinstance(model, Model)
128 |
129 | def test_invalid_backbone(self):
130 | inputs = Input(shape=(4, 2, 3))
131 | with self.assertRaises(ValueError):
132 | backbone_utils.get_backbone('bad', inputs, return_dict=True)
133 |
134 |
135 | if __name__ == '__main__':
136 | test.main()
137 |
--------------------------------------------------------------------------------
/tests/deepcell_layers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of
2 | # Technology (Caltech), with support from the Paul Allen Family Foundation,
3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
4 | # All rights reserved.
5 | #
6 | # Licensed under a modified Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
11 | #
12 | # The Work provided may be used for non-commercial academic purposes only.
13 | # For any other use of the Work, including commercial use, please contact:
14 | # vanvalenlab@gmail.com
15 | #
16 | # Neither the name of Caltech nor the names of its contributors may be used
17 | # to endorse or promote products derived from this software without specific
18 | # prior written permission.
19 | #
20 | # Unless required by applicable law or agreed to in writing, software
21 | # distributed under the License is distributed on an "AS IS" BASIS,
22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | # See the License for the specific language governing permissions and
24 | # limitations under the License.
25 | # ==============================================================================
26 | """Tests for the upsampling layers"""
27 |
28 | import numpy as np
29 | import tensorflow as tf
30 | from tensorflow.keras import backend as K
31 | from absl.testing import parameterized
32 | from keras import keras_parameterized
33 | from keras import testing_utils
34 | from tensorflow.python.framework import test_util as tf_test_util
35 | from tensorflow.python.platform import test
36 | from tensorflow.keras.utils import custom_object_scope
37 |
38 | from deepcell import layers
39 |
40 |
41 | @keras_parameterized.run_all_keras_modes
42 | class TestUpsampleLike(keras_parameterized.TestCase):
43 |
44 | def test_simple(self):
45 | # channels_last
46 | # create simple UpsampleLike layer
47 | upsample_like_layer = layers.UpsampleLike()
48 |
49 | # create input source
50 | source = np.zeros((1, 2, 2, 1), dtype=K.floatx())
51 | source = K.variable(source)
52 | target = np.zeros((1, 5, 5, 1), dtype=K.floatx())
53 | expected = target
54 | target = K.variable(target)
55 |
56 | # compute output
57 | computed_shape = upsample_like_layer.compute_output_shape(
58 | [source.shape, target.shape])
59 |
60 | actual = upsample_like_layer.call([source, target])
61 | actual = K.get_value(actual)
62 |
63 | self.assertEqual(actual.shape, computed_shape)
64 | self.assertAllEqual(actual, expected)
65 | # channels_first
66 | # create simple UpsampleLike layer
67 | upsample_like_layer = layers.UpsampleLike(
68 | data_format='channels_first')
69 |
70 | # create input source
71 | source = np.zeros((1, 1, 2, 2), dtype=K.floatx())
72 | source = K.variable(source)
73 | target = np.zeros((1, 1, 5, 5), dtype=K.floatx())
74 | expected = target
75 | target = K.variable(target)
76 |
77 | # compute output
78 | computed_shape = upsample_like_layer.compute_output_shape(
79 | [source.shape, target.shape])
80 | actual = upsample_like_layer.call([source, target])
81 | actual = K.get_value(actual)
82 |
83 | self.assertEqual(actual.shape, computed_shape)
84 | self.assertAllEqual(actual, expected)
85 |
86 | def test_simple_3d(self):
87 | # create simple UpsampleLike layer
88 | upsample_like_layer = layers.UpsampleLike()
89 |
90 | # create input source
91 | source = np.zeros((1, 2, 2, 2, 1), dtype=K.floatx())
92 | source = K.variable(source)
93 | target = np.zeros((1, 5, 5, 5, 1), dtype=K.floatx())
94 | expected = target
95 | target = K.variable(target)
96 |
97 | # compute output
98 | computed_shape = upsample_like_layer.compute_output_shape(
99 | [source.shape, target.shape])
100 |
101 | actual = upsample_like_layer.call([source, target])
102 | actual = K.get_value(actual)
103 |
104 | self.assertEqual(actual.shape, computed_shape)
105 | self.assertAllEqual(actual, expected)
106 |
107 | # channels_first
108 | # create simple UpsampleLike layer
109 | upsample_like_layer = layers.UpsampleLike(
110 | data_format='channels_first')
111 |
112 | # create input source
113 | source = np.zeros((1, 1, 2, 2, 2), dtype=K.floatx())
114 | source = K.variable(source)
115 | target = np.zeros((1, 1, 5, 5, 5), dtype=K.floatx())
116 | expected = target
117 | target = K.variable(target)
118 |
119 | # compute output
120 | computed_shape = upsample_like_layer.compute_output_shape(
121 | [source.shape, target.shape])
122 | actual = upsample_like_layer.call([source, target])
123 | actual = K.get_value(actual)
124 |
125 | self.assertEqual(actual.shape, computed_shape)
126 | self.assertAllEqual(actual, expected)
127 |
128 | def test_mini_batch(self):
129 | # create simple UpsampleLike layer
130 | upsample_like_layer = layers.UpsampleLike()
131 |
132 | # create input source
133 | source = np.zeros((2, 2, 2, 1), dtype=K.floatx())
134 | source = K.variable(source)
135 |
136 | target = np.zeros((2, 5, 5, 1), dtype=K.floatx())
137 | expected = target
138 | target = K.variable(target)
139 |
140 | # compute output
141 | actual = upsample_like_layer.call([source, target])
142 | actual = K.get_value(actual)
143 |
144 | self.assertAllEqual(actual, expected)
145 |
146 |
147 | @keras_parameterized.run_all_keras_modes
148 | @parameterized.named_parameters(
149 | *tf_test_util.generate_combinations_with_testcase_name(
150 | norm_method=[None, 'std', 'max', 'whole_image']))
151 | class ImageNormalizationTest(keras_parameterized.TestCase):
152 |
153 | def test_normalize_2d(self, norm_method):
154 | custom_objects = {'ImageNormalization2D': layers.ImageNormalization2D}
155 | with tf.keras.utils.custom_object_scope(custom_objects):
156 | testing_utils.layer_test(
157 | layers.ImageNormalization2D,
158 | kwargs={'norm_method': norm_method,
159 | 'filter_size': 3,
160 | 'data_format': 'channels_last'},
161 | input_shape=(3, 5, 6, 4))
162 | testing_utils.layer_test(
163 | layers.ImageNormalization2D,
164 | kwargs={'norm_method': norm_method,
165 | 'filter_size': 3,
166 | 'data_format': 'channels_first'},
167 | input_shape=(3, 4, 5, 6))
168 | # test constraints and bias
169 | k_constraint = tf.keras.constraints.max_norm(0.01)
170 | b_constraint = tf.keras.constraints.max_norm(0.01)
171 | layer = layers.ImageNormalization2D(
172 | use_bias=True,
173 | kernel_constraint=k_constraint,
174 | bias_constraint=b_constraint)
175 | layer(tf.keras.backend.variable(np.ones((3, 5, 6, 4))))
176 | # self.assertEqual(layer.kernel.constraint, k_constraint)
177 | # self.assertEqual(layer.bias.constraint, b_constraint)
178 | # test bad norm_method
179 | with self.assertRaises(ValueError):
180 | layer = layers.ImageNormalization2D(norm_method='invalid')
181 | # test bad input dimensions
182 | with self.assertRaises(ValueError):
183 | layer = layers.ImageNormalization2D()
184 | layer.build([3, 10, 11, 12, 4])
185 | # test invalid channel
186 | with self.assertRaises(ValueError):
187 | layer = layers.ImageNormalization2D()
188 | layer.build([3, 5, 6, None])
189 |
190 |
191 | @keras_parameterized.run_all_keras_modes
192 | class LocationTest(keras_parameterized.TestCase):
193 |
194 | def test_location_2d(self):
195 | with custom_object_scope({'Location2D': layers.Location2D}):
196 | testing_utils.layer_test(
197 | layers.Location2D,
198 | kwargs={'data_format': 'channels_last'},
199 | input_shape=(3, 5, 6, 4))
200 | testing_utils.layer_test(
201 | layers.Location2D,
202 | kwargs={'in_shape': (4, 5, 6),
203 | 'data_format': 'channels_first'},
204 | input_shape=(3, 4, 5, 6))
205 |
--------------------------------------------------------------------------------
/tests/deepcell_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016-2023 The Van Valen Lab at the California Institute of
2 | # Technology (Caltech), with support from the Paul Allen Family Foundation,
3 | # Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
4 | # All rights reserved.
5 | #
6 | # Licensed under a modified Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
11 | #
12 | # The Work provided may be used for non-commercial academic purposes only.
13 | # For any other use of the Work, including commercial use, please contact:
14 | # vanvalenlab@gmail.com
15 | #
16 | # Neither the name of Caltech nor the names of its contributors may be used
17 | # to endorse or promote products derived from this software without specific
18 | # prior written permission.
19 | #
20 | # Unless required by applicable law or agreed to in writing, software
21 | # distributed under the License is distributed on an "AS IS" BASIS,
22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | # See the License for the specific language governing permissions and
24 | # limitations under the License.
25 | # ==============================================================================
26 | """Tests for misc_utils"""
27 |
28 | from tensorflow.python.platform import test
29 |
30 | from deepcell.utils import get_sorted_keys
31 |
32 |
33 | class MiscUtilsTest(test.TestCase):
34 |
35 | def test_get_sorted_keys(self):
36 | d = {'C1': 1, 'C3': 2, 'C2': 3}
37 | self.assertListEqual(get_sorted_keys(d), ['C1', 'C2', 'C3'])
38 |
--------------------------------------------------------------------------------
/tests/inference_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pytest
4 | import tempfile
5 | import numpy as np
6 | from skimage import io
7 | from cell_classification.application import Nimbus
8 | from segmentation_data_prep_test import prepare_test_data_folders, prep_object_and_inputs
9 | from cell_classification.inference import calculate_normalization, prepare_normalization_dict
10 | from cell_classification.inference import prepare_input_data, segment_mean, predict_fovs
11 | from cell_classification.inference import test_time_aug as tt_aug
12 |
13 |
14 | def test_calculate_normalization():
15 | with tempfile.TemporaryDirectory() as temp_dir:
16 | fov_paths = prepare_test_data_folders(
17 | 1, temp_dir, ["CD4"], random=True,
18 | scale=[0.5]
19 | )
20 | channel = "CD4"
21 | channel_path = os.path.join(fov_paths[0], channel + ".tiff")
22 | channel_out, norm_val = calculate_normalization(channel_path, 0.999)
23 | # test if we get the correct channel and normalization value
24 | assert channel_out == channel
25 | assert np.isclose(norm_val, 0.5, 0.01)
26 |
27 |
28 | def test_prepare_normalization_dict():
29 | with tempfile.TemporaryDirectory() as temp_dir:
30 | scales = [0.5, 1.0, 1.5, 2.0, 5.0]
31 | channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"]
32 | fov_paths = prepare_test_data_folders(
33 | 5, temp_dir, channels, random=True,
34 | scale=scales
35 | )
36 | normalization_dict = prepare_normalization_dict(
37 | fov_paths, temp_dir, quantile=0.999, exclude_channels=["CD57"], n_subset=10, n_jobs=1,
38 | output_name="normalization_dict.json"
39 | )
40 | # test if normalization dict got saved
41 | assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json"))
42 | assert normalization_dict == json.load(
43 | open(os.path.join(temp_dir, "normalization_dict.json"))
44 | )
45 | # test if channel got excluded
46 | assert "CD57" not in normalization_dict.keys()
47 | # test if normalization dict is correct
48 | for channel, scale in zip(channels, scales):
49 | if channel == "CD57":
50 | continue
51 | assert np.isclose(normalization_dict[channel], scale, 0.01)
52 |
53 | # test if multiprocessing yields approximately the same results
54 | normalization_dict_mp = prepare_normalization_dict(
55 | fov_paths, temp_dir, quantile=0.999, exclude_channels=["CD57"], n_subset=10, n_jobs=2,
56 | output_name="normalization_dict.json"
57 | )
58 | for key in normalization_dict.keys():
59 | assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6)
60 |
61 |
62 | def test_prepare_input_data():
63 | with tempfile.TemporaryDirectory() as temp_dir:
64 | scales = [0.5]
65 | channels = ["CD4"]
66 | fov_paths = prepare_test_data_folders(
67 | 1, temp_dir, channels, random=True,
68 | scale=scales
69 | )
70 | mplex_img = io.imread(os.path.join(fov_paths[0], "CD4.tiff"))
71 | instance_mask = io.imread(os.path.join(fov_paths[0], "cell_segmentation.tiff"))
72 | input_data = prepare_input_data(mplex_img, instance_mask)
73 | # check shape
74 | assert input_data.shape == (1, 256, 256, 2)
75 | # check if instance mask got binarized and eroded
76 | assert np.alltrue(np.unique(input_data[..., 1]) == np.array([0, 1]))
77 | assert np.sum(input_data[..., 1]) < np.sum(instance_mask)
78 | # check if mplex image is the same as before
79 | assert np.alltrue(input_data[0, ..., 0] == mplex_img)
80 |
81 |
82 | def test_segment_mean():
83 | with tempfile.TemporaryDirectory() as temp_dir:
84 | scales = [0.5]
85 | channels = ["CD4"]
86 | fov_paths = prepare_test_data_folders(
87 | 1, temp_dir, channels, random=True,
88 | scale=scales
89 | )
90 | mplex_img = io.imread(os.path.join(fov_paths[0], "CD4.tiff"))
91 | prediction = (mplex_img > 0.25).astype(np.float32)
92 | instance_mask = io.imread(os.path.join(fov_paths[0], "cell_segmentation.tiff"))
93 | instance_ids, mean_per_cell = segment_mean(instance_mask, prediction)
94 | # check if we get the correct number of cells
95 | assert len(instance_ids) == len(np.unique(instance_mask)[1:])
96 | # check if we get the correct mean per cell
97 | for i in np.unique(instance_mask)[1:]:
98 | assert mean_per_cell[i-1] == np.mean(prediction[instance_mask == i])
99 |
100 |
101 | def test_tt_aug():
102 | with tempfile.TemporaryDirectory() as temp_dir:
103 | def segmentation_naming_convention(fov_path):
104 | return os.path.join(fov_path, "cell_segmentation.tiff")
105 |
106 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir)
107 | os.remove(os.path.join(temp_dir, 'normalization_dict.json'))
108 | output_dir = os.path.join(temp_dir, "nimbus_output")
109 | fov_paths = fov_paths[:1]
110 | nimbus = Nimbus(
111 | fov_paths, segmentation_naming_convention, output_dir,
112 | exclude_channels=["CD57", "CD11c", "XYZ"]
113 | )
114 | nimbus.prepare_normalization_dict()
115 | channel = "CD4"
116 | mplex_img = io.imread(os.path.join(fov_paths[0], channel+".tiff"))
117 | instance_mask = io.imread(os.path.join(fov_paths[0], "cell_segmentation.tiff"))
118 | input_data = prepare_input_data(mplex_img, instance_mask)
119 | pred_map = tt_aug(
120 | input_data, channel, nimbus, nimbus.normalization_dict, rotate=True, flip=True,
121 | batch_size=32
122 | )
123 | # check if we get the correct shape
124 | assert pred_map.shape == (256, 256, 1)
125 |
126 | pred_map_2 = tt_aug(
127 | input_data, channel, nimbus, nimbus.normalization_dict, rotate=False, flip=True,
128 | batch_size=32
129 | )
130 | pred_map_3 = tt_aug(
131 | input_data, channel, nimbus, nimbus.normalization_dict, rotate=True, flip=False,
132 | batch_size=32
133 | )
134 | pred_map_no_tt_aug = nimbus._predict_segmentation(
135 | input_data,
136 | batch_size=1,
137 | preprocess_kwargs={
138 | "normalize": True,
139 | "marker": channel,
140 | "normalization_dict": nimbus.normalization_dict},
141 | )
142 | # check if we get roughly the same results for non augmented and augmented predictions
143 | assert np.allclose(pred_map, pred_map_no_tt_aug, atol=0.05)
144 | assert np.allclose(pred_map_2, pred_map_no_tt_aug, atol=0.05)
145 | assert np.allclose(pred_map_3, pred_map_no_tt_aug, atol=0.05)
146 |
147 |
148 | def test_predict_fovs():
149 | with tempfile.TemporaryDirectory() as temp_dir:
150 | def segmentation_naming_convention(fov_path):
151 | return os.path.join(fov_path, "cell_segmentation.tiff")
152 |
153 | exclude_channels = ["CD57", "CD11c", "XYZ"]
154 | _, fov_paths, _, _ = prep_object_and_inputs(temp_dir)
155 | os.remove(os.path.join(temp_dir, 'normalization_dict.json'))
156 | output_dir = os.path.join(temp_dir, "nimbus_output")
157 | fov_paths = fov_paths[:1]
158 | nimbus = Nimbus(
159 | fov_paths, segmentation_naming_convention, output_dir,
160 | exclude_channels=exclude_channels
161 | )
162 | output_dir = os.path.join(temp_dir, "nimbus_output")
163 | nimbus.prepare_normalization_dict()
164 | cell_table = predict_fovs(
165 | fov_paths, output_dir, nimbus, nimbus.normalization_dict,
166 | segmentation_naming_convention, exclude_channels=exclude_channels,
167 | save_predictions=False, half_resolution=True,
168 | )
169 | # check if we get the correct number of cells
170 | assert len(cell_table) == 15
171 | # check if we get the correct columns (fov, label, CD4_pred, CD56_pred)
172 | assert np.alltrue(
173 | set(cell_table.columns) == set(["fov", "label", "CD4_pred", "CD56_pred"])
174 | )
175 | # check if predictions don't get written to output_dir
176 | assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff"))
177 | assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff"))
178 | #
179 | # run again with save_predictions=True and check if predictions get written to output_dir
180 | cell_table = predict_fovs(
181 | fov_paths, output_dir, nimbus, nimbus.normalization_dict,
182 | segmentation_naming_convention, exclude_channels=exclude_channels,
183 | save_predictions=True, half_resolution=True, test_time_augmentation=False,
184 | )
185 | assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff"))
186 | assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff"))
187 |
--------------------------------------------------------------------------------
/tests/loss_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from cell_classification.loss import Loss
5 |
6 |
7 | def test_loss():
8 | pred = tf.constant(np.random.rand(1, 256, 256, 1))
9 | target = tf.constant(np.random.randint(0, 2, size=(1, 256, 256, 1)))
10 | loss_fn = Loss("BinaryCrossentropy", False)
11 | loss = loss_fn(target, pred)
12 |
13 | # check if loss has the right shape
14 | assert isinstance(loss, tf.Tensor)
15 | assert loss.shape == (1, 256, 256)
16 | # check if loss is in the right range
17 | assert tf.reduce_mean(loss).numpy() >= 0
18 |
19 | # check if wrapper works with label_smoothing
20 | loss_fn = Loss("BinaryCrossentropy", False, label_smoothing=0.1)
21 | loss_smoothed = loss_fn(target, pred)
22 |
23 | # check if loss is a scalar
24 | assert isinstance(loss_smoothed, tf.Tensor)
25 | assert loss_smoothed.shape == (1, 256, 256)
26 | # check if loss is in the right range
27 | assert tf.reduce_mean(loss_smoothed).numpy() >= 0
28 | # check if smoothed loss is reasonably near the unsmoothed loss
29 | assert np.isclose(loss_smoothed.numpy().mean(), loss.numpy().mean(), atol=0.1)
30 |
31 | # check if loss is zero when target and pred are equal
32 | loss_fn = Loss("BinaryCrossentropy", False)
33 | loss = loss_fn(target, tf.cast(target, tf.float32))
34 | assert loss.numpy().mean() == 0
35 |
36 | # check if loss is zero when target is 2
37 | loss_fn = Loss("BinaryCrossentropy", True)
38 | loss = loss_fn(2 * np.ones_like(target), pred)
39 | assert loss.numpy().mean() == 0
40 |
41 | # check if works with FocalLoss
42 | loss_fn = Loss("BinaryFocalCrossentropy", True, label_smoothing=0.1, gamma=2)
43 | loss = loss_fn(target, pred)
44 | assert loss.numpy().mean() >= 0
45 |
46 | # check if config is returned correctly
47 | config = loss_fn.get_config()
48 | assert config["loss_name"] == "BinaryFocalCrossentropy"
49 | assert config["label_smoothing"] == 0.1
50 | assert config["gamma"] == 2
51 | assert config["selective_masking"]
52 |
--------------------------------------------------------------------------------
/tests/metrics_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import wandb
7 |
8 | from cell_classification.metrics import (HDF5Loader, average_roc, calc_metrics,
9 | calc_roc)
10 | from cell_classification.model_builder import ModelBuilder
11 |
12 | from .segmentation_data_prep_test import prep_object_and_inputs
13 |
14 |
15 | def make_pred_list():
16 | pred_list = []
17 | for i in range(10):
18 | instance_mask = np.random.randint(0, 10, size=(256, 256, 1))
19 | binary_mask = (instance_mask > 0).astype(np.uint8)
20 | activity_df = pd.DataFrame(
21 | {
22 | "labels": np.array([1, 2, 5, 7, 9, 11], dtype=np.uint16),
23 | "activity": [1, 0, 0, 2, 0, 1],
24 | "cell_type": ["T cell", "B cell", "T cell", "B cell", "T cell", "B cell"],
25 | "sample": [str(i)]*6,
26 | "imaging_platform": ["test"]*6,
27 | "dataset": ["test"]*6,
28 | "marker": ["CD4"]*6 if i % 2 == 0 else "CD8",
29 | "prediction": np.random.rand(6),
30 | }
31 | )
32 | pred_list.append(
33 | {
34 | "marker_activity_mask": np.random.randint(0, 2, (256, 256, 1))*binary_mask,
35 | "prediction": np.random.rand(256, 256, 1),
36 | "instance_mask": instance_mask,
37 | "binary_mask": binary_mask,
38 | "dataset": "test", "imaging_platform": "test",
39 | "marker": "CD4" if i % 2 == 0 else "CD8",
40 | "activity_df": activity_df,
41 | }
42 | )
43 | pred_list[-1]["marker_activity_mask"] = np.zeros((256, 256, 1))
44 | return pred_list
45 |
46 |
47 | def test_calc_roc():
48 | pred_list = make_pred_list()
49 | roc = calc_roc(pred_list)
50 |
51 | # check if roc has the right keys
52 | assert set(roc.keys()) == set(["fpr", "tpr", "auc", "thresholds", "marker"])
53 |
54 | # check if roc has the right number of items
55 | assert len(roc["fpr"]) == len(roc["tpr"]) == len(roc["thresholds"]) == len(roc["auc"]) == 9
56 |
57 |
58 | def test_calc_metrics():
59 | pred_list = make_pred_list()
60 | avg_metrics = calc_metrics(pred_list)
61 | keys = [
62 | "accuracy", "precision", "recall", "specificity", "f1_score", "tp", "tn", "fp", "fn",
63 | "dataset", "imaging_platform", "marker", "threshold",
64 | ]
65 |
66 | # check if avg_metrics has the right keys
67 | assert set(avg_metrics.keys()) == set(keys)
68 |
69 | # check if avg_metrics has the right number of items
70 | assert (
71 | len(avg_metrics["accuracy"]) == len(avg_metrics["precision"])
72 | == len(avg_metrics["recall"]) == len(avg_metrics["f1_score"])
73 | == len(avg_metrics["tp"]) == len(avg_metrics["tn"])
74 | == len(avg_metrics["fp"]) == len(avg_metrics["fn"])
75 | == len(avg_metrics["dataset"]) == len(avg_metrics["imaging_platform"])
76 | == len(avg_metrics["marker"]) == len(avg_metrics["threshold"]) == 50
77 | )
78 |
79 |
80 | def test_average_roc():
81 | pred_list = make_pred_list()
82 | roc_list = calc_roc(pred_list)
83 | tprs, mean_tprs, base, std, mean_thresh = average_roc(roc_list)
84 |
85 | # check if mean_tprs, base, std and mean_thresh have the same length
86 | assert len(mean_tprs) == len(base) == len(std) == len(mean_thresh) == tprs.shape[1]
87 |
88 | # check if mean and std give reasonable results
89 | assert np.array_equal(np.mean(tprs, axis=0), mean_tprs)
90 | assert np.array_equal(np.std(tprs, axis=0), std)
91 |
92 |
93 | def test_HDF5Generator(config_params):
94 | with tempfile.TemporaryDirectory() as temp_dir:
95 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir)
96 | data_prep.tf_record_path = temp_dir
97 | data_prep.make_tf_record()
98 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord")
99 | config_params["record_path"] = [tf_record_path]
100 | config_params["path"] = temp_dir
101 | config_params["experiment"] = "test"
102 | config_params["dataset_names"] = ["test1"]
103 | config_params["num_steps"] = 2
104 | config_params["dataset_sample_probs"] = [1.0]
105 | config_params["num_validation"] = [2]
106 | config_params["num_test"] = [2]
107 | config_params["snap_steps"] = 100
108 | config_params["val_steps"] = 100
109 | model = ModelBuilder(config_params)
110 | model.train()
111 | model.predict_dataset(model.validation_datasets[0], save_predictions=True)
112 | wandb.finish()
113 | generator = HDF5Loader(model.params['eval_dir'])
114 |
115 | # check if generator has the right number of items
116 | assert [len(generator)] == config_params['num_validation']
117 |
118 | # check if generator returns the right items
119 | for sample in generator:
120 | assert isinstance(sample, dict)
121 | assert set(list(sample.keys())) == set([
122 | 'binary_mask', 'dataset', 'folder_name', 'imaging_platform', 'instance_mask',
123 | 'marker', 'marker_activity_mask', 'membrane_img', 'mplex_img', 'nuclei_img',
124 | 'prediction', 'prediction_mean', 'activity_df', 'tissue_type'
125 | ])
126 |
--------------------------------------------------------------------------------
/tests/plot_utils_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import tensorflow as tf
5 |
6 | from cell_classification.metrics import average_roc, calc_metrics, calc_roc
7 | from cell_classification.plot_utils import (collapse_activity_dfs,
8 | heatmap_plot, plot_average_roc,
9 | plot_metrics_against_threshold,
10 | plot_overlay, plot_together,
11 | subset_activity_df, subset_plots)
12 | from cell_classification.segmentation_data_prep import (feature_description,
13 | parse_dict)
14 |
15 | from .metrics_test import make_pred_list
16 | from .segmentation_data_prep_test import prep_object_and_inputs
17 |
18 |
19 | def prepare_dataset(temp_dir):
20 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir)
21 | data_prep.tf_record_path = temp_dir
22 | data_prep.make_tf_record()
23 | tf_record_path = os.path.join(temp_dir, data_prep.dataset + ".tfrecord")
24 | dataset = tf.data.TFRecordDataset(tf_record_path)
25 | return dataset
26 |
27 |
28 | def test_plot_overlay():
29 | with tempfile.TemporaryDirectory() as temp_dir:
30 | dataset = iter(prepare_dataset(temp_dir))
31 | plot_path = os.path.join(temp_dir, "plots")
32 | os.makedirs(plot_path, exist_ok=True)
33 | record = next(dataset)
34 | example_encoded = tf.io.parse_single_example(record, feature_description)
35 | example = parse_dict(example_encoded)
36 | plot_overlay(
37 | example, save_dir=plot_path, save_file=f"{example['folder_name']}_overlay.png"
38 | )
39 |
40 | # check if plot was saved
41 | assert os.path.exists(os.path.join(plot_path, f"{example['folder_name']}_overlay.png"))
42 |
43 |
44 | def test_plot_together():
45 | with tempfile.TemporaryDirectory() as temp_dir:
46 | dataset = iter(prepare_dataset(temp_dir))
47 | plot_path = os.path.join(temp_dir, "plots")
48 | os.makedirs(plot_path, exist_ok=True)
49 | record = next(dataset)
50 | example_encoded = tf.io.parse_single_example(record, feature_description)
51 | example = parse_dict(example_encoded)
52 | plot_together(
53 | example, ["mplex_img", "nuclei_img", "marker_activity_mask"], save_dir=plot_path,
54 | save_file=f"{example['folder_name']}_together.png"
55 | )
56 |
57 | # check if plot was saved
58 | assert os.path.exists(os.path.join(plot_path, f"{example['folder_name']}_together.png"))
59 |
60 |
61 | def test_plot_average_roc():
62 | with tempfile.TemporaryDirectory() as temp_dir:
63 | pred_list = make_pred_list()
64 | roc_list = calc_roc(pred_list)
65 | tprs, mean_tprs, base, std, mean_thresh = average_roc(roc_list)
66 | plot_path = os.path.join(temp_dir, "plots")
67 | os.makedirs(plot_path, exist_ok=True)
68 | plot_average_roc(mean_tprs, std, save_dir=plot_path, save_file="average_roc.png")
69 |
70 | # check if plot was saved
71 | assert os.path.exists(os.path.join(plot_path, "average_roc.png"))
72 |
73 |
74 | def test_plot_metrics_against_threshold():
75 | with tempfile.TemporaryDirectory() as temp_dir:
76 | pred_list = make_pred_list()
77 | avg_metrics = calc_metrics(pred_list)
78 | plot_path = os.path.join(temp_dir, "plots")
79 | os.makedirs(plot_path, exist_ok=True)
80 | plot_metrics_against_threshold(
81 | avg_metrics, metric_keys=['precision', 'recall', 'f1_score'], save_dir=plot_path,
82 | threshold_key="threshold", save_file="metrics_against_threshold.png"
83 | )
84 |
85 | # check if plot was saved
86 | assert os.path.exists(os.path.join(plot_path, "metrics_against_threshold.png"))
87 |
88 |
89 | def test_collapse_activity_dfs():
90 | pred_list = make_pred_list()
91 | df = collapse_activity_dfs(pred_list)
92 |
93 | # check if df has the right shape and keys
94 | assert df.shape == (len(pred_list)*6, 8)
95 | assert set(df.columns) == set(pred_list[0]["activity_df"].columns)
96 |
97 |
98 | def test_subset_activity_df():
99 | pred_list = make_pred_list()
100 | df = collapse_activity_dfs(pred_list)
101 | cd4_subset = subset_activity_df(df, {"marker": "CD4"})
102 | cd4_tcells_subset = subset_activity_df(df, {"marker": "CD4", "cell_type": "T cell"})
103 |
104 | # check if cd4_subset and cd8_subset have the right shape
105 | assert set(cd4_subset["marker"]) == set(["CD4"])
106 | assert cd4_subset.shape == df[df.marker == "CD4"].shape
107 |
108 | # check if cd4_tcells_subset has the right shape
109 | assert set(cd4_tcells_subset["marker"]) == set(["CD4"])
110 | assert set(cd4_tcells_subset["cell_type"]) == set(["T cell"])
111 | assert cd4_tcells_subset.shape == df[(df.marker == "CD4") & (df.cell_type == "T cell")].shape
112 |
113 |
114 | def test_subset_plots():
115 | pred_list = make_pred_list()
116 | activity_df = collapse_activity_dfs(pred_list)
117 | with tempfile.TemporaryDirectory() as temp_dir:
118 | plot_path = os.path.join(temp_dir, "plots")
119 | os.makedirs(plot_path, exist_ok=True)
120 | subset_plots(
121 | activity_df, subset_list=["marker"], save_dir=plot_path,
122 | save_file="split_by_marker.png"
123 | )
124 | subset_plots(
125 | activity_df, subset_list=["marker", "cell_type"], save_dir=plot_path,
126 | save_file="split_by_marker_ct.png"
127 | )
128 |
129 | # check if plots were saved
130 | assert os.path.exists(os.path.join(plot_path, "split_by_marker.png"))
131 |
132 |
133 | def test_heatmap_plot():
134 | pred_list = make_pred_list()
135 | activity_df = collapse_activity_dfs(pred_list)
136 | with tempfile.TemporaryDirectory() as temp_dir:
137 | plot_path = os.path.join(temp_dir, "plots")
138 | os.makedirs(plot_path, exist_ok=True)
139 | heatmap_plot(
140 | activity_df, ["marker"], save_dir=plot_path, save_file="heatmap.png"
141 | )
142 |
143 | # check if plot was saved
144 | assert os.path.exists(os.path.join(plot_path, "heatmap.png"))
145 |
--------------------------------------------------------------------------------
/tests/post_processing_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from cell_classification.post_processing import (merge_activity_df,
5 | process_to_cells)
6 |
7 |
8 | def test_process_to_cells():
9 | # prepare data
10 | instance_mask = np.random.randint(0, 10, size=(256, 256, 1))
11 | prediction = np.random.rand(256, 256, 1)
12 | prediction[instance_mask == 1] = 1.0
13 | prediction[instance_mask == 2] = 0.0
14 | prediction_mean, activity_df = process_to_cells(instance_mask, prediction)
15 |
16 | # check types and shape
17 | assert isinstance(activity_df, pd.DataFrame)
18 | assert isinstance(prediction_mean, np.ndarray)
19 | assert len(activity_df.pred_activity) == 9
20 | assert prediction_mean.shape == (256, 256, 1)
21 |
22 | # check values
23 | assert prediction_mean[instance_mask == 1].mean() == 1.0
24 | assert prediction_mean[instance_mask == 2].mean() == 0.0
25 | assert activity_df.pred_activity[0] == 1.0
26 | assert activity_df.pred_activity[1] == 0.0
27 |
28 |
29 | def test_merge_activity_df():
30 | # prepare data
31 | pred_df = pd.DataFrame({"labels": list(range(1, 10)), "pred_activity": np.random.rand(9)})
32 | gt_df = pd.DataFrame(
33 | {"labels": list(range(1, 10)), "gt_activity": np.random.randint(0, 2, 9)}
34 | )
35 | merged_df = merge_activity_df(gt_df, pred_df)
36 |
37 | # check if columns are there and no labels got lost in the merge
38 | assert np.array_equal(merged_df.labels, gt_df.labels)
39 | assert set(['labels', 'gt_activity', "pred_activity"]) == set(merged_df.columns)
40 |
--------------------------------------------------------------------------------
/tests/promix_naive_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import tensorflow as tf
7 | import toml
8 |
9 | import wandb
10 | from cell_classification.promix_naive import PromixNaive
11 |
12 | from segmentation_data_prep_test import prep_object_and_inputs
13 |
14 |
15 | def test_reduce_to_cells(config_params):
16 | config_params["test"] = True
17 | pred = np.random.rand(16, 256, 266)
18 | instance_mask = np.random.randint(0, 100, (16, 256, 266))
19 | instance_mask[-1, instance_mask[-1] == 1] = 0
20 | marker_activity_mask = np.zeros_like(instance_mask)
21 | marker_activity_mask[instance_mask > 90] = 1
22 | trainer = PromixNaive(config_params)
23 | uniques, mean_per_cell = tf.map_fn(
24 | trainer.reduce_to_cells,
25 | (pred, instance_mask),
26 | infer_shape=False,
27 | fn_output_signature=[
28 | tf.RaggedTensorSpec(shape=[None], dtype=tf.int32, ragged_rank=0),
29 | tf.RaggedTensorSpec(shape=[None], dtype=tf.float32, ragged_rank=0),
30 | ],
31 | )
32 |
33 | # check that the output has the right dimension
34 | assert uniques.shape[0] == instance_mask.shape[0]
35 | assert mean_per_cell.shape[0] == instance_mask.shape[0]
36 |
37 | # check that the output is correct
38 | assert set(np.unique(instance_mask[0])) == set(uniques[0].numpy())
39 | for i in np.unique(instance_mask[0]):
40 | assert np.isclose(
41 | np.mean(pred[0][instance_mask[0] == i]),
42 | mean_per_cell[0][uniques[0] == i].numpy().max(),
43 | )
44 |
45 |
46 | def test_matched_high_confidence_selection_thresholds(config_params):
47 | config_params["test"] = True
48 | trainer = PromixNaive(config_params)
49 | trainer.matched_high_confidence_selection_thresholds()
50 | thresholds = trainer.confidence_loss_thresholds
51 | # check that the output has the right dimension
52 | assert len(thresholds) == 2
53 | assert thresholds["positive"] > 0.0
54 | assert thresholds["negative"] > 0.0
55 |
56 |
57 | def test_train(config_params):
58 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
59 | with tempfile.TemporaryDirectory() as temp_dir:
60 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir)
61 | data_prep.tf_record_path = temp_dir
62 | data_prep.make_tf_record()
63 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord")
64 | config_params["record_path"] = [tf_record_path]
65 | config_params["path"] = temp_dir
66 | config_params["experiment"] = "test"
67 | config_params["dataset_names"] = ["test1"]
68 | config_params["dataset_sample_probs"] = [1.0]
69 | config_params["num_steps"] = 7
70 | config_params["num_validation"] = [2]
71 | config_params["num_test"] = [2]
72 | config_params["batch_size"] = 2
73 | config_params["test"] = True
74 | config_params["weight_decay"] = 1e-4
75 | config_params["snap_steps"] = 5
76 | config_params["val_steps"] = 5
77 | config_params["quantile"] = 0.3
78 | config_params["ema"] = 0.01
79 | config_params["confidence_thresholds"] = [0.1, 0.9]
80 | config_params["mixup_prob"] = 0.5
81 | trainer = PromixNaive(config_params)
82 | trainer.train()
83 | wandb.finish()
84 | # check params.toml is dumped to file and contains the created paths
85 | assert "params.toml" in os.listdir(trainer.params["model_dir"])
86 | loaded_params = toml.load(os.path.join(trainer.params["model_dir"], "params.toml"))
87 | for key in ["model_dir", "log_dir", "model_path"]:
88 | assert key in list(loaded_params.keys())
89 |
90 | # check if model can be loaded from file
91 | trainer.model = None
92 | trainer.load_model(trainer.params["model_path"])
93 | assert isinstance(trainer.model, tf.keras.Model)
94 |
95 |
96 | def test_prep_data(config_params):
97 | with tempfile.TemporaryDirectory() as temp_dir:
98 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir)
99 | data_prep.tf_record_path = temp_dir
100 | data_prep.make_tf_record()
101 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord")
102 | config_params["record_path"] = tf_record_path
103 | config_params["path"] = temp_dir
104 | config_params["experiment"] = "test"
105 | config_params["dataset_names"] = ["test1"]
106 | config_params["dataset_sample_probs"] = [1.0]
107 | config_params["num_steps"] = 3
108 | config_params["num_validation"] = [2]
109 | config_params["num_test"] = [2]
110 | config_params["batch_size"] = 2
111 | trainer = PromixNaive(config_params)
112 | trainer.prep_data()
113 |
114 | # check if train and validation datasets exists and are of the right type
115 | assert isinstance(trainer.validation_datasets[0], tf.data.Dataset)
116 | assert isinstance(trainer.train_dataset, tf.data.Dataset)
117 |
118 |
119 | def prepare_activity_df():
120 | activity_df_list = []
121 | for i in range(4):
122 | activity_df = pd.DataFrame(
123 | {
124 | "labels": np.array([1, 2, 5, 7, 9, 11], dtype=np.uint16),
125 | "activity": [1, 0, 0, 0, 0, 1],
126 | "cell_type": ["T cell", "B cell", "T cell", "B cell", "T cell", "B cell"],
127 | "sample": [str(i)] * 6,
128 | "imaging_platform": ["test"] * 6,
129 | "dataset": ["test"] * 6,
130 | "marker": ["CD4"] * 6 if i % 2 == 0 else "CD8",
131 | "prediction": [0.9, 0.1, 0.1, 0.7, 0.7, 0.2],
132 | }
133 | )
134 | activity_df_list.append(activity_df)
135 | return activity_df_list
136 |
137 |
138 | def test_class_wise_loss_selection(config_params):
139 | config_params["test"] = True
140 | trainer = PromixNaive(config_params)
141 | activity_df_list = prepare_activity_df()
142 | df = activity_df_list[0]
143 | marker = df["marker"][0]
144 | dataset = df["dataset"][0]
145 |
146 | dataset_marker = dataset + "_" + marker
147 | trainer.class_wise_loss_quantiles[dataset_marker] = {"positive": 0.5, "negative": 0.5}
148 | df["loss"] = df.activity * df.prediction + (1 - df.activity) * (1 - df.prediction)
149 | positive_df = df[df["activity"] == 1]
150 | negative_df = df[df["activity"] == 0]
151 | selected_subset = trainer.class_wise_loss_selection(positive_df, negative_df, marker, dataset)
152 |
153 | # check that the output has the right dimension
154 | assert len(selected_subset) == 2
155 | assert len(selected_subset[0]) == 1
156 |
157 | # check that the output is correct and only those cells are selected that have a loss
158 | # smaller than the threshold
159 | assert selected_subset[0].equals(
160 | df[df["activity"] == 1].loc[
161 | df["loss"] <= trainer.class_wise_loss_quantiles[dataset_marker]["positive"]
162 | ]
163 | )
164 | assert selected_subset[1].equals(
165 | df[df["activity"] == 0].loc[
166 | df["loss"] <= trainer.class_wise_loss_quantiles[dataset_marker]["negative"]
167 | ]
168 | )
169 |
170 | # check if quantiles got updated
171 | assert trainer.class_wise_loss_quantiles[dataset_marker]["positive"] != 0.5
172 | assert trainer.class_wise_loss_quantiles[dataset_marker]["negative"] != 0.5
173 |
174 |
175 | def test_matched_high_confidence_selection(config_params):
176 | trainer = PromixNaive(config_params)
177 | activity_df_list = prepare_activity_df()
178 | df = activity_df_list[0]
179 | df["loss"] = df.activity * df.prediction + (1 - df.activity) * (1 - df.prediction)
180 | positive_df = df[df["activity"] == 1]
181 | negative_df = df[df["activity"] == 0]
182 | mark = df["marker"][0]
183 | trainer.matched_high_confidence_selection_thresholds()
184 | selected_subset = trainer.matched_high_confidence_selection(positive_df, negative_df)
185 | df = pd.concat(selected_subset)
186 |
187 | # check that the output has the right dimension
188 | assert len(df) == 1
189 |
190 | # check that the output is correct and only those cells are selected that have a loss
191 | # smaller than the threshold
192 | gt_activity = "positive" if df["activity"].values[0] == 1 else "negative"
193 | assert (df.loss.values[0] <= trainer.confidence_loss_thresholds[gt_activity])
194 |
195 |
196 | def test_batchwise_loss_selection(config_params):
197 | config_params["test"] = True
198 | trainer = PromixNaive(config_params)
199 | trainer.matched_high_confidence_selection_thresholds()
200 | activity_df_list = prepare_activity_df()
201 | instance_mask = np.zeros([256, 256], dtype=np.uint8)
202 | i = 1
203 | for h in range(0, 260, 20):
204 | for w in range(0, 260, 20):
205 | instance_mask[h: h + 10, w: w + 10] = i
206 | i += 1
207 | dfs = activity_df_list[:2]
208 | mark = [tf.constant(str(df["marker"][0]).encode()) for df in dfs]
209 | dset = [tf.constant(str(df["dataset"][0]).encode()) for df in dfs]
210 |
211 | for df in dfs:
212 | df["loss"] = df.activity * df.prediction + (1 - df.activity) * (1 - df.prediction)
213 | instance_mask = instance_mask[np.newaxis, ..., np.newaxis]
214 | instance_mask = np.concatenate([instance_mask, instance_mask], axis=0)
215 | loss_mask = trainer.batchwise_loss_selection(dfs, instance_mask, mark, dset)
216 |
217 | # check that the output has the right dimension
218 | assert list(loss_mask.shape) == [2, 256, 256]
219 |
220 | # check that they are equal
221 | assert np.array_equal(loss_mask[0], loss_mask[1])
222 |
223 |
224 | def test_quantile_scheduler(config_params):
225 | config_params["test"] = True
226 | trainer = PromixNaive(config_params)
227 | quantile_start = config_params["quantile"]
228 | quantile_end = config_params["quantile_end"]
229 | quantile_warmup_steps = config_params["quantile_warmup_steps"]
230 | step_0_quantile = trainer.quantile_scheduler(0)
231 | step_half_warmpup_quantile = trainer.quantile_scheduler(quantile_warmup_steps // 2)
232 | step_warump_quantile = trainer.quantile_scheduler(quantile_warmup_steps)
233 | step_n_quantile = trainer.quantile_scheduler(quantile_warmup_steps*2)
234 |
235 | # check that the output has the expected values
236 | assert step_0_quantile == quantile_start
237 | assert step_half_warmpup_quantile == (quantile_start + quantile_end) / 2
238 | assert step_warump_quantile == quantile_end
239 | assert step_n_quantile == quantile_end
240 |
241 |
242 | def test_gen_prep_batches_promix_fn(config_params):
243 | with tempfile.TemporaryDirectory() as temp_dir:
244 | data_prep, _, _, _ = prep_object_and_inputs(temp_dir)
245 | data_prep.tf_record_path = temp_dir
246 | data_prep.make_tf_record()
247 | tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord")
248 | config_params["record_path"] = [tf_record_path]
249 | config_params["path"] = temp_dir
250 | config_params["batch_constituents"] = ["mplex_img", "binary_mask", "nuclei_img",
251 | "membrane_img"]
252 | config_params["experiment"] = "test"
253 | config_params["dataset_names"] = ["test1"]
254 | config_params["num_steps"] = 20
255 | config_params["dataset_sample_probs"] = [1.0]
256 | config_params["batch_size"] = 1
257 | config_params["test"] = True
258 | config_params["num_validation"] = [2]
259 | config_params["num_test"] = [2]
260 | trainer = PromixNaive(config_params)
261 | trainer.prep_data()
262 | example = next(iter(trainer.train_dataset))
263 | prep_batches_promix_4 = trainer.prep_batches_promix
264 | prep_batches_promix_2 = trainer.gen_prep_batches_promix_fn(
265 | keys=["mplex_img", "binary_mask"]
266 | )
267 |
268 | # check if each batch contains the above specified constituents
269 | batch_2 = prep_batches_promix_2(example)
270 | assert batch_2[0].shape[-1] == 1
271 | assert np.array_equal(batch_2[0], example["mplex_img"])
272 |
273 | batch_4 = prep_batches_promix_4(example)
274 | assert batch_4[0].shape[-1] == 3
275 | assert np.array_equal(batch_4[0], tf.concat([
276 | example["mplex_img"], example["nuclei_img"], example["membrane_img"]
277 | ], axis=-1)
278 | )
279 |
--------------------------------------------------------------------------------
/tests/simple_data_prep_test.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import tempfile
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from cell_classification.simple_data_prep import SimpleTFRecords
9 |
10 | from .segmentation_data_prep_test import prepare_test_data_folders
11 |
12 |
13 | def test_get_marker_activity():
14 | with tempfile.TemporaryDirectory() as temp_dir:
15 | norm_dict = {"CD11c": 1.0, "CD4": 1.0, "CD56": 1.0, "CD57": 1.0}
16 | with open(os.path.join(temp_dir, "norm_dict.json"), "w") as f:
17 | json.dump(norm_dict, f)
18 | data_folders = prepare_test_data_folders(
19 | 5, temp_dir, list(norm_dict.keys()) + ["XYZ"], random=True,
20 | scale=[0.5, 1.0, 1.5, 2.0, 5.0]
21 | )
22 | cell_table_path = os.path.join(temp_dir, "cell_table.csv")
23 | cell_table = pd.DataFrame(
24 | {
25 | "SampleID": ["fov_0"] * 15 + ["fov_1"] * 15 + ["fov_2"] * 15 + ["fov_3"] * 15 +
26 | ["fov_4"] * 15 + ["fov_5"] * 15,
27 | "labels": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] * 6,
28 | "CD4_gt": [1, 1, 0, 0, 2] * 3 * 3 +
29 | [0, 0, 1, 1, 2] * 3 * 3,
30 | }
31 | )
32 |
33 | cell_table.to_csv(cell_table_path, index=False)
34 | data_prep = SimpleTFRecords(
35 | data_dir=temp_dir,
36 | tf_record_path=temp_dir,
37 | cell_table_path=cell_table_path,
38 | normalization_dict_path=None,
39 | selected_markers=["CD4"],
40 | imaging_platform="test",
41 | dataset="test",
42 | tissue_type="test",
43 | tile_size=[256, 256],
44 | stride=[256, 256],
45 | nuclei_channels=["CD56"],
46 | membrane_channels=["CD57"],
47 | )
48 | data_prep.load_and_check_input()
49 | marker = "CD4"
50 | sample_name = "fov_1"
51 | fov_1_subset = cell_table[cell_table.SampleID == sample_name]
52 | data_prep.sample_subset = fov_1_subset
53 | marker_activity, _ = data_prep.get_marker_activity(sample_name, marker)
54 |
55 | # check if the we get marker_acitivity for all labels in the fov_1 subset
56 | assert np.array_equal(marker_activity.labels, fov_1_subset.labels)
57 |
58 | # check if the df has the right marker activity values for a given cell
59 | assert np.array_equal(marker_activity.activity.values, fov_1_subset.CD4_gt.values)
60 |
--------------------------------------------------------------------------------
/tests/unet_test.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/jakeret/unet/blob/master/tests/test_unet.py
2 | from unittest.mock import Mock, patch
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.keras import layers
6 | from tensorflow.keras import losses
7 | import sys
8 | from cell_classification import unet
9 | import pytest
10 |
11 |
12 | class TestPad2D:
13 |
14 | def test_serialization(self):
15 | pad2d = unet.Pad2D(padding=(1, 1), data_format="channels_last")
16 | config = pad2d.get_config()
17 | new_pad2d = unet.Pad2D.from_config(config)
18 |
19 | assert new_pad2d.padding == pad2d.padding
20 | assert new_pad2d.data_format == pad2d.data_format
21 |
22 | def test_padding(self):
23 | for data_format in ["channels_last", "channels_first"]:
24 | for mode in ["CONSTANT", "REFLECT", "SYMMETRIC"]:
25 | print(mode)
26 | pad2d = unet.Pad2D(data_format=data_format, mode=mode)
27 | if data_format == "channels_last":
28 | input_tensor = np.ones([1, 10, 10, 1])
29 | output_tensor = pad2d(input_tensor)
30 | assert output_tensor.shape == (1, 12, 12, 1)
31 | else:
32 | input_tensor = np.ones([1, 1, 10, 10])
33 | output_tensor = pad2d(input_tensor)
34 | assert output_tensor.shape == (1, 1, 12, 12)
35 |
36 | # test for valid padding
37 | pad2d = unet.Pad2D(data_format="channels_last", mode="VALID")
38 | input_tensor = np.ones([1, 10, 10, 1])
39 | output_tensor = pad2d(input_tensor)
40 | assert output_tensor.shape == (1, 10, 10, 1)
41 |
42 | # check if error is raised when mode is not valid
43 | with pytest.raises(ValueError, match="mode must be"):
44 | pad2d = unet.Pad2D(data_format="channels_last", mode="same")
45 |
46 | # check if error is raised when data_format is not valid
47 | with pytest.raises(ValueError, match="data_format must be"):
48 | pad2d = unet.Pad2D(data_format="channels_in_the_middle", mode="CONSTANT")
49 |
50 |
51 | class TestConvBlock:
52 |
53 | def test_serialization(self):
54 | conv_block = unet.ConvBlock(layer_idx=1,
55 | filters_root=16,
56 | kernel_size=3,
57 | padding="REFLECT",
58 | activation="relu",
59 | name="conv_block_test",
60 | data_format="channels_last")
61 |
62 | config = conv_block.get_config()
63 | new_conv_block = unet.ConvBlock.from_config(config)
64 |
65 | assert new_conv_block.layer_idx == conv_block.layer_idx
66 | assert new_conv_block.filters_root == conv_block.filters_root
67 | assert new_conv_block.kernel_size == conv_block.kernel_size
68 | assert new_conv_block.padding == conv_block.padding
69 | assert new_conv_block.activation == conv_block.activation
70 | assert new_conv_block.activation == conv_block.activation
71 | assert new_conv_block.data_format == conv_block.data_format
72 |
73 |
74 | class TestUpconvBlock:
75 |
76 | def test_serialization(self):
77 | upconv_block = unet.UpconvBlock(
78 | layer_idx=1, filters_root=16, kernel_size=3, pool_size=2, padding="REFLECT",
79 | activation="relu", name="upconv_block_test", data_format="channels_last"
80 | )
81 |
82 | config = upconv_block.get_config()
83 | new_upconv_block = unet.UpconvBlock.from_config(config)
84 |
85 | assert new_upconv_block.layer_idx == upconv_block.layer_idx
86 | assert new_upconv_block.filters_root == upconv_block.filters_root
87 | assert new_upconv_block.kernel_size == upconv_block.kernel_size
88 | assert new_upconv_block.pool_size == upconv_block.pool_size
89 | assert new_upconv_block.padding == upconv_block.padding
90 | assert new_upconv_block.activation == upconv_block.activation
91 | assert new_upconv_block.activation == upconv_block.activation
92 | assert new_upconv_block.data_format == upconv_block.data_format
93 |
94 |
95 | class TestCropConcatBlock():
96 |
97 | def test_uneven_concat(self):
98 | layer = unet.CropConcatBlock(data_format="channels_last")
99 | down_tensor = np.ones([1, 61, 61, 32])
100 | up_tensor = np.ones([1, 52, 52, 32])
101 |
102 | concat_tensor = layer(up_tensor, down_tensor)
103 |
104 | assert concat_tensor.shape == (1, 52, 52, 64)
105 |
106 |
107 | class TestUnetModel:
108 |
109 | def test_serialization(self, tmpdir):
110 | save_path = str(tmpdir / "unet_model")
111 | unet_model = unet.build_model(layer_depth=3, filters_root=2)
112 | unet_model.save(save_path)
113 |
114 | reconstructed_model = tf.keras.models.load_model(save_path)
115 | assert reconstructed_model is not None
116 |
117 | def test_build_model(self):
118 | nx = 512
119 | ny = 512
120 | channels = 3
121 | num_classes = 2
122 | kernel_size = 3
123 | pool_size = 2
124 | filters_root = 64
125 | layer_depth = 5
126 | # same padding
127 | padding = "CONSTANT"
128 | model = unet.build_model(nx=nx,
129 | ny=ny,
130 | channels=channels,
131 | num_classes=num_classes,
132 | layer_depth=layer_depth,
133 | filters_root=filters_root,
134 | kernel_size=kernel_size,
135 | pool_size=pool_size,
136 | padding=padding,
137 | data_format="channels_last")
138 |
139 | input_shape = model.get_layer("inputs").output.shape
140 | assert tuple(input_shape) == (None, nx, ny, channels)
141 | output_shape = model.get_layer("semantic_head").output.shape
142 | assert tuple(output_shape) == (None, nx, ny, num_classes)
143 |
144 | # valid padding
145 | padding = "VALID"
146 | nx = 572
147 | ny = 572
148 | model = unet.build_model(nx=nx,
149 | ny=ny,
150 | channels=channels,
151 | num_classes=num_classes,
152 | layer_depth=layer_depth,
153 | filters_root=filters_root,
154 | kernel_size=kernel_size,
155 | pool_size=pool_size,
156 | padding=padding,
157 | data_format="channels_last")
158 |
159 | input_shape = model.get_layer("inputs").output.shape
160 | assert tuple(input_shape) == (None, nx, ny, channels)
161 | output_shape = model.get_layer("semantic_head").output.shape
162 | assert tuple(output_shape) == (None, 388, 388, num_classes)
163 |
164 | filters_per_layer = [filters_root, 128, 256, 512, 1024, 512, 256, 128, filters_root]
165 | conv2D_layers = _collect_conv2d_layers(model)
166 |
167 | assert len(conv2D_layers) == 2 * len(filters_per_layer) + 1
168 |
169 | for conv2D_layer in conv2D_layers[:-1]:
170 | assert conv2D_layer.kernel_size == (kernel_size, kernel_size)
171 |
172 | for i, filters in enumerate(filters_per_layer):
173 | assert conv2D_layers[i*2].filters == filters
174 | assert conv2D_layers[i*2+1].filters == filters
175 |
176 | maxpool_layers = [layer for layer in model.layers if isinstance(layer, layers.MaxPool2D)]
177 |
178 | assert len(maxpool_layers) == layer_depth - 1
179 |
180 | for maxpool_layer in maxpool_layers[:-1]:
181 | assert maxpool_layer.pool_size == (pool_size, pool_size)
182 |
183 |
184 | def _collect_conv2d_layers(model):
185 | conv2d_layers = []
186 | for layer in model.layers:
187 | if isinstance(layer, layers.Conv2D):
188 | conv2d_layers.append(layer)
189 | elif isinstance(layer, unet.ConvBlock):
190 | conv2d_layers.append(layer.conv2d_1)
191 | conv2d_layers.append(layer.conv2d_2)
192 |
193 | return conv2d_layers
194 |
--------------------------------------------------------------------------------
/tests/viewer_widget_test.py:
--------------------------------------------------------------------------------
1 | from cell_classification.viewer_widget import NimbusViewer
2 | from segmentation_data_prep_test import prep_object_and_inputs
3 | import numpy as np
4 | import tempfile
5 | import os
6 |
7 |
8 | def test_NimbusViewer():
9 | with tempfile.TemporaryDirectory() as temp_dir:
10 | _, _, _, _ = prep_object_and_inputs(temp_dir)
11 | viewer_widget = NimbusViewer(temp_dir, temp_dir)
12 | assert isinstance(viewer_widget, NimbusViewer)
13 |
14 |
15 | def test_composite_image():
16 | with tempfile.TemporaryDirectory() as temp_dir:
17 | _, _, _, _ = prep_object_and_inputs(temp_dir)
18 | viewer_widget = NimbusViewer(temp_dir, temp_dir)
19 | path_dict = {
20 | "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"),
21 | "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"),
22 | }
23 | composite_image = viewer_widget.create_composite_image(path_dict)
24 | assert isinstance(composite_image, np.ndarray)
25 | assert composite_image.shape == (256, 256, 2)
26 |
27 | path_dict["blue"] = os.path.join(temp_dir, "fov_0", "CD56.tiff")
28 | composite_image = viewer_widget.create_composite_image(path_dict)
29 | assert composite_image.shape == (256, 256, 3)
30 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # Enable line length testing with maximum line length of 99
2 | [pycodestyle]
3 | max-line-length = 99
--------------------------------------------------------------------------------