├── .editorconfig ├── .flake8 ├── .github └── workflows │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── _static │ └── css │ │ └── custom.css ├── _templates │ ├── autosummary │ │ └── class.rst │ └── class_no_inherited.rst ├── api │ └── index.rst ├── conf.py ├── extensions │ └── typed_returns.py ├── index.rst ├── installation.rst ├── make.bat ├── notebooks │ ├── hierarchical_selsction_prototype.ipynb │ └── marker_selection_example.ipynb ├── references.rst └── release_notes │ ├── index.rst │ ├── v0.1.0.rst │ └── v0.1.1.rst ├── pyproject.toml ├── readthedocs.yml ├── schierarchy ├── __init__.py ├── base │ ├── __init__.py │ └── _pyro_base_regression_module.py ├── logistic │ ├── __init__.py │ ├── _logistic_model.py │ └── _logistic_module.py ├── regression │ ├── __init__.py │ ├── _reference_model.py │ └── _reference_module.py └── utils │ ├── __init__.py │ ├── data_transformation.py │ └── simulation.py ├── setup.py └── tests ├── __init__.py ├── test_hierarchical_logist.py ├── test_hierarchical_logist_nohierarchy.py └── test_regression.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, W605, N812 3 | exclude = .git,docs 4 | max-line-length = 119 5 | 6 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: schierarchy 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Cache pip 27 | uses: actions/cache@v2 28 | with: 29 | path: ~/.cache/pip 30 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 31 | restore-keys: | 32 | ${{ runner.os }}-pip- 33 | - name: Install dependencies 34 | run: | 35 | pip install pytest-cov 36 | pip install .[dev] 37 | - name: Lint with flake8 38 | run: | 39 | flake8 40 | - name: Format with black 41 | run: | 42 | black --check . 43 | - name: Test with pytest 44 | run: | 45 | pytest --cov-report=xml --cov=mypackage 46 | - name: After success 47 | run: | 48 | bash <(curl -s https://codecov.io/bash) 49 | pip list 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project-specific 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # vscode 135 | .vscode/settings.json 136 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - repo: https://github.com/python/black 7 | rev: 22.3.0 8 | hooks: 9 | - id: black 10 | - repo: https://gitlab.com/pycqa/flake8 11 | rev: 3.8.4 12 | hooks: 13 | - id: flake8 14 | - repo: https://github.com/pycqa/isort 15 | rev: 5.7.0 16 | hooks: 17 | - id: isort 18 | name: isort (python) 19 | additional_dependencies: [toml] 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scHierarchy: Hierarchical logistic regression model for marker gene selection using hierachical cell annotations 2 | 3 | [![Stars](https://img.shields.io/github/stars/dissatisfaction-ai/scHierarchy?logo=GitHub&color=yellow)](https://github.com/vitkl/scHierarchy/stargazers) 4 | [![Documentation Status](https://readthedocs.org/projects/scHierarchy/badge/?version=latest)](https://scHierarchy.readthedocs.io/en/stable/?badge=stable) 5 | ![Build Status](https://github.com/dissatisfaction-ai/scHierarchy/actions/workflows/test.yml/badge.svg?event=push) 6 | [![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) 7 | 8 | image 9 | 10 | ## Installation 11 | 12 | Linux installation 13 | ```bash 14 | conda create -y -n schierarchy-env python=3.9 15 | conda activate schierarchy-env 16 | pip install git+https://github.com/dissatisfaction-ai/scHierarchy.git 17 | ``` 18 | 19 | Mac installation 20 | ```bash 21 | conda create -y -n schierarchy-env python=3.8 22 | conda activate schierarchy-env 23 | pip install git+https://github.com/pyro-ppl/pyro.git@dev 24 | conda install -y -c anaconda hdf5 pytables netcdf4 25 | pip install git+https://github.com/dissatisfaction-ai/scHierarchy.git 26 | ``` 27 | 28 | ## Usage example notebook 29 | 30 | https://github.com/dissatisfaction-ai/scHierarchy/blob/main/docs/notebooks/marker_selection_example.ipynb 31 | 32 | This will be updated using a publicly available dataset & Colab - however, it is challenging to find published dataset with several annotation levels yet sufficiently small to be used on Colab. 33 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Run to check if valid 2 | # curl --data-binary @codecov.yml https://codecov.io/validate 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: 80% 8 | threshold: 1% 9 | patch: off 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = scvi 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* influenced by and borrowed from: https://github.com/cvxgrp/pymde/blob/main/docs_src/source/_static/css/custom.css */ 2 | 3 | @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,300;0,400;0,600;1,300;1,400;1,600&display=swap'); 4 | 5 | :root { 6 | --sidebarcolor: #003262; 7 | --sidebarfontcolor: #ffffff; 8 | --sidebarhover: #295e97; 9 | 10 | --bodyfontcolor: #333; 11 | --webfont: 'Roboto'; 12 | 13 | --contentwidth: 1000px; 14 | } 15 | 16 | /* Fonts and text */ 17 | h1, h2, h3, h4, h5, h6 { 18 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif; 19 | font-weight: 400; 20 | } 21 | 22 | h2, h3, h4, h5, h6 { 23 | padding-top: 0.25em; 24 | margin-bottom: 0.5em; 25 | } 26 | 27 | h1 { 28 | font-size: 225%; 29 | } 30 | 31 | body { 32 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif; 33 | color: var(--bodyfontcolor); 34 | } 35 | 36 | p { 37 | font-size: 1em; 38 | line-height: 150%; 39 | } 40 | 41 | 42 | /* Sidebar */ 43 | .wy-side-nav-search { 44 | background-color: var(--sidebarcolor); 45 | } 46 | 47 | .wy-nav-side { 48 | background: var(--sidebarcolor); 49 | } 50 | 51 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 52 | color: var(--sidebarfontcolor); 53 | } 54 | 55 | .wy-menu-vertical a { 56 | color: var(--sidebarfontcolor); 57 | } 58 | 59 | .wy-side-nav-search > div.version { 60 | color: var(--sidebarfontcolor); 61 | } 62 | 63 | .wy-menu-vertical a:hover { 64 | background-color: var(--sidebarhover); 65 | } 66 | 67 | /* Main content */ 68 | .wy-nav-content { 69 | max-width: var(--contentwidth); 70 | } 71 | 72 | 73 | html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list) > dt{ 74 | margin-bottom: 6px; 75 | border-left: none; 76 | background: none; 77 | color: #555; 78 | } 79 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: Attributes 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | .. rubric:: Methods 24 | 25 | .. autosummary:: 26 | :toctree: . 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /docs/_templates/class_no_inherited.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block methods %} 10 | {% if methods %} 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in methods %} 16 | {%- if item != '__init__' and item not in inherited_members%} 17 | ~{{ fullname }}.{{ item }} 18 | {%- endif -%} 19 | 20 | {%- endfor %} 21 | {% endif %} 22 | {% endblock %} 23 | -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | === 2 | API 3 | === 4 | 5 | .. currentmodule:: mypackage 6 | 7 | .. note:: External functions can be imported from other packages like `scvi-tools` and be displayed on your docs website. As an example, `setup_anndata` here is directly from `scvi-tools`. 8 | 9 | Data 10 | ~~~~ 11 | .. autosummary:: 12 | :toctree: reference/ 13 | 14 | setup_anndata 15 | 16 | MyModel 17 | ~~~~~~~ 18 | 19 | .. autosummary:: 20 | :toctree: reference/ 21 | 22 | MyModel 23 | MyPyroModel 24 | 25 | MyModule 26 | ~~~~~~~~ 27 | .. autosummary:: 28 | :toctree: reference/ 29 | :template: class_no_inherited.rst 30 | 31 | MyModule 32 | MyPyroModule 33 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | import sys 10 | from pathlib import Path 11 | 12 | HERE = Path(__file__).parent 13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] 14 | 15 | import mypackage # noqa 16 | 17 | 18 | # -- General configuration --------------------------------------------- 19 | 20 | # If your documentation needs a minimal Sphinx version, state it here. 21 | # 22 | needs_sphinx = "3.0" # Nicer param docs 23 | 24 | # Add any Sphinx extension module names here, as strings. They can be 25 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 26 | extensions = [ 27 | "sphinx.ext.autodoc", 28 | "sphinx.ext.viewcode", 29 | "nbsphinx", 30 | "nbsphinx_link", 31 | "sphinx.ext.mathjax", 32 | "sphinx.ext.napoleon", 33 | "sphinx_autodoc_typehints", # needs to be after napoleon 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.autosummary", 36 | "scanpydoc.elegant_typehints", 37 | "scanpydoc.definition_list_typed_field", 38 | "scanpydoc.autosummary_generate_imported", 39 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 40 | ] 41 | 42 | # nbsphinx specific settings 43 | exclude_patterns = ["_build", "**.ipynb_checkpoints"] 44 | nbsphinx_execute = "never" 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = ".rst" 54 | 55 | # Generate the API documentation when building 56 | autosummary_generate = True 57 | autodoc_member_order = "bysource" 58 | napoleon_google_docstring = False 59 | napoleon_numpy_docstring = True 60 | napoleon_include_init_with_doc = False 61 | napoleon_use_rtype = True 62 | napoleon_use_param = True 63 | napoleon_custom_sections = [("Params", "Parameters")] 64 | todo_include_todos = False 65 | numpydoc_show_class_members = False 66 | annotate_defaults = True 67 | # The master toctree document. 68 | master_doc = "index" 69 | 70 | 71 | intersphinx_mapping = dict( 72 | anndata=("https://anndata.readthedocs.io/en/stable/", None), 73 | ipython=("https://ipython.readthedocs.io/en/stable/", None), 74 | matplotlib=("https://matplotlib.org/", None), 75 | numpy=("https://docs.scipy.org/doc/numpy/", None), 76 | pandas=("https://pandas.pydata.org/pandas-docs/stable/", None), 77 | python=("https://docs.python.org/3", None), 78 | scipy=("https://docs.scipy.org/doc/scipy/reference/", None), 79 | sklearn=("https://scikit-learn.org/stable/", None), 80 | torch=("https://pytorch.org/docs/master/", None), 81 | scanpy=("https://scanpy.readthedocs.io/en/stable/", None), 82 | pytorch_lightning=("https://pytorch-lightning.readthedocs.io/en/stable/", None), 83 | ) 84 | 85 | 86 | # General information about the project. 87 | project = "schierarchy" 88 | copyright = "2021, Yosef Lab, UC Berkeley" 89 | author = "Adam Gayoso" 90 | 91 | # The version info for the project you're documenting, acts as replacement 92 | # for |version| and |release|, also used in various other places throughout 93 | # the built documents. 94 | # 95 | # The short X.Y version. 96 | version = mypackage.__version__ 97 | # The full version, including alpha/beta/rc tags. 98 | release = mypackage.__version__ 99 | 100 | # The language for content autogenerated by Sphinx. Refer to documentation 101 | # for a list of supported languages. 102 | # 103 | # This is also used if you do content translation via gettext catalogs. 104 | # Usually you set "language" from the command line for these cases. 105 | language = None 106 | 107 | # List of patterns, relative to source directory, that match files and 108 | # directories to ignore when looking for source files. 109 | # This patterns also effect to html_static_path and html_extra_path 110 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 111 | 112 | # The name of the Pygments (syntax highlighting) style to use. 113 | pygments_style = "tango" 114 | 115 | # If true, `todo` and `todoList` produce output, else they produce nothing. 116 | todo_include_todos = False 117 | 118 | # -- Options for HTML output ------------------------------------------------- 119 | 120 | # The theme to use for HTML and HTML Help pages. See the documentation for 121 | # a list of builtin themes. 122 | # 123 | html_theme = "sphinx_rtd_theme" 124 | 125 | html_show_sourcelink = False 126 | 127 | html_show_copyright = False 128 | 129 | display_version = True 130 | 131 | # Add any paths that contain custom static files (such as style sheets) here, 132 | # relative to this directory. They are copied after the builtin static files, 133 | # so a file named "default.css" will overwrite the builtin "default.css". 134 | html_static_path = ["_static"] 135 | 136 | html_css_files = [ 137 | "css/custom.css", 138 | ] 139 | 140 | html_favicon = "favicon.ico" 141 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | import re 4 | 5 | from sphinx.application import Sphinx 6 | from sphinx.ext.napoleon import NumpyDocstring 7 | 8 | 9 | def process_return(lines): 10 | for line in lines: 11 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) 12 | if m: 13 | # Once this is in scanpydoc, we can use the fancy hover stuff 14 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 15 | else: 16 | yield line 17 | 18 | 19 | def scanpy_parse_returns_section(self, section): 20 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section()))) 21 | lines = self._format_block(":returns: ", lines_raw) 22 | if lines and lines[-1]: 23 | lines.append("") 24 | return lines 25 | 26 | 27 | def setup(app: Sphinx): 28 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section 29 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | mypackage documentation 3 | ======================== 4 | 5 | Welcome! This is the corresponding documentation website for the `scvi-tools-skeleton 6 | `_. The purpose of this website is to demonstrate some of the functionality available for scvi-tools developer API users. 7 | 8 | We recommend building your own package by first using this repository as a `template `_. Subsequently, the info in the `/docs` directory can be updated to contain information relevant to your package. 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | :hidden: 14 | 15 | installation 16 | api/index 17 | release_notes/index 18 | references 19 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Prerequisites 5 | ~~~~~~~~~~~~~~ 6 | 7 | my_package can be installed via PyPI. 8 | 9 | conda prerequisites 10 | ################### 11 | 12 | 1. Install Conda. We typically use the Miniconda_ Python distribution. Use Python version >=3.7. 13 | 14 | 2. Create a new conda environment:: 15 | 16 | conda create -n scvi-env python=3.7 17 | 18 | 3. Activate your environment:: 19 | 20 | source activate scvi-env 21 | 22 | pip prerequisites: 23 | ################## 24 | 25 | 1. Install Python_, we prefer the `pyenv `_ version management system, along with `pyenv-virtualenv `_. 26 | 27 | 2. Install PyTorch_. If you have an Nvidia GPU, be sure to install a version of PyTorch that supports it -- scvi-tools runs much faster with a discrete GPU. 28 | 29 | .. _Miniconda: https://conda.io/miniconda.html 30 | .. _Python: https://www.python.org/downloads/ 31 | .. _PyTorch: http://pytorch.org 32 | 33 | my_package installation 34 | ~~~~~~~~~~~~~~~~~~~~~~~ 35 | 36 | Install my_package in one of the following ways: 37 | 38 | Through **pip**:: 39 | 40 | pip install 41 | 42 | Through pip with packages to run notebooks. This installs scanpy, etc.:: 43 | 44 | pip install [tutorials] 45 | 46 | Nightly version - clone this repo and run:: 47 | 48 | pip install . 49 | 50 | For development - clone this repo and run:: 51 | 52 | pip install -e .[dev,docs] 53 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=scvi 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/notebooks/hierarchical_selsction_prototype.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1e96289b-7640-4294-9e10-ba8f0ef5eae9", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import pandas as pd\n", 13 | "import torch\n", 14 | "from torch.nn import Parameter\n", 15 | "import numpy as np\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "\n", 18 | "from pyro.infer import MCMC, NUTS, Predictive\n", 19 | "import pyro\n", 20 | "from pyro import poutine\n", 21 | "from pyro import distributions as dist\n", 22 | "from pyro.optim import Adam\n", 23 | "from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO\n", 24 | "from pyro.infer.autoguide import AutoDelta, AutoNormal\n", 25 | "pyro.enable_validation(True)\n", 26 | "import tqdm\n", 27 | "\n", 28 | "import scipy\n", 29 | "from scipy import ndimage\n", 30 | "from scipy.sparse import load_npz\n", 31 | "from scipy.sparse import coo_matrix\n", 32 | "#import skimage.color\n", 33 | "\n", 34 | "from sklearn.neighbors import KDTree\n", 35 | "\n", 36 | "from collections import Counter\n", 37 | "\n", 38 | "import scanpy as sc\n", 39 | "\n", 40 | "import seaborn as sns" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "810ef782-d010-41ea-bd35-869d65244d8d", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def subsample_clusters(obj, group_label = 'celltype_subset', subsample_factor = 0.2):\n", 51 | " chosen = []\n", 52 | " indxs = np.arange(obj.n_obs)\n", 53 | " for group in obj.obs[group_label].unique():\n", 54 | " group_indxs = np.where(obj.obs[group_label] == group)[0]\n", 55 | " group_chosen = np.random.choice(group_indxs,\n", 56 | " size=np.maximum(int(len(group_indxs)*subsample_factor), [len(group_indxs),250][int(len(group_indxs) > 250)]),\n", 57 | " replace=False)\n", 58 | " #print(len(group_chosen))\n", 59 | " chosen += list(group_chosen)\n", 60 | "\n", 61 | " return np.array(chosen)\n", 62 | "\n", 63 | "def data_to_zero_truncated_cdf(x):\n", 64 | " x_sorted = np.sort(x)\n", 65 | " indx_sorted = np.argsort(x)\n", 66 | " x_sorted = x[indx_sorted]\n", 67 | " zero_ind = np.where(x_sorted == 0)[0][-1]\n", 68 | " p = np.concatenate([np.zeros(zero_ind), 1. * np.arange(len(x_sorted) - zero_ind) / (len(x_sorted) - zero_ind - 1)])\n", 69 | " cdfs = np.zeros_like(x)\n", 70 | " cdfs[indx_sorted] = p\n", 71 | " return cdfs" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "ccbe68b4-97d5-425a-a019-4a62576a201f", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "data = sc.read_h5ad('../data_atlas/atals_processed.h5ad')\n", 82 | "data.layers['raw_data'] = ((data.X.expm1() / 10000).multiply(data.obs[['nCount_RNA']].values)).tocsr() #real data \n", 83 | "\n", 84 | "data.obs['celltype_subset_alt'] = list(data.obs['celltype_subset'].values)\n", 85 | "#data.obs.loc[data.obs['gene_module'] != 'no_gene_module', 'celltype_subset_alt'] = data.obs.loc[data.obs['gene_module'] != 'no_gene_module', 'gene_module']\n", 86 | "data.obs.loc[data.obs['celltype_major'] == 'CAFs', 'celltype_subset_alt'] = data.obs.loc[data.obs['celltype_major'] == 'CAFs', 'celltype_minor']\n", 87 | "\n", 88 | "group_today = 'celltype_subset_alt'\n", 89 | "\n", 90 | "data_sub = data[subsample_clusters(data, subsample_factor=0.1, group_label=group_today)]\n", 91 | "\n", 92 | "mean_exp_vals = []\n", 93 | "for g in data_sub.obs[group_today].unique():\n", 94 | " #print(g)\n", 95 | " tmp = data_sub.X[np.where(data_sub.obs[group_today] == g)[0],:]\n", 96 | " tmp = tmp.toarray().mean(axis=0)\n", 97 | " mean_exp_vals.append(tmp)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "e15d27dc-05e0-457b-bd0d-df923c423c3b", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "2d326189-e48c-4b4e-9588-1dcaddb110bd", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "expressed_genes = np.where((np.exp(np.stack(mean_exp_vals, axis=0).max(axis=0)) - 1 > \n", 116 | " (np.exp(np.stack(mean_exp_vals, axis=0).max(axis=0))[np.where(data_sub.var.index == 'RPLP0')[0]] - 1)/20))[0]\n", 117 | "\n", 118 | "data_sub = data_sub[:, expressed_genes].copy()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "215c94c8-27a2-4e68-8680-5bd2ac933233", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "58613c69-f272-46c0-a40a-ed4442c422b1", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "def multi_logist_hierarchical(D, X, label_graph, class_size, device='cuda'):\n", 137 | " '''\n", 138 | " D ~ Multinomial(softmax(Xw))\n", 139 | " \n", 140 | " Parameters\n", 141 | " ----------\n", 142 | " D : torch.tensor\n", 143 | " Cell type assignement (cells x levels) \n", 144 | " X : torch.tensor\n", 145 | " Expression matrix (cells x genes)\n", 146 | " label_graph : dictionary\n", 147 | " level mapping binary (parent x children)\n", 148 | " class_size : dict \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " device: torch.device\n", 153 | " Device specification (default - CUDA) for CPU use torch.device('cpu')\n", 154 | "\n", 155 | " '''\n", 156 | " i_cells, g_genes = X.shape\n", 157 | " c_class_levels = [class_size[i].size()[0] for i in range(2)]\n", 158 | " \n", 159 | " #top level weights\n", 160 | " w_top = pyro.sample('w_top', dist.Laplace(torch.tensor([0.]).to(device),\n", 161 | " torch.tensor([0.5]).to(device)).expand([g_genes, c_class_levels[0]]).to_event(2)) \n", 162 | " f_top = torch.nn.functional.softmax(torch.matmul(X, w_top / (class_size[0])**0.5), dim=1)\n", 163 | " #bottom levels\n", 164 | " w_bottom = pyro.sample('w_bottom', dist.Laplace(torch.tensor([0.]).to(device),\n", 165 | " torch.tensor([0.5]).to(device)).expand([g_genes, c_class_levels[1]]).to_event(2)) \n", 166 | " \n", 167 | " \n", 168 | " #w_top_actual = torch.zeros_like(w_top)\n", 169 | " f_bottom = torch.zeros((i_cells, c_class_levels[1])).to(device)\n", 170 | " #for parent in range(c_class_levels[0]):\n", 171 | " #w_top_actual[:, parent] = torch.maximum(w_top[:, parent] - w_bottom[:, label_graph[parent]].sum(axis=1), torch.zeros_like(w_top_actual[:, parent]))\n", 172 | " #f_bottom[:, label_graph[parent]] += torch.nn.functional.softmax(torch.matmul(X, w_bottom[:, label_graph[parent]] / (class_size[1][label_graph[parent]])[None,:]**0.5), dim=1) / len(label_graph[parent])\n", 173 | " #w_top_actual = pyro.deterministic('w_top_actual', w_top_actual)\n", 174 | " #f_top = torch.nn.functional.softmax(torch.matmul(X, w_top_actual / (class_size[0])**0.5), dim=1)\n", 175 | "\n", 176 | " for parent in range(c_class_levels[0]):\n", 177 | " f_bottom[:, label_graph[parent]] += torch.nn.functional.softmax(torch.matmul(X, w_bottom[:, label_graph[parent]] / (class_size[1][label_graph[parent]])[None,:]**0.5), dim=1) * f_top[:,parent,None]\n", 178 | " \n", 179 | " \n", 180 | " #bottom level\n", 181 | " obs_top = pyro.sample('likelihood_top', dist.Categorical(f_top).to_event(1), obs=D[:,0])\n", 182 | " obs_bottom = pyro.sample('likelihood_bottom', dist.Categorical(f_bottom).to_event(1), obs=D[:,1]) " 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "c50759be-a918-470f-b8f4-9c170d6dd605", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "cac2b31a-10e0-458e-ac6d-ed772d0637dd", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "1eae2906-27dd-4319-ae6a-a5929e40168b", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "#hierarchical labeling \n", 209 | "np.random.seed(42)\n", 210 | "pyro.set_rng_seed(42)\n", 211 | "torch.manual_seed(42)\n", 212 | "\n", 213 | "device='cpu'\n", 214 | "D = torch.tensor(np.float32(np.stack([data_sub.obs['celltype_major'].cat.codes.values, data_sub.obs[group_today].cat.codes.values], axis=1))).to(device=device)\n", 215 | "X = np.apply_along_axis(data_to_zero_truncated_cdf, 0, data_sub.X.toarray()) #/ (data_sub.layers['raw_data'].toarray().max(axis=0)**0.5)[None,:]\n", 216 | "X = torch.tensor(np.float32(X)).to(device=device)\n", 217 | "class_size = [torch.tensor(np.float32(data_sub.obs.groupby(group).size().values)).to(device=device) for group in ['celltype_major', group_today]]\n", 218 | "label_graph = {}\n", 219 | "for i, g in enumerate(data_sub.obs['celltype_major'].cat.categories):\n", 220 | " label_graph[i] = data_sub.obs[data_sub.obs['celltype_major'] == g][group_today].unique().codes\n", 221 | "\n", 222 | "pyro.clear_param_store()\n", 223 | "\n", 224 | "model = multi_logist_hierarchical\n", 225 | "guide = AutoNormal(model)\n", 226 | "\n", 227 | "adam_params = {\"lr\": 0.01, \"betas\": (0.95, 0.999)}\n", 228 | "optimizer = Adam(adam_params)\n", 229 | "\n", 230 | "svi = SVI(model, guide, optimizer, loss=Trace_ELBO())\n", 231 | "\n", 232 | "n_steps = 10000\n", 233 | "# do gradient steps\n", 234 | "loss_list = []\n", 235 | "for step in tqdm.tqdm(range(n_steps)):\n", 236 | " loss = svi.step(D, X, label_graph, class_size, 'cpu')\n", 237 | " loss_list.append(loss)\n" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "730b52b9-9c49-43f4-aa68-faebd9e0168b", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "8cef7858-7757-4e68-be57-57c550bf03d4", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "trace = Predictive(model, guide=guide, num_samples=100)(D, X, label_graph, class_size, 'cpu')" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "id": "f054f8d2-478f-4ce3-b5bd-4671e611cd52", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "gene_names = data_sub.var.index.values\n", 266 | "\n", 267 | "selected_dcit_top = {}\n", 268 | "for i, name in enumerate(data_sub.obs['celltype_major'].cat.categories):\n", 269 | " weights = trace['w_top'].mean(axis=0)[:, i].cpu().numpy()\n", 270 | " top4 = np.argpartition(weights, -4)[-4:]\n", 271 | " selected_dcit_top[name] = gene_names[top4]\n", 272 | "selected_dcit_top['hk'] = np.array(['RPLP0'])" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "41a0b2e6-91e4-4761-8c20-7387e09689c1", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "fig = sc.pl.dotplot(data, selected_dcit_top, 'celltype_major')" 283 | ] 284 | } 285 | ], 286 | "metadata": { 287 | "kernelspec": { 288 | "display_name": "Environment (test_schierarchy)", 289 | "language": "python", 290 | "name": "test_schierarchy" 291 | }, 292 | "language_info": { 293 | "codemirror_mode": { 294 | "name": "ipython", 295 | "version": 3 296 | }, 297 | "file_extension": ".py", 298 | "mimetype": "text/x-python", 299 | "name": "python", 300 | "nbconvert_exporter": "python", 301 | "pygments_lexer": "ipython3", 302 | "version": "3.9.7" 303 | } 304 | }, 305 | "nbformat": 4, 306 | "nbformat_minor": 5 307 | } 308 | -------------------------------------------------------------------------------- /docs/notebooks/marker_selection_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7f51c100-e7c5-499e-8d55-82145a997d3c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "\n", 12 | "#if installed somewhere else\n", 13 | "sys.path.insert(1, '/nfs/team205/vk7/sanger_projects/BayraktarLab/cell2location/')\n", 14 | "sys.path.insert(1, '/lustre/scratch119/casm/team299ly/al15/projects/scHierarchy/')\n", 15 | "\n", 16 | "import scanpy as sc\n", 17 | "import anndata\n", 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt \n", 21 | "import matplotlib as mpl\n", 22 | "\n", 23 | "import cell2location\n", 24 | "import scvi\n", 25 | "import schierarchy\n", 26 | "\n", 27 | "from matplotlib import rcParams\n", 28 | "rcParams['pdf.fonttype'] = 42 # enables correct plotting of text\n", 29 | "import seaborn as sns" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "db0a5b58-f850-4825-b7a5-344f2eb53375", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "#location of scRNA data\n", 40 | "sc_data_folder = '/nfs/casm/team299ly/al15/projects/sc-breast/data_atlas/'\n", 41 | "#location where the result is stored \n", 42 | "results_folder = '/nfs/casm/team299ly/al15/projects/sc-breast/data_atlas/results/'\n", 43 | "\n", 44 | "#prefix for experiment\n", 45 | "ref_run_name = f'{results_folder}hierarchical_logist/'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "0a892146-efb1-4c2b-a4b9-39da51888bbb", 51 | "metadata": {}, 52 | "source": [ 53 | "## Load Breast cancer scRNA dataset" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "8d257eb3-2399-447c-ae54-ad0a795afe08", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "## read data\n", 64 | "adata_ref = anndata.read_h5ad(sc_data_folder + \"atals_processed.h5ad\")\n", 65 | "adata_ref.layers['processed'] = adata_ref.X\n", 66 | "#revert log transformation (if data is originally transformed)\n", 67 | "adata_ref.X = ((adata_ref.layers['processed'].expm1() / 10000).multiply(adata_ref.obs[['nCount_RNA']].values)).tocsr() #real data \n", 68 | "\n", 69 | "# mitochondrial genes\n", 70 | "adata_ref.var['mt'] = adata_ref.var_names.str.startswith('MT-') \n", 71 | "# ribosomal genes\n", 72 | "adata_ref.var['ribo'] = adata_ref.var_names.str.startswith((\"RPS\",\"RPL\"))\n", 73 | "# hemoglobin genes.\n", 74 | "adata_ref.var['hb'] = adata_ref.var_names.str.contains((\"^HB[^(P)]\"))\n", 75 | "\n", 76 | "#delete ribo mt and hb genes \n", 77 | "adata_ref = adata_ref[:, np.logical_and(np.logical_and(~adata_ref.var['mt'], ~adata_ref.var['ribo']), ~adata_ref.var['hb'])]" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "e88ed26c-d3b4-464e-8513-82d1102bf65c", 83 | "metadata": {}, 84 | "source": [ 85 | "### Process single cell data" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "9250e711-e373-4aea-8856-e28696ca1591", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# before we estimate the reference cell type signature we recommend to perform very permissive genes selection\n", 96 | "# in this 2D histogram orange rectangle lays over excluded genes.\n", 97 | "# In this case, the downloaded dataset was already filtered using this method,\n", 98 | "# hence no density under the orange rectangle\n", 99 | "from cell2location.utils.filtering import filter_genes\n", 100 | "selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)\n", 101 | "adata_ref = adata_ref[:, selected].copy()\n", 102 | "\n", 103 | "#remove genes which are omnispread\n", 104 | "max_cutoff = (adata_ref.var['n_cells'] / adata_ref.n_obs) > 0.8\n", 105 | "print(f'% of genes expressed in more than 80% of cells {max_cutoff.mean()}')\n", 106 | "\n", 107 | "# filter the object\n", 108 | "adata_ref = adata_ref[:, ~max_cutoff].copy()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "c9540a8d-892e-4364-a0b4-d182e241bf83", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "%%time\n", 119 | "#qunatile normalise log_transformed data, could be replaced with a transformation of your choice\n", 120 | "from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf\n", 121 | "adata_ref.layers[\"cdf\"] = np.apply_along_axis(\n", 122 | " data_to_zero_truncated_cdf, 0, adata_ref.layers[\"processed\"].toarray()\n", 123 | ")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "3cf6c2f7-c5ee-4fc4-8ae0-76d1c8b7b235", 129 | "metadata": {}, 130 | "source": [ 131 | "## Initialise and run the model" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "43a42209-900d-4a1c-9be6-031aa9028f3a", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "from schierarchy import LogisticModel\n", 142 | "\n", 143 | "#names of label columns from the most coarse to the most fine\n", 144 | "level_keys = ['celltype_major', 'celltype_minor', 'celltype_subset']\n", 145 | "\n", 146 | "LogisticModel.setup_anndata(adata_ref, layer=\"cdf\", level_keys=level_keys)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "edfb39e4-f98a-4a2d-b4e7-f16291ed7a11", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# train regression model to get signatures of cell types\n", 157 | "from schierarchy import LogisticModel\n", 158 | "learning_mode = 'fixed-sigma'\n", 159 | "mod = LogisticModel(adata_ref, laplace_learning_mode=learning_mode)\n", 160 | "\n", 161 | "# Use all data for training (validation not implemented yet, train_size=1)\n", 162 | "mod.train(max_epochs=600, batch_size=2500, train_size=1, lr=0.01, use_gpu=True)\n", 163 | "\n", 164 | "# plot ELBO loss history during training, removing first 20 epochs from the plot\n", 165 | "mod.plot_history(50)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "8b3fb8ca-c349-4a1f-8308-490877955c18", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "%%time\n", 176 | "\n", 177 | "# In this section, we export the estimated gene weights and per-cell probabilities \n", 178 | "# (summary of the posterior distribution).\n", 179 | "adata_ref = mod.export_posterior(\n", 180 | " adata_ref, sample_kwargs={'num_samples': 50, 'batch_size': 2500, 'use_gpu': True}\n", 181 | ")\n", 182 | "\n", 183 | "# Save model\n", 184 | "mod.save(f\"{ref_run_name}\", overwrite=True)\n", 185 | "\n", 186 | "# Save anndata object with results\n", 187 | "adata_file = f\"{ref_run_name}/sc.h5ad\"\n", 188 | "adata_ref.write(adata_file)\n", 189 | "adata_file" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "206c4872-c283-48c0-8e10-0a5f06113f27", 195 | "metadata": {}, 196 | "source": [ 197 | "## Load model" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "1fec3185-f973-494d-879c-106acc2ac45d", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "#if you're not making predictions - just work with adata_file, it already has stored results \n", 208 | "model = LogisticModel.load(ref_run_name, adata_ref)\n", 209 | "adata_ref = model.export_posterior(\n", 210 | " adata_ref, sample_kwargs={'num_samples': 50, 'batch_size': 2500, 'use_gpu': True}\n", 211 | ")" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "id": "504b907f-b12a-438d-a96a-51f3524d5b1d", 217 | "metadata": {}, 218 | "source": [ 219 | "## Visualise hierarchy of marker genes" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "6423dfe8-8688-4567-bbaa-138438d50338", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "adata_file = f\"{ref_run_name}/sc.h5ad\"\n", 230 | "adata_ref = sc.read(adata_file)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "d41bc1c8-a183-4e16-a38f-944e4b9b407f", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "#complete slected gene plots \n", 241 | "gene_names = adata_ref.var['gene_ids'].values\n", 242 | "observed_labels = []\n", 243 | "\n", 244 | "for level in level_keys:\n", 245 | " selected_dcit = {}\n", 246 | " for i, name in enumerate(adata_ref.obs[level].cat.categories):\n", 247 | " weights = adata_ref.varm[f'means_weight_{level}'][f'means_weight_{level}_{name}'].values\n", 248 | " top_n = np.argpartition(weights, -3)[-3:]\n", 249 | " if name not in observed_labels:\n", 250 | " selected_dcit[name] = gene_names[top_n]\n", 251 | " fig = sc.pl.dotplot(adata_ref, selected_dcit, level, log=True, gene_symbols='gene_ids')" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "4fdd013a-2f7e-40f4-8065-2d86d6ca5262", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "ind = adata_ref.obs[level_keys[0]].isin(['T-cells'])\n", 262 | "adata_ref_subset = adata_ref[ind, :]\n", 263 | "\n", 264 | "gene_names = adata_ref.var['gene_ids'].values\n", 265 | "observed_labels = []\n", 266 | "\n", 267 | "\n", 268 | "for level in level_keys:\n", 269 | " selected_dcit = {}\n", 270 | " for i, name in enumerate(adata_ref_subset.obs[level].cat.categories):\n", 271 | " weights = adata_ref_subset.varm[f'means_weight_{level}'][f'means_weight_{level}_{name}'].values\n", 272 | " top_n = np.argpartition(weights, -5)[-5:]\n", 273 | " if name not in observed_labels:\n", 274 | " selected_dcit[name] = gene_names[top_n]\n", 275 | " observed_labels.append(name)\n", 276 | " ind = adata_ref_subset.obs[level].isin(list(selected_dcit.keys()))\n", 277 | " adata_ref_subset_v2 = adata_ref_subset[ind, :]\n", 278 | " if adata_ref_subset_v2.n_obs > 0:\n", 279 | " fig = sc.pl.dotplot(adata_ref_subset_v2, selected_dcit, level, log=True, gene_symbols='gene_ids')" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "0184e2b9-758e-498a-88a6-d62799739ecb", 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [] 289 | } 290 | ], 291 | "metadata": { 292 | "kernelspec": { 293 | "display_name": "Environment (schierarchy)", 294 | "language": "python", 295 | "name": "schierarchy" 296 | }, 297 | "language_info": { 298 | "codemirror_mode": { 299 | "name": "ipython", 300 | "version": 3 301 | }, 302 | "file_extension": ".py", 303 | "mimetype": "text/x-python", 304 | "name": "python", 305 | "nbconvert_exporter": "python", 306 | "pygments_lexer": "ipython3", 307 | "version": "3.9.7" 308 | } 309 | }, 310 | "nbformat": 4, 311 | "nbformat_minor": 5 312 | } 313 | -------------------------------------------------------------------------------- /docs/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | .. note:: The reference below can be referenced in docstrings. See :class:`~mypackage.MyModule` for an example. 5 | 6 | .. [Lopez18] Romain Lopez, Jeffrey Regier, Michael Cole, Michael I. Jordan, Nir Yosef (2018), 7 | *Deep generative modeling for single-cell transcriptomics*, 8 | `Nature Methods `__. 9 | -------------------------------------------------------------------------------- /docs/release_notes/index.rst: -------------------------------------------------------------------------------- 1 | Release notes 2 | ============= 3 | 4 | This is the list of changes to scvi-tools between each release. Full commit history 5 | is available in the `commit logs `_. 6 | 7 | Version 0.1 8 | ----------- 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | v0.1.1 13 | v0.1.0 14 | -------------------------------------------------------------------------------- /docs/release_notes/v0.1.0.rst: -------------------------------------------------------------------------------- 1 | New in 0.1.0 (2020-01-25) 2 | ------------------------- 3 | Initial release of the skeleton. Complies with `scvi-tools>=0.9.0a0`. 4 | -------------------------------------------------------------------------------- /docs/release_notes/v0.1.1.rst: -------------------------------------------------------------------------------- 1 | New in 0.1.1 (2020-01-25) 2 | ------------------------- 3 | Add more documentation. 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | include_trailing_comma = true 3 | multi_line_output = 3 4 | profile = "black" 5 | skip_glob = ["docs/*", "schierarchy/__init__.py"] 6 | 7 | [tool.poetry] 8 | authors = ["Artem Lomakin ", "Vitalii Kleshchevnikov "] 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "Intended Audience :: Science/Research", 12 | "Natural Language :: English", 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3.10", 15 | "Operating System :: MacOS :: MacOS X", 16 | "Operating System :: Microsoft :: Windows", 17 | "Operating System :: POSIX :: Linux", 18 | "Topic :: Scientific/Engineering :: Bio-Informatics", 19 | ] 20 | description = "Hierachical cell identity toolkit." 21 | documentation = "https://github.com/vitkl/schierarchy" 22 | homepage = "https://github.com/vitkl/schierarchy" 23 | license = "Apache License, Version 2.0" 24 | name = "schierarchy" 25 | packages = [ 26 | {include = "schierarchy"}, 27 | ] 28 | readme = "README.md" 29 | version = "0.0.2" 30 | 31 | [tool.poetry.dependencies] 32 | anndata = ">=0.7.5" 33 | black = {version = "==22.3.0", optional = true} 34 | codecov = {version = ">=2.0.8", optional = true} 35 | flake8 = {version = ">=3.7.7", optional = true} 36 | importlib-metadata = {version = "^1.0", python = "<3.8"} 37 | ipython = {version = ">=7.1.1", optional = true} 38 | isort = {version = ">=5.7", optional = true} 39 | jupyter = {version = ">=1.0", optional = true} 40 | leidenalg = {version = "*", optional = true} 41 | loompy = {version = ">=3.0.6", optional = true} 42 | nbconvert = {version = ">=5.4.0", optional = true} 43 | nbformat = {version = ">=4.4.0", optional = true} 44 | nbsphinx = {version = "*", optional = true} 45 | nbsphinx-link = {version = "*", optional = true} 46 | pre-commit = {version = ">=2.7.1", optional = true} 47 | pydata-sphinx-theme = {version = ">=0.4.0", optional = true} 48 | pytest = {version = ">=4.4", optional = true} 49 | python = ">=3.9.18" 50 | python-igraph = {version = "*", optional = true} 51 | scanpy = {version = ">=1.6", optional = true} 52 | scanpydoc = {version = ">=0.5", optional = true} 53 | scikit-misc = {version = ">=0.1.3", optional = true} 54 | cell2location = {git = "https://github.com/BayraktarLab/cell2location.git@hires_sliding_window"} 55 | pyro-ppl = {version = ">=1.8.0"} 56 | scvi-tools = ">=1.0.0" 57 | sphinx = {version = "^3.0", optional = true} 58 | sphinx-autodoc-typehints = {version = "*", optional = true} 59 | sphinx-rtd-theme = {version = "*", optional = true} 60 | typing_extensions = {version = "*", python = "<3.8"} 61 | 62 | [tool.poetry.extras] 63 | dev = ["black", "pytest", "flake8", "codecov", "scanpy", "loompy", "jupyter", "nbformat", "nbconvert", "pre-commit", "isort"] 64 | docs = [ 65 | "sphinx", 66 | "scanpydoc", 67 | "nbsphinx", 68 | "nbsphinx-link", 69 | "ipython", 70 | "pydata-sphinx-theme", 71 | "typing_extensions", 72 | "sphinx-autodoc-typehints", 73 | "sphinx-rtd-theme", 74 | ] 75 | tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc"] 76 | 77 | [tool.poetry.dev-dependencies] 78 | 79 | [build-system] 80 | build-backend = "poetry.masonry.api" 81 | requires = [ 82 | "poetry>=1.0", 83 | "setuptools", # keep it here or "pip install -e" would fail 84 | ] 85 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | image: latest 4 | sphinx: 5 | configuration: docs/conf.py 6 | python: 7 | version: 3.9 8 | install: 9 | - method: pip 10 | path: . 11 | extra_requirements: 12 | - docs 13 | -------------------------------------------------------------------------------- /schierarchy/__init__.py: -------------------------------------------------------------------------------- 1 | """scvi-tools-skeleton.""" 2 | 3 | import logging 4 | 5 | from rich.console import Console 6 | from rich.logging import RichHandler 7 | 8 | from .regression._reference_model import RegressionModel 9 | from .regression._reference_module import RegressionBackgroundDetectionTechPyroModel 10 | 11 | from .logistic._logistic_model import LogisticModel 12 | from .logistic._logistic_module import HierarchicalLogisticPyroModel 13 | 14 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 15 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 16 | try: 17 | import importlib.metadata as importlib_metadata 18 | except ModuleNotFoundError: 19 | import importlib_metadata 20 | 21 | package_name = "schierarchy" 22 | __version__ = importlib_metadata.version(package_name) 23 | 24 | logger = logging.getLogger(__name__) 25 | # set the logging level 26 | logger.setLevel(logging.INFO) 27 | 28 | # nice logging outputs 29 | console = Console(force_terminal=True) 30 | if console.is_jupyter is True: 31 | console.is_jupyter = False 32 | ch = RichHandler(show_path=False, console=console, show_time=False) 33 | formatter = logging.Formatter("scHierarchy: %(message)s") 34 | ch.setFormatter(formatter) 35 | logger.addHandler(ch) 36 | 37 | # this prevents double outputs 38 | logger.propagate = False 39 | 40 | __all__ = [ 41 | "RegressionModel", 42 | "LogisticModel", 43 | "HierarchicalLogisticPyroModel", 44 | "RegressionBackgroundDetectionTechPyroModel", 45 | ] 46 | -------------------------------------------------------------------------------- /schierarchy/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/base/__init__.py -------------------------------------------------------------------------------- /schierarchy/base/_pyro_base_regression_module.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from cell2location.models.base._pyro_mixin import AutoGuideMixinModule, init_to_value 4 | from scvi._compat import Literal 5 | from scvi.module.base import PyroBaseModuleClass 6 | 7 | 8 | class RegressionBaseModule(PyroBaseModuleClass, AutoGuideMixinModule): 9 | def __init__( 10 | self, 11 | model, 12 | amortised: bool = False, 13 | encoder_mode: Literal["single", "multiple", "single-multiple"] = "single", 14 | encoder_kwargs=None, 15 | create_autoguide_kwargs: Optional[dict] = None, 16 | **kwargs, 17 | ): 18 | """ 19 | Module class which defines AutoGuide given model. Supports multiple model architectures. 20 | 21 | Parameters 22 | ---------- 23 | amortised 24 | boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters? 25 | encoder_mode 26 | Use single encoder for all variables ("single"), one encoder per variable ("multiple") 27 | or a single encoder in the first step and multiple encoders in the second step ("single-multiple"). 28 | encoder_kwargs 29 | arguments for Neural Network construction (scvi.nn.FCLayers) 30 | kwargs 31 | arguments for specific model class - e.g. number of genes, values of the prior distribution 32 | """ 33 | super().__init__() 34 | self.hist = [] 35 | 36 | self._model = model(**kwargs) 37 | self._amortised = amortised 38 | if create_autoguide_kwargs is None: 39 | create_autoguide_kwargs = dict() 40 | 41 | self._guide = self._create_autoguide( 42 | model=self.model, 43 | amortised=self.is_amortised, 44 | encoder_kwargs=encoder_kwargs, 45 | encoder_mode=encoder_mode, 46 | init_loc_fn=self.init_to_value, 47 | n_cat_list=[kwargs["n_batch"]], 48 | **create_autoguide_kwargs, 49 | ) 50 | 51 | self._get_fn_args_from_batch = self._model._get_fn_args_from_batch 52 | 53 | @property 54 | def model(self): 55 | return self._model 56 | 57 | @property 58 | def guide(self): 59 | return self._guide 60 | 61 | @property 62 | def is_amortised(self): 63 | return self._amortised 64 | 65 | @property 66 | def list_obs_plate_vars(self): 67 | return self.model.list_obs_plate_vars() 68 | 69 | def init_to_value(self, site): 70 | 71 | if getattr(self.model, "np_init_vals", None) is not None: 72 | init_vals = { 73 | k: getattr(self.model, f"init_val_{k}") 74 | for k in self.model.np_init_vals.keys() 75 | } 76 | else: 77 | init_vals = dict() 78 | return init_to_value(site=site, values=init_vals) 79 | -------------------------------------------------------------------------------- /schierarchy/logistic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/logistic/__init__.py -------------------------------------------------------------------------------- /schierarchy/logistic/_logistic_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import date 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from anndata import AnnData 8 | from cell2location.models.base._pyro_mixin import ( 9 | AutoGuideMixinModule, 10 | PltExportMixin, 11 | QuantileMixin, 12 | init_to_value, 13 | ) 14 | from pyro import clear_param_store 15 | from pyro.infer.autoguide import AutoNormalMessenger, init_to_feasible, init_to_mean 16 | from scvi import REGISTRY_KEYS 17 | from scvi.data import AnnDataManager 18 | from scvi.data.fields import CategoricalJointObsField, LayerField, NumericalObsField 19 | from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin 20 | from scvi.module.base import PyroBaseModuleClass 21 | from scvi.utils import setup_anndata_dsp 22 | 23 | from ._logistic_module import HierarchicalLogisticPyroModel 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def infer_tree(labels_df, level_keys): 29 | """ 30 | 31 | Parameters 32 | ---------- 33 | labels_df 34 | DataFrame with annotations 35 | level_keys 36 | List of column names from top to bottom levels (from less detailed to more detailed) 37 | 38 | Returns 39 | ------- 40 | List of edges between level len(n_levels - 1 ) 41 | 42 | """ 43 | # for multiple layers of hierarchy 44 | if len(level_keys) > 1: 45 | tree_inferred = [{} for i in range(len(level_keys) - 1)] 46 | for i in range(len(level_keys) - 1): 47 | layer_p = labels_df.loc[:, level_keys[i]] 48 | layer_ch = labels_df.loc[:, level_keys[i + 1]] 49 | for j in range(labels_df.shape[0]): 50 | if layer_p[j] not in tree_inferred[i].keys(): 51 | tree_inferred[i][layer_p[j]] = [layer_ch[j]] 52 | else: 53 | if layer_ch[j] not in tree_inferred[i][layer_p[j]]: 54 | tree_inferred[i][layer_p[j]].append(layer_ch[j]) 55 | # if only one level 56 | else: 57 | tree_inferred = [list(labels_df[level_keys[0]].unique())] 58 | 59 | return tree_inferred 60 | 61 | 62 | """def _setup_summary_stats(adata, level_keys): 63 | n_cells = adata.shape[0] 64 | n_vars = adata.shape[1] 65 | n_cells_per_label_per_level = [ 66 | adata.obs.groupby(group).size().values.astype(int) for group in level_keys 67 | ] 68 | 69 | n_levels = len(level_keys) 70 | 71 | summary_stats = { 72 | "n_cells": n_cells, 73 | "n_vars": n_vars, 74 | "n_levels": n_levels, 75 | "n_cells_per_label_per_level": n_cells_per_label_per_level, 76 | } 77 | 78 | adata.uns["_scvi"]["summary_stats"] = summary_stats 79 | adata.uns["tree"] = infer_tree(adata.obsm["_scvi_extra_categoricals"], level_keys) 80 | 81 | logger.info( 82 | "Successfully registered anndata object containing {} cells, {} vars, " 83 | "{} cell annotation levels.".format(n_cells, n_vars, n_levels) 84 | ) 85 | return summary_stats""" 86 | 87 | 88 | class LogisticBaseModule(PyroBaseModuleClass, AutoGuideMixinModule): 89 | def __init__( 90 | self, 91 | model, 92 | init_loc_fn=init_to_mean(fallback=init_to_feasible), 93 | guide_class=AutoNormalMessenger, 94 | guide_kwargs: Optional[dict] = None, 95 | **kwargs, 96 | ): 97 | """ 98 | Module class which defines AutoGuide given model. Supports multiple model architectures. 99 | 100 | Parameters 101 | ---------- 102 | amortised 103 | boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters? 104 | encoder_mode 105 | Use single encoder for all variables ("single"), one encoder per variable ("multiple") 106 | or a single encoder in the first step and multiple encoders in the second step ("single-multiple"). 107 | encoder_kwargs 108 | arguments for Neural Network construction (scvi.nn.FCLayers) 109 | kwargs 110 | arguments for specific model class - e.g. number of genes, values of the prior distribution 111 | """ 112 | super().__init__() 113 | self.hist = [] 114 | self._model = model(**kwargs) 115 | 116 | if guide_kwargs is None: 117 | guide_kwargs = dict() 118 | 119 | self._guide = guide_class( 120 | self.model, 121 | init_loc_fn=init_loc_fn, 122 | **guide_kwargs 123 | # create_plates=model.create_plates, 124 | ) 125 | 126 | self._get_fn_args_from_batch = self._model._get_fn_args_from_batch 127 | 128 | @property 129 | def model(self): 130 | return self._model 131 | 132 | @property 133 | def guide(self): 134 | return self._guide 135 | 136 | @property 137 | def list_obs_plate_vars(self): 138 | return self.model.list_obs_plate_vars() 139 | 140 | def init_to_value(self, site): 141 | 142 | if getattr(self.model, "np_init_vals", None) is not None: 143 | init_vals = { 144 | k: getattr(self.model, f"init_val_{k}") 145 | for k in self.model.np_init_vals.keys() 146 | } 147 | else: 148 | init_vals = dict() 149 | return init_to_value(site=site, values=init_vals) 150 | 151 | 152 | class LogisticModel( 153 | QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExportMixin, BaseModelClass 154 | ): 155 | """ 156 | Model which estimates per cluster average mRNA count account for batch effects. User-end model class. 157 | 158 | https://github.com/BayraktarLab/cell2location 159 | 160 | Parameters 161 | ---------- 162 | adata 163 | single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. 164 | level_keys 165 | List of column names from top to bottom levels (from less detailed to more detailed) 166 | use_gpu 167 | Use the GPU? 168 | **model_kwargs 169 | Keyword args for :class:`~scvi.external.LocationModelLinearDependentWMultiExperimentModel` 170 | 171 | Examples 172 | -------- 173 | TODO add example 174 | >>> 175 | """ 176 | 177 | def __init__( 178 | self, 179 | adata: AnnData, 180 | laplace_learning_mode: str = "fixed-sigma", 181 | # tree: list, 182 | model_class=None, 183 | **model_kwargs, 184 | ): 185 | # in case any other model was created before that shares the same parameter names. 186 | clear_param_store() 187 | 188 | super().__init__(adata) 189 | 190 | level_keys = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[ 191 | "field_keys" 192 | ] 193 | self.n_cells_per_label_per_level_ = [ 194 | self.adata.obs.groupby(group).size().values.astype(int) 195 | for group in level_keys 196 | ] 197 | self.n_levels_ = len(level_keys) 198 | self.level_keys_ = level_keys 199 | self.tree_ = infer_tree( 200 | self.adata_manager.get_from_registry(REGISTRY_KEYS.CAT_COVS_KEY), level_keys 201 | ) 202 | self.laplace_learning_mode_ = laplace_learning_mode 203 | 204 | if model_class is None: 205 | model_class = HierarchicalLogisticPyroModel 206 | self.module = LogisticBaseModule( 207 | model=model_class, 208 | n_obs=self.summary_stats["n_cells"], 209 | n_vars=self.summary_stats["n_vars"], 210 | n_levels=self.n_levels_, 211 | n_cells_per_label_per_level=self.n_cells_per_label_per_level_, 212 | tree=self.tree_, 213 | laplace_learning_mode=self.laplace_learning_mode_, 214 | **model_kwargs, 215 | ) 216 | self.samples = dict() 217 | self.init_params_ = self._get_init_params(locals()) 218 | 219 | @classmethod 220 | @setup_anndata_dsp.dedent 221 | def setup_anndata( 222 | cls, 223 | adata: AnnData, 224 | layer: Optional[str] = None, 225 | level_keys: Optional[list] = None, 226 | **kwargs, 227 | ): 228 | """ 229 | %(summary)s. 230 | 231 | Parameters 232 | ---------- 233 | %(param_layer)s 234 | %(param_batch_key)s 235 | %(param_labels_key)s 236 | %(param_cat_cov_keys)s 237 | %(param_cont_cov_keys)s 238 | """ 239 | setup_method_args = cls._get_setup_method_args(**locals()) 240 | adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") 241 | anndata_fields = [ 242 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 243 | CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, level_keys), 244 | NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), 245 | ] 246 | adata_manager = AnnDataManager( 247 | fields=anndata_fields, setup_method_args=setup_method_args 248 | ) 249 | adata_manager.register_fields(adata, **kwargs) 250 | cls.register_manager(adata_manager) 251 | 252 | def _export2adata(self, samples): 253 | r""" 254 | Export key model variables and samples 255 | 256 | Parameters 257 | ---------- 258 | samples 259 | dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()`` 260 | 261 | Returns 262 | ------- 263 | Updated dictionary with additional details is saved to ``adata.uns['mod']``. 264 | """ 265 | # add factor filter and samples of all parameters to unstructured data 266 | results = { 267 | "model_name": str(self.module.__class__.__name__), 268 | "date": str(date.today()), 269 | "var_names": self.adata.var_names.tolist(), 270 | "obs_names": self.adata.obs_names.tolist(), 271 | "post_sample_means": samples["post_sample_means"] if "post_sample_means" in samples else None, 272 | "post_sample_stds": samples["post_sample_stds"] if "post_sample_stds" in samples else None, 273 | } 274 | # add posterior quantiles 275 | for k, v in samples.items(): 276 | if k.startswith("post_sample_"): 277 | results[k] = v 278 | 279 | return results 280 | 281 | def export_posterior( 282 | self, 283 | adata, 284 | prediction: bool = False, 285 | use_quantiles: bool = False, 286 | sample_kwargs: Optional[dict] = None, 287 | export_slot: str = "mod", 288 | add_to_varm: list = ["means", "stds", "q05", "q95"], 289 | ): 290 | """ 291 | Summarise posterior distribution and export results (cell abundance) to anndata object: 292 | 1. adata.obsm: Estimated references expression signatures (average mRNA count in each cell type), 293 | as pd.DataFrames for each posterior distribution summary `add_to_varm`, 294 | posterior mean, sd, 5% and 95% quantiles (['means', 'stds', 'q05', 'q95']). 295 | If export to adata.varm fails with error, results are saved to adata.var instead. 296 | 2. adata.uns: Posterior of all parameters, model name, date, 297 | cell type names ('factor_names'), obs and var names. 298 | 299 | Parameters 300 | ---------- 301 | adata 302 | anndata object where results should be saved 303 | prediction 304 | Prediction mode predicts cell labels on new data. 305 | sample_kwargs 306 | arguments for self.sample_posterior (generating and summarising posterior samples), namely: 307 | num_samples - number of samples to use (Default = 1000). 308 | batch_size - data batch size (keep low enough to fit on GPU, default 2048). 309 | use_gpu - use gpu for generating samples? 310 | export_slot 311 | adata.uns slot where to export results 312 | add_to_varm 313 | posterior distribution summary to export in adata.varm (['means', 'stds', 'q05', 'q95']). 314 | Returns 315 | ------- 316 | 317 | """ 318 | 319 | sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() 320 | 321 | label_keys = list( 322 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[ 323 | "field_keys" 324 | ] 325 | ) 326 | 327 | # when prediction mode change to evaluation mode and swap adata object 328 | if prediction: 329 | self.module.eval() 330 | self.module.model.prediction = True 331 | # use version of this function for prediction 332 | self.module._get_fn_args_from_batch = ( 333 | self.module.model._get_fn_args_from_batch 334 | ) 335 | # resize plates for according to the validation object 336 | self.module.model.n_obs = adata.n_obs 337 | # create index column 338 | adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") 339 | # for minibatch learning, selected indices lay in "ind_x" 340 | # scvi.data.register_tensor_from_anndata( 341 | # adata, 342 | # registry_key="ind_x", 343 | # adata_attr_name="obs", 344 | # adata_key_name="_indices", 345 | # ) 346 | # if all columns with labels don't exist, create them and fill with 0s 347 | if np.all(~np.isin(label_keys, adata.obs.columns)): 348 | adata.obs.loc[:, label_keys] = 0 349 | # substitute adata object 350 | adata_train = self.adata.copy() 351 | self.adata = self._validate_anndata(adata) 352 | # self.adata = adata 353 | 354 | if use_quantiles: 355 | add_to_varm = [i for i in add_to_varm if (i not in ["means", "stds"]) and ("q" in i)] 356 | if len(add_to_varm) == 0: 357 | raise ValueError("No quantiles to export - please add add_to_obsm=['q05', 'q50', 'q95'].") 358 | self.samples = dict() 359 | for i in add_to_varm: 360 | q = float(f"0.{i[1:]}") 361 | self.samples[f"post_sample_{i}"] = self.posterior_quantile(q=q, **sample_kwargs) 362 | else: 363 | # generate samples from posterior distributions for all parameters 364 | # and compute mean, 5%/95% quantiles and standard deviation 365 | self.samples = self.sample_posterior(**sample_kwargs) 366 | 367 | # revert adata object substitution 368 | self.adata = adata_train 369 | self.module.eval() 370 | self.module.model.prediction = False 371 | # re-set default version of this function 372 | self.module._get_fn_args_from_batch = ( 373 | self.module.model._get_fn_args_from_batch 374 | ) 375 | obs_names = adata.obs_names 376 | else: 377 | if use_quantiles: 378 | add_to_varm = [i for i in add_to_varm if (i not in ["means", "stds"]) and ("q" in i)] 379 | if len(add_to_varm) == 0: 380 | raise ValueError("No quantiles to export - please add add_to_obsm=['q05', 'q50', 'q95'].") 381 | self.samples = dict() 382 | for i in add_to_varm: 383 | q = float(f"0.{i[1:]}") 384 | self.samples[f"post_sample_{i}"] = self.posterior_quantile(q=q, **sample_kwargs) 385 | else: 386 | # generate samples from posterior distributions for all parameters 387 | # and compute mean, 5%/95% quantiles and standard deviation 388 | self.samples = self.sample_posterior(**sample_kwargs) 389 | obs_names = self.adata.obs_names 390 | 391 | # export posterior distribution summary for all parameters and 392 | # annotation (model, date, var, obs and cell type names) to anndata object 393 | adata.uns[export_slot] = self._export2adata(self.samples) 394 | 395 | # export estimated expression in each cluster 396 | # first convert np.arrays to pd.DataFrames with cell type and observation names 397 | # data frames contain mean, 5%/95% quantiles and standard deviation, denoted by a prefix 398 | for i in range(self.n_levels_): 399 | categories = list( 400 | list( 401 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[ 402 | "mappings" 403 | ].values() 404 | )[i] 405 | ) 406 | for k in add_to_varm: 407 | sample_df = pd.DataFrame( 408 | self.samples[f"post_sample_{k}"].get(f"weight_level_{i}", None), 409 | columns=[f"{k}_weight_{label_keys[i]}_{c}" for c in categories], 410 | index=self.adata.var_names, 411 | ) 412 | try: 413 | adata.varm[f"{k}_weight_{label_keys[i]}"] = sample_df.loc[ 414 | adata.var_names, : 415 | ] 416 | except ValueError: 417 | # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names` 418 | adata.var[sample_df.columns] = sample_df.loc[adata.var_names, :] 419 | 420 | sample_df = pd.DataFrame( 421 | self.samples[f"post_sample_{k}"].get(f"label_prob_{i}", None), 422 | columns=obs_names, 423 | index=[f"{k}_label_{label_keys[i]}_{c}" for c in categories], 424 | ).T 425 | try: 426 | # TODO change to user input name 427 | adata.obsm[f"{k}_label_prob_{label_keys[i]}"] = sample_df.loc[ 428 | adata.obs_names, : 429 | ] 430 | except ValueError: 431 | # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names` 432 | adata.obs[sample_df.columns] = sample_df.loc[adata.obs_names, :] 433 | 434 | return adata 435 | 436 | 437 | # TODO plot QC - prediction accuracy curve 438 | -------------------------------------------------------------------------------- /schierarchy/logistic/_logistic_module.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pyro 4 | import pyro.distributions as dist 5 | import torch 6 | from pyro.nn import PyroModule 7 | from scvi import REGISTRY_KEYS 8 | 9 | 10 | class HierarchicalLogisticPyroModel(PyroModule): 11 | 12 | prediction = False 13 | 14 | def __init__( 15 | self, 16 | n_obs, 17 | n_vars, 18 | n_levels, 19 | n_cells_per_label_per_level, 20 | tree, 21 | laplace_prior={"mu": 0.0, "sigma": 0.5, "exp_rate": 3.0}, 22 | laplace_learning_mode="fixed-sigma", 23 | init_vals: Optional[dict] = None, 24 | dropout_p: float = 0.1, 25 | use_dropout: bool = True, 26 | use_gene_dropout: bool = False, 27 | ): 28 | r""" 29 | 30 | Parameters 31 | ---------- 32 | n_obs 33 | n_vars 34 | n_batch 35 | n_extra_categoricals 36 | laplace_prior 37 | laplace_learning_mode = 'fixed-sigma', 'learn-sigma-single', 'learn-sigma-gene', 'learn-sigma-celltype', 38 | 'learn-sigma-gene-celltype' 39 | """ 40 | 41 | ############# Initialise parameters ################ 42 | super().__init__() 43 | 44 | self.dropout = torch.nn.Dropout(p=dropout_p) 45 | 46 | self.n_obs = n_obs 47 | self.n_vars = n_vars 48 | self.n_levels = n_levels 49 | self.n_cells_per_label_per_level = n_cells_per_label_per_level 50 | self.tree = tree 51 | self.laplace_prior = laplace_prior 52 | self.laplace_learning_mode = laplace_learning_mode 53 | self.use_dropout = use_dropout 54 | self.use_gene_dropout = use_gene_dropout 55 | 56 | if self.laplace_learning_mode not in [ 57 | "fixed-sigma", 58 | "learn-sigma-single", 59 | "learn-sigma-gene", 60 | "learn-sigma-celltype", 61 | "learn-sigma-gene-celltype", 62 | ]: 63 | raise NotImplementedError 64 | 65 | if (init_vals is not None) & (type(init_vals) is dict): 66 | self.np_init_vals = init_vals 67 | for k in init_vals.keys(): 68 | self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k])) 69 | 70 | for i in range(self.n_levels): 71 | self.register_buffer( 72 | f"n_cells_per_label_per_level_{i}", 73 | torch.tensor(n_cells_per_label_per_level[i]), 74 | ) 75 | 76 | self.register_buffer( 77 | "laplace_prior_mu", 78 | torch.tensor(self.laplace_prior["mu"]), 79 | ) 80 | 81 | self.register_buffer( 82 | "laplace_prior_sigma", 83 | torch.tensor(self.laplace_prior["sigma"]), 84 | ) 85 | 86 | self.register_buffer( 87 | "exponential_prior_rate", 88 | torch.tensor(self.laplace_prior["exp_rate"]), 89 | ) 90 | 91 | self.register_buffer("ones", torch.ones((1, 1))) 92 | 93 | @property 94 | def layers_size(self): 95 | if self.tree is not None: 96 | if len(self.tree) > 0 and type(self.tree[0]) is not list: 97 | return [len(x) for x in self.tree] + [ 98 | len( 99 | [item for sublist in self.tree[-1].values() for item in sublist] 100 | ) 101 | ] 102 | else: 103 | return [len(self.tree[0])] 104 | else: 105 | return None 106 | 107 | @property 108 | def _get_fn_args_from_batch(self): 109 | if self.prediction: 110 | return self._get_fn_args_from_batch_prediction 111 | else: 112 | return self._get_fn_args_from_batch_training 113 | 114 | @staticmethod 115 | def _get_fn_args_from_batch_training(tensor_dict): 116 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY] 117 | idx = tensor_dict["ind_x"].long().squeeze() 118 | levels = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY] 119 | return (x_data, idx, levels), {} 120 | 121 | @staticmethod 122 | def _get_fn_args_from_batch_prediction(tensor_dict): 123 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY] 124 | idx = tensor_dict["ind_x"].long().squeeze() 125 | return (x_data, idx, idx), {} 126 | 127 | ############# Define the model ################ 128 | 129 | def create_plates(self, x_data, idx, levels): 130 | return pyro.plate("obs_plate", size=self.n_obs, dim=-1, subsample=idx) 131 | 132 | def list_obs_plate_vars(self): 133 | """Create a dictionary with the name of observation/minibatch plate, 134 | indexes of model args to provide to encoder, 135 | variable names that belong to the observation plate 136 | and the number of dimensions in non-plate axis of each variable""" 137 | 138 | return { 139 | "name": "obs_plate", 140 | "input": [], # expression data + (optional) batch index 141 | "input_transform": [], # how to transform input data before passing to NN 142 | "sites": {}, 143 | } 144 | 145 | def forward(self, x_data, idx, levels): 146 | obs_plate = self.create_plates(x_data, idx, levels) 147 | 148 | if self.use_dropout: 149 | x_data = self.dropout(x_data) 150 | if self.use_gene_dropout: 151 | x_data = x_data * self.dropout(self.ones.expand([1, self.n_vars]).clone()) 152 | 153 | f = [] 154 | for i in range(self.n_levels): 155 | # create weights for level i 156 | if self.laplace_learning_mode == "fixed-sigma": 157 | w_i = pyro.sample( 158 | f"weight_level_{i}", 159 | dist.Laplace(self.laplace_prior_mu, self.laplace_prior_sigma) 160 | .expand([self.n_vars, self.layers_size[i]]) 161 | .to_event(2), 162 | ) 163 | elif self.laplace_learning_mode == "learn-sigma-single": 164 | sigma_i = pyro.sample( 165 | f"sigma_level_{i}", 166 | dist.Exponential(self.exponential_prior_rate).expand([1, 1]), 167 | ) 168 | w_i = pyro.sample( 169 | f"weight_level_{i}", 170 | dist.Laplace(self.laplace_prior_mu, sigma_i) 171 | .expand([self.n_vars, self.layers_size[i]]) 172 | .to_event(2), 173 | ) 174 | elif self.laplace_learning_mode == "learn-sigma-gene": 175 | sigma_ig = pyro.sample( 176 | f"sigma_ig_level_{i}", 177 | dist.Exponential(self.exponential_prior_rate) 178 | .expand([self.n_vars]) 179 | .to_event(1), 180 | ) 181 | w_i = pyro.sample( 182 | f"weight_level_{i}", 183 | dist.Laplace(self.laplace_prior_mu, sigma_ig[:, None]) 184 | .expand([self.n_vars, self.layers_size[i]]) 185 | .to_event(2), 186 | ) 187 | elif self.laplace_learning_mode == "learn-sigma-celltype": 188 | sigma_ic = pyro.sample( 189 | f"sigma_ic_level_{i}", 190 | dist.Exponential(self.exponential_prior_rate) 191 | .expand([self.layers_size[i]]) 192 | .to_event(1), 193 | ) 194 | w_i = pyro.sample( 195 | f"weight_level_{i}", 196 | dist.Laplace(self.laplace_prior_mu, sigma_ic[None, :]) 197 | .expand([self.n_vars, self.layers_size[i]]) 198 | .to_event(2), 199 | ) 200 | elif self.laplace_learning_mode == "learn-sigma-gene-celltype": 201 | sigma_ig = pyro.sample( 202 | f"sigma_ig_level_{i}", 203 | dist.Exponential(self.exponential_prior_rate) 204 | .expand([self.n_vars]) 205 | .to_event(1), 206 | ) 207 | 208 | sigma_ic = pyro.sample( 209 | f"sigma_ic_level_{i}", 210 | dist.Exponential(self.exponential_prior_rate) 211 | .expand([self.layers_size[i]]) 212 | .to_event(1), 213 | ) 214 | w_i = pyro.sample( 215 | f"weight_level_{i}", 216 | dist.Laplace( 217 | self.laplace_prior_mu, sigma_ig[:, None] @ sigma_ic[None, :] 218 | ).to_event(2), 219 | ) 220 | # parameter for cluster size weight normalisation w / sqrt(n_cells per cluster) 221 | n_cells_per_label = self.get_buffer(f"n_cells_per_label_per_level_{i}") 222 | if i == 0: 223 | # computer f for level 0 (it is independent from the previous level as it doesn't exist) 224 | f_i = torch.nn.functional.softmax( 225 | torch.matmul(x_data, w_i / n_cells_per_label**0.5), dim=1 226 | ) 227 | else: 228 | # initiate f for level > 0 229 | f_i = torch.ones((x_data.shape[0], self.layers_size[i])).to( 230 | x_data.device 231 | ) 232 | # compute f as f_(i) * f_(i-1) for each cluster group under the parent node 233 | # multiplication could handle non-tree structures (multiple parents for one child cluster) 234 | for parent, children in self.tree[i - 1].items(): 235 | f_i[:, children] *= ( 236 | torch.nn.functional.softmax( 237 | torch.matmul( 238 | x_data, 239 | w_i[:, children] 240 | / (n_cells_per_label[children])[None, :] ** 0.5, 241 | ), 242 | dim=1, 243 | ) 244 | * f[i - 1][:, parent, None] 245 | ) 246 | # record level i probabilities as level i+1 depends on them 247 | f.append(f_i) 248 | with obs_plate: 249 | pyro.deterministic(f"label_prob_{i}", f_i.T) 250 | if not self.prediction: 251 | pyro.sample( 252 | f"likelihood_{i}", dist.Categorical(f_i), obs=levels[:, i] 253 | ) 254 | -------------------------------------------------------------------------------- /schierarchy/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/regression/__init__.py -------------------------------------------------------------------------------- /schierarchy/regression/_reference_model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from anndata import AnnData 7 | from cell2location.cluster_averages import compute_cluster_averages 8 | from cell2location.models.base._pyro_base_reference_module import RegressionBaseModule 9 | from cell2location.models.base._pyro_mixin import PltExportMixin, QuantileMixin 10 | from pyro import clear_param_store 11 | from scvi import REGISTRY_KEYS 12 | from scvi.data import AnnDataManager 13 | from scvi.data.fields import ( 14 | CategoricalJointObsField, 15 | CategoricalObsField, 16 | LayerField, 17 | NumericalJointObsField, 18 | NumericalObsField, 19 | ) 20 | from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin 21 | from scvi.utils import setup_anndata_dsp 22 | 23 | from ._reference_module import RegressionBackgroundDetectionTechPyroModel 24 | 25 | 26 | class RegressionModel( 27 | QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExportMixin, BaseModelClass 28 | ): 29 | """ 30 | Model which estimates per cluster average mRNA count account for batch effects. User-end model class. 31 | 32 | https://github.com/BayraktarLab/cell2location 33 | 34 | Parameters 35 | ---------- 36 | adata 37 | single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. 38 | use_gpu 39 | Use the GPU? 40 | **model_kwargs 41 | Keyword args for :class:`~scvi.external.LocationModelLinearDependentWMultiExperimentModel` 42 | 43 | Examples 44 | -------- 45 | TODO add example 46 | >>> 47 | """ 48 | 49 | def __init__( 50 | self, 51 | adata: AnnData, 52 | model_class=None, 53 | use_average_as_initial: bool = True, 54 | **model_kwargs, 55 | ): 56 | # in case any other model was created before that shares the same parameter names. 57 | clear_param_store() 58 | 59 | super().__init__(adata) 60 | 61 | if model_class is None: 62 | model_class = RegressionBackgroundDetectionTechPyroModel 63 | 64 | # annotations for cell types 65 | self.n_factors_ = self.summary_stats["n_labels"] 66 | self.factor_names_ = self.adata_manager.get_state_registry( 67 | REGISTRY_KEYS.LABELS_KEY 68 | ).categorical_mapping 69 | # annotations for extra categorical covariates 70 | if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry: 71 | self.extra_categoricals_ = self.adata_manager.get_state_registry( 72 | REGISTRY_KEYS.CAT_COVS_KEY 73 | ) 74 | self.n_extra_categoricals_ = self.extra_categoricals_.n_cats_per_key 75 | model_kwargs["n_extra_categoricals"] = self.n_extra_categoricals_ 76 | 77 | # use per class average as initial value 78 | if use_average_as_initial: 79 | # compute cluster average expression 80 | aver = self._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY) 81 | model_kwargs["init_vals"] = { 82 | "per_cluster_mu_fg": aver.values.T.astype("float32") + 0.0001 83 | } 84 | 85 | self.module = RegressionBaseModule( 86 | model=model_class, 87 | n_obs=self.summary_stats["n_cells"], 88 | n_vars=self.summary_stats["n_vars"], 89 | n_factors=self.n_factors_, 90 | n_batch=self.summary_stats["n_batch"], 91 | **model_kwargs, 92 | ) 93 | self._model_summary_string = f'RegressionBackgroundDetectionTech model with the following params: \nn_factors: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} ' 94 | self.init_params_ = self._get_init_params(locals()) 95 | 96 | @classmethod 97 | @setup_anndata_dsp.dedent 98 | def setup_anndata( 99 | cls, 100 | adata: AnnData, 101 | layer: Optional[str] = None, 102 | batch_key: Optional[str] = None, 103 | labels_key: Optional[str] = None, 104 | categorical_covariate_keys: Optional[List[str]] = None, 105 | continuous_covariate_keys: Optional[List[str]] = None, 106 | **kwargs, 107 | ): 108 | """ 109 | %(summary)s. 110 | 111 | Parameters 112 | ---------- 113 | %(param_layer)s 114 | %(param_batch_key)s 115 | %(param_labels_key)s 116 | %(param_cat_cov_keys)s 117 | %(param_cont_cov_keys)s 118 | """ 119 | setup_method_args = cls._get_setup_method_args(**locals()) 120 | adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") 121 | anndata_fields = [ 122 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 123 | CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), 124 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), 125 | CategoricalJointObsField( 126 | REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys 127 | ), 128 | NumericalJointObsField( 129 | REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys 130 | ), 131 | NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), 132 | ] 133 | adata_manager = AnnDataManager( 134 | fields=anndata_fields, setup_method_args=setup_method_args 135 | ) 136 | adata_manager.register_fields(adata, **kwargs) 137 | cls.register_manager(adata_manager) 138 | 139 | def train( 140 | self, 141 | max_epochs: Optional[int] = None, 142 | batch_size: int = 2500, 143 | train_size: float = 1, 144 | lr: float = 0.002, 145 | **kwargs, 146 | ): 147 | """Train the model with useful defaults 148 | 149 | Parameters 150 | ---------- 151 | max_epochs 152 | Number of passes through the dataset. If `None`, defaults to 153 | `np.min([round((20000 / n_cells) * 400), 400])` 154 | train_size 155 | Size of training set in the range [0.0, 1.0]. 156 | batch_size 157 | Minibatch size to use during training. If `None`, no minibatching occurs and all 158 | data is copied to device (e.g., GPU). 159 | lr 160 | Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). 161 | Specifying optimiser via plan_kwargs overrides this choice of lr. 162 | kwargs 163 | Other arguments to scvi.model.base.PyroSviTrainMixin().train() method 164 | """ 165 | 166 | kwargs["max_epochs"] = max_epochs 167 | kwargs["batch_size"] = batch_size 168 | kwargs["train_size"] = train_size 169 | kwargs["lr"] = lr 170 | 171 | super().train(**kwargs) 172 | 173 | def _compute_cluster_averages(self, key=REGISTRY_KEYS.LABELS_KEY): 174 | """ 175 | Compute average per cluster (key=REGISTRY_KEYS.LABELS_KEY) or per batch (key=REGISTRY_KEYS.BATCH_KEY). 176 | 177 | Returns 178 | ------- 179 | pd.DataFrame with variables in rows and labels in columns 180 | """ 181 | # find cell label column 182 | label_col = self.adata_manager.get_state_registry(key).original_key 183 | 184 | # find data slot 185 | x_dict = self.adata_manager.data_registry["X"] 186 | if x_dict["attr_name"] == "X": 187 | use_raw = False 188 | else: 189 | use_raw = True 190 | if x_dict["attr_name"] == "layers": 191 | layer = x_dict["attr_key"] 192 | else: 193 | layer = None 194 | 195 | # compute mean expression of each gene in each cluster/batch 196 | aver = compute_cluster_averages( 197 | self.adata, labels=label_col, use_raw=use_raw, layer=layer 198 | ) 199 | 200 | return aver 201 | 202 | def export_posterior( 203 | self, 204 | adata, 205 | sample_kwargs: Optional[dict] = None, 206 | export_slot: str = "mod", 207 | add_to_varm: list = ["means", "stds", "q05", "q95"], 208 | scale_average_detection: bool = True, 209 | ): 210 | """ 211 | Summarise posterior distribution and export results (cell abundance) to anndata object: 212 | 1. adata.obsm: Estimated references expression signatures (average mRNA count in each cell type), 213 | as pd.DataFrames for each posterior distribution summary `add_to_varm`, 214 | posterior mean, sd, 5% and 95% quantiles (['means', 'stds', 'q05', 'q95']). 215 | If export to adata.varm fails with error, results are saved to adata.var instead. 216 | 2. adata.uns: Posterior of all parameters, model name, date, 217 | cell type names ('factor_names'), obs and var names. 218 | 219 | Parameters 220 | ---------- 221 | adata 222 | anndata object where results should be saved 223 | sample_kwargs 224 | arguments for self.sample_posterior (generating and summarising posterior samples), namely: 225 | num_samples - number of samples to use (Default = 1000). 226 | batch_size - data batch size (keep low enough to fit on GPU, default 2048). 227 | use_gpu - use gpu for generating samples? 228 | export_slot 229 | adata.uns slot where to export results 230 | add_to_varm 231 | posterior distribution summary to export in adata.varm (['means', 'stds', 'q05', 'q95']). 232 | Returns 233 | ------- 234 | 235 | """ 236 | 237 | sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() 238 | 239 | # generate samples from posterior distributions for all parameters 240 | # and compute mean, 5%/95% quantiles and standard deviation 241 | self.samples = self.sample_posterior(**sample_kwargs) 242 | 243 | # export posterior distribution summary for all parameters and 244 | # annotation (model, date, var, obs and cell type names) to anndata object 245 | adata.uns[export_slot] = self._export2adata(self.samples) 246 | 247 | # export estimated expression in each cluster 248 | # first convert np.arrays to pd.DataFrames with cell type and observation names 249 | # data frames contain mean, 5%/95% quantiles and standard deviation, denoted by a prefix 250 | for k in add_to_varm: 251 | sample_df = self.sample2df_vars( 252 | self.samples, 253 | site_name="per_cluster_mu_fg", 254 | summary_name=k, 255 | name_prefix="", 256 | ) 257 | if scale_average_detection and ( 258 | "detection_y_c" in list(self.samples[f"post_sample_{k}"].keys()) 259 | ): 260 | sample_df = ( 261 | sample_df * self.samples[f"post_sample_{k}"]["detection_y_c"].mean() 262 | ) 263 | try: 264 | adata.varm[f"{k}_per_cluster_mu_fg"] = sample_df.loc[adata.var.index, :] 265 | except ValueError: 266 | # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names` 267 | adata.var[sample_df.columns] = sample_df.loc[adata.var.index, :] 268 | 269 | return adata 270 | 271 | def plot_QC( 272 | self, 273 | summary_name: str = "means", 274 | use_n_obs: int = 1000, 275 | scale_average_detection: bool = True, 276 | ): 277 | """ 278 | Show quality control plots: 279 | 1. Reconstruction accuracy to assess if there are any issues with model training. 280 | The plot should be roughly diagonal, strong deviations signal problems that need to be investigated. 281 | Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random 282 | observations are used to speed up computation. 283 | 284 | 2. Estimated reference expression signatures (accounting for batch effect) 285 | compared to average expression in each cluster. We expect the signatures to be different 286 | from average when batch effects are present, however, when this plot is very different from 287 | a perfect diagonal, such as very low values on Y-axis, non-zero density everywhere) 288 | it indicates problems with signature estimation. 289 | 290 | Parameters 291 | ---------- 292 | summary_name 293 | posterior distribution summary to use ('means', 'stds', 'q05', 'q95') 294 | 295 | Returns 296 | ------- 297 | 298 | """ 299 | 300 | super().plot_QC(summary_name=summary_name, use_n_obs=use_n_obs) 301 | plt.show() 302 | 303 | inf_aver = self.samples[f"post_sample_{summary_name}"]["per_cluster_mu_fg"].T 304 | if scale_average_detection and ( 305 | "detection_y_c" in list(self.samples[f"post_sample_{summary_name}"].keys()) 306 | ): 307 | inf_aver = ( 308 | inf_aver 309 | * self.samples[f"post_sample_{summary_name}"]["detection_y_c"].mean() 310 | ) 311 | aver = self._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY) 312 | aver = aver[self.factor_names_] 313 | 314 | plt.hist2d( 315 | np.log10(aver.values.flatten() + 1), 316 | np.log10(inf_aver.flatten() + 1), 317 | bins=50, 318 | norm=matplotlib.colors.LogNorm(), 319 | ) 320 | plt.xlabel("Mean expression for every gene in every cluster") 321 | plt.ylabel("Estimated expression for every gene in every cluster") 322 | plt.show() 323 | -------------------------------------------------------------------------------- /schierarchy/regression/_reference_module.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pyro 6 | import pyro.distributions as dist 7 | import torch 8 | from pyro.nn import PyroModule 9 | from scvi import REGISTRY_KEYS 10 | from scvi.nn import one_hot 11 | 12 | 13 | class RegressionBackgroundDetectionTechPyroModel(PyroModule): 14 | r""" 15 | Given cell type annotation for each cell, the corresponding reference cell type signatures :math:`g_{f,g}`, 16 | which represent the average mRNA count of each gene `g` in each cell type `f={1, .., F}`, 17 | are estimated from sc/snRNA-seq data using Negative Binomial regression, 18 | which allows to robustly combine data across technologies and batches. 19 | 20 | This model combines batches, and treats data :math:`D` as Negative Binomial distributed, 21 | given mean :math:`\mu` and overdispersion :math:`\alpha`: 22 | 23 | .. math:: 24 | D_{c,g} \sim \mathtt{NB}(alpha=\alpha_{g}, mu=\mu_{c,g}) 25 | .. math:: 26 | \mu_{c,g} = (\mu_{f,g} + s_{e,g}) * y_e * y_{t,g} 27 | 28 | Which is equivalent to: 29 | 30 | .. math:: 31 | D_{c,g} \sim \mathtt{Poisson}(\mathtt{Gamma}(\alpha_{f,g}, \alpha_{f,g} / \mu_{c,g})) 32 | 33 | Here, :math:`\mu_{f,g}` denotes average mRNA count in each cell type :math:`f` for each gene :math:`g`; 34 | :math:`y_c` denotes normalisation for each experiment :math:`e` to account for sequencing depth. 35 | :math:`y_{t,g}` denotes per gene :math:`g` detection efficiency normalisation for each technology :math:`t`. 36 | 37 | """ 38 | 39 | def __init__( 40 | self, 41 | n_obs, 42 | n_vars, 43 | n_factors, 44 | n_batch, 45 | n_extra_categoricals=None, 46 | alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0}, 47 | gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, 48 | gene_add_mean_hyp_prior={ 49 | "alpha": 1.0, 50 | "beta": 100.0, 51 | }, 52 | detection_hyp_prior={"mean_alpha": 1.0, "mean_beta": 1.0}, 53 | gene_tech_prior={"mean": 1, "alpha": 200}, 54 | init_vals: Optional[dict] = None, 55 | ): 56 | """ 57 | 58 | Parameters 59 | ---------- 60 | n_obs 61 | n_vars 62 | n_factors 63 | n_batch 64 | n_extra_categoricals 65 | alpha_g_phi_hyp_prior 66 | gene_add_alpha_hyp_prior 67 | gene_add_mean_hyp_prior 68 | detection_hyp_prior 69 | gene_tech_prior 70 | """ 71 | 72 | ############# Initialise parameters ################ 73 | super().__init__() 74 | 75 | self.n_obs = n_obs 76 | self.n_vars = n_vars 77 | self.n_factors = n_factors 78 | self.n_batch = n_batch 79 | self.n_extra_categoricals = n_extra_categoricals 80 | 81 | self.alpha_g_phi_hyp_prior = alpha_g_phi_hyp_prior 82 | self.gene_add_alpha_hyp_prior = gene_add_alpha_hyp_prior 83 | self.gene_add_mean_hyp_prior = gene_add_mean_hyp_prior 84 | self.detection_hyp_prior = detection_hyp_prior 85 | self.gene_tech_prior = gene_tech_prior 86 | 87 | if (init_vals is not None) & (type(init_vals) is dict): 88 | self.np_init_vals = init_vals 89 | for k in init_vals.keys(): 90 | self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k])) 91 | 92 | self.register_buffer( 93 | "detection_mean_hyp_prior_alpha", 94 | torch.tensor(self.detection_hyp_prior["mean_alpha"]), 95 | ) 96 | self.register_buffer( 97 | "detection_mean_hyp_prior_beta", 98 | torch.tensor(self.detection_hyp_prior["mean_beta"]), 99 | ) 100 | self.register_buffer( 101 | "gene_tech_prior_alpha", 102 | torch.tensor(self.gene_tech_prior["alpha"]), 103 | ) 104 | self.register_buffer( 105 | "gene_tech_prior_beta", 106 | torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]), 107 | ) 108 | 109 | self.register_buffer( 110 | "alpha_g_phi_hyp_prior_alpha", 111 | torch.tensor(self.alpha_g_phi_hyp_prior["alpha"]), 112 | ) 113 | self.register_buffer( 114 | "alpha_g_phi_hyp_prior_beta", 115 | torch.tensor(self.alpha_g_phi_hyp_prior["beta"]), 116 | ) 117 | self.register_buffer( 118 | "gene_add_alpha_hyp_prior_alpha", 119 | torch.tensor(self.gene_add_alpha_hyp_prior["alpha"]), 120 | ) 121 | self.register_buffer( 122 | "gene_add_alpha_hyp_prior_beta", 123 | torch.tensor(self.gene_add_alpha_hyp_prior["beta"]), 124 | ) 125 | self.register_buffer( 126 | "gene_add_mean_hyp_prior_alpha", 127 | torch.tensor(self.gene_add_mean_hyp_prior["alpha"]), 128 | ) 129 | self.register_buffer( 130 | "gene_add_mean_hyp_prior_beta", 131 | torch.tensor(self.gene_add_mean_hyp_prior["beta"]), 132 | ) 133 | 134 | self.register_buffer("ones", torch.ones((1, 1))) 135 | self.register_buffer("eps", torch.tensor(1e-8)) 136 | 137 | ############# Define the model ################ 138 | @staticmethod 139 | def _get_fn_args_from_batch_no_cat(tensor_dict): 140 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY] 141 | ind_x = tensor_dict["ind_x"].long().squeeze() 142 | batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] 143 | label_index = tensor_dict[REGISTRY_KEYS.LABELS_KEY] 144 | return (x_data, ind_x, batch_index, label_index, label_index), {} 145 | 146 | @staticmethod 147 | def _get_fn_args_from_batch_cat(tensor_dict): 148 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY] 149 | ind_x = tensor_dict["ind_x"].long().squeeze() 150 | batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] 151 | label_index = tensor_dict[REGISTRY_KEYS.LABELS_KEY] 152 | extra_categoricals = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY] 153 | return (x_data, ind_x, batch_index, label_index, extra_categoricals), {} 154 | 155 | @property 156 | def _get_fn_args_from_batch(self): 157 | if self.n_extra_categoricals is not None: 158 | return self._get_fn_args_from_batch_cat 159 | else: 160 | return self._get_fn_args_from_batch_no_cat 161 | 162 | def create_plates(self, x_data, idx, batch_index, label_index, extra_categoricals): 163 | return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx) 164 | 165 | def list_obs_plate_vars(self): 166 | """Create a dictionary with the name of observation/minibatch plate, 167 | indexes of model args to provide to encoder, 168 | variable names that belong to the observation plate 169 | and the number of dimensions in non-plate axis of each variable""" 170 | 171 | return { 172 | "name": "obs_plate", 173 | "input": [], # expression data + (optional) batch index 174 | "input_transform": [], # how to transform input data before passing to NN 175 | "sites": {}, 176 | } 177 | 178 | def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): 179 | 180 | obs2sample = one_hot(batch_index, self.n_batch) 181 | obs2label = one_hot(label_index, self.n_factors) 182 | if self.n_extra_categoricals is not None: 183 | obs2extra_categoricals = torch.cat( 184 | [ 185 | one_hot( 186 | extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)), 187 | n_cat, 188 | ) 189 | for i, n_cat in enumerate(self.n_extra_categoricals) 190 | ], 191 | dim=1, 192 | ) 193 | 194 | obs_plate = self.create_plates( 195 | x_data, idx, batch_index, label_index, extra_categoricals 196 | ) 197 | 198 | # =====================Per-cluster average mRNA count ======================= # 199 | # \mu_{f,g} 200 | per_cluster_mu_fg = pyro.sample( 201 | "per_cluster_mu_fg", 202 | dist.Gamma(self.ones, self.ones) 203 | .expand([self.n_factors, self.n_vars]) 204 | .to_event(2), 205 | ) 206 | 207 | # =====================Gene-specific multiplicative component ======================= # 208 | # `y_{t, g}` per gene multiplicative effect that explains the difference 209 | # in sensitivity between genes in each technology or covariate effect 210 | if self.n_extra_categoricals is not None: 211 | detection_tech_gene_tg = pyro.sample( 212 | "detection_tech_gene_tg", 213 | dist.Gamma( 214 | self.ones * self.gene_tech_prior_alpha, 215 | self.ones * self.gene_tech_prior_beta, 216 | ) 217 | .expand([np.sum(self.n_extra_categoricals), self.n_vars]) 218 | .to_event(2), 219 | ) 220 | 221 | # =====================Cell-specific detection efficiency ======================= # 222 | # y_c with hierarchical mean prior 223 | detection_mean_y_e = pyro.sample( 224 | "detection_mean_y_e", 225 | dist.Gamma( 226 | self.ones * self.detection_mean_hyp_prior_alpha, 227 | self.ones * self.detection_mean_hyp_prior_beta, 228 | ) 229 | .expand([self.n_batch, 1]) 230 | .to_event(2), 231 | ) 232 | detection_y_c = obs2sample @ detection_mean_y_e # (self.n_obs, 1) 233 | 234 | # =====================Gene-specific additive component ======================= # 235 | # s_{e,g} accounting for background, free-floating RNA 236 | s_g_gene_add_alpha_hyp = pyro.sample( 237 | "s_g_gene_add_alpha_hyp", 238 | dist.Gamma( 239 | self.ones * self.gene_add_alpha_hyp_prior_alpha, 240 | self.ones * self.gene_add_alpha_hyp_prior_beta, 241 | ), 242 | ) 243 | s_g_gene_add_mean = pyro.sample( 244 | "s_g_gene_add_mean", 245 | dist.Gamma( 246 | self.gene_add_mean_hyp_prior_alpha, 247 | self.gene_add_mean_hyp_prior_beta, 248 | ) 249 | .expand([self.n_batch, 1]) 250 | .to_event(2), 251 | ) # (self.n_batch) 252 | s_g_gene_add_alpha_e_inv = pyro.sample( 253 | "s_g_gene_add_alpha_e_inv", 254 | dist.Exponential(s_g_gene_add_alpha_hyp) 255 | .expand([self.n_batch, 1]) 256 | .to_event(2), 257 | ) # (self.n_batch) 258 | s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) 259 | 260 | s_g_gene_add = pyro.sample( 261 | "s_g_gene_add", 262 | dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean) 263 | .expand([self.n_batch, self.n_vars]) 264 | .to_event(2), 265 | ) # (self.n_batch, n_vars) 266 | 267 | # =====================Gene-specific overdispersion ======================= # 268 | alpha_g_phi_hyp = pyro.sample( 269 | "alpha_g_phi_hyp", 270 | dist.Gamma( 271 | self.ones * self.alpha_g_phi_hyp_prior_alpha, 272 | self.ones * self.alpha_g_phi_hyp_prior_beta, 273 | ), 274 | ) 275 | alpha_g_inverse = pyro.sample( 276 | "alpha_g_inverse", 277 | dist.Exponential(alpha_g_phi_hyp).expand([1, self.n_vars]).to_event(2), 278 | ) # (self.n_batch or 1, self.n_vars) 279 | 280 | # =====================Expected expression ======================= # 281 | 282 | # overdispersion 283 | alpha = self.ones / alpha_g_inverse.pow(2) 284 | # biological expression 285 | mu = ( 286 | obs2label @ per_cluster_mu_fg 287 | + obs2sample @ s_g_gene_add # contaminating RNA 288 | ) * detection_y_c # cell-specific normalisation 289 | if self.n_extra_categoricals is not None: 290 | # gene-specific normalisation for covatiates 291 | mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) 292 | # total_count, logits = _convert_mean_disp_to_counts_logits( 293 | # mu, alpha, eps=self.eps 294 | # ) 295 | 296 | # =====================DATA likelihood ======================= # 297 | # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial 298 | with obs_plate: 299 | pyro.sample( 300 | "data_target", 301 | dist.GammaPoisson(concentration=alpha, rate=alpha / mu), 302 | # dist.NegativeBinomial(total_count=total_count, logits=logits), 303 | obs=x_data, 304 | ) 305 | 306 | # =====================Other functions======================= # 307 | def compute_expected(self, samples, adata_manager, ind_x=None): 308 | r"""Compute expected expression of each gene in each cell. Useful for evaluating how well 309 | the model learned expression pattern of all genes in the data. 310 | 311 | Parameters 312 | ---------- 313 | samples 314 | dictionary with values of the posterior 315 | adata 316 | registered anndata 317 | ind_x 318 | indices of cells to use (to reduce data size) 319 | """ 320 | if ind_x is None: 321 | ind_x = np.arange(adata_manager.adata.n_obs).astype(int) 322 | else: 323 | ind_x = ind_x.astype(int) 324 | obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) 325 | obs2sample = ( 326 | pd.get_dummies(obs2sample.flatten()).values[ind_x, :].astype("float32") 327 | ) 328 | obs2label = adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY) 329 | obs2label = ( 330 | pd.get_dummies(obs2label.flatten()).values[ind_x, :].astype("float32") 331 | ) 332 | if self.n_extra_categoricals is not None: 333 | extra_categoricals = adata_manager.get_from_registry( 334 | REGISTRY_KEYS.CAT_COVS_KEY 335 | ) 336 | obs2extra_categoricals = np.concatenate( 337 | [ 338 | pd.get_dummies(extra_categoricals.iloc[ind_x, i]).astype("float32") 339 | for i, n_cat in enumerate(self.n_extra_categoricals) 340 | ], 341 | axis=1, 342 | ) 343 | 344 | alpha = 1 / np.power(samples["alpha_g_inverse"], 2) 345 | 346 | mu = ( 347 | np.dot(obs2label, samples["per_cluster_mu_fg"]) 348 | + np.dot(obs2sample, samples["s_g_gene_add"]) 349 | ) * np.dot( 350 | obs2sample, samples["detection_mean_y_e"] 351 | ) # samples["detection_y_c"][ind_x, :] 352 | if self.n_extra_categoricals is not None: 353 | mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) 354 | 355 | return {"mu": mu, "alpha": alpha} 356 | 357 | def compute_expected_subset(self, samples, adata_manager, fact_ind, cell_ind): 358 | r"""Compute expected expression of each gene in each cell that comes from 359 | a subset of factors (cell types) or cells. 360 | 361 | Useful for evaluating how well the model learned expression pattern of all genes in the data. 362 | 363 | Parameters 364 | ---------- 365 | samples 366 | dictionary with values of the posterior 367 | adata 368 | registered anndata 369 | fact_ind 370 | indices of factors/cell types to use 371 | cell_ind 372 | indices of cells to use 373 | """ 374 | obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) 375 | obs2sample = pd.get_dummies(obs2sample.flatten()) 376 | obs2label = adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY) 377 | obs2label = pd.get_dummies(obs2label.flatten()) 378 | if self.n_extra_categoricals is not None: 379 | extra_categoricals = adata_manager.get_from_registry( 380 | REGISTRY_KEYS.CAT_COVS_KEY 381 | ) 382 | obs2extra_categoricals = np.concatenate( 383 | [ 384 | pd.get_dummies(extra_categoricals.iloc[:, i]) 385 | for i, n_cat in enumerate(self.n_extra_categoricals) 386 | ], 387 | axis=1, 388 | ) 389 | 390 | alpha = 1 / np.power(samples["alpha_g_inverse"], 2) 391 | 392 | mu = ( 393 | np.dot( 394 | obs2label[cell_ind, fact_ind], samples["per_cluster_mu_fg"][fact_ind, :] 395 | ) 396 | + np.dot(obs2sample[cell_ind, :], samples["s_g_gene_add"]) 397 | ) * np.dot( 398 | obs2sample, samples["detection_mean_y_e"] 399 | ) # samples["detection_y_c"] 400 | if self.n_extra_categoricals is not None: 401 | mu = mu * np.dot( 402 | obs2extra_categoricals[cell_ind, :], samples["detection_tech_gene_tg"] 403 | ) 404 | 405 | return {"mu": mu, "alpha": alpha} 406 | 407 | def normalise(self, samples, adata_manager, adata): 408 | r"""Normalise expression data by estimated technical variables. 409 | 410 | Parameters 411 | ---------- 412 | samples 413 | dictionary with values of the posterior 414 | adata 415 | registered anndata 416 | 417 | """ 418 | obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) 419 | obs2sample = pd.get_dummies(obs2sample.flatten()) 420 | if self.n_extra_categoricals is not None: 421 | extra_categoricals = adata_manager.get_from_registry( 422 | REGISTRY_KEYS.CAT_COVS_KEY 423 | ) 424 | obs2extra_categoricals = np.concatenate( 425 | [ 426 | pd.get_dummies(extra_categoricals.iloc[:, i]) 427 | for i, n_cat in enumerate(self.n_extra_categoricals) 428 | ], 429 | axis=1, 430 | ) 431 | # get counts matrix 432 | corrected = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) 433 | # normalise per-sample scaling 434 | corrected = corrected / np.dot(obs2sample, samples["detection_mean_y_e"]) 435 | # normalise per gene effects 436 | if self.n_extra_categoricals is not None: 437 | corrected = corrected / np.dot( 438 | obs2extra_categoricals, samples["detection_tech_gene_tg"] 439 | ) 440 | 441 | # remove additive sample effects 442 | corrected = corrected - np.dot(obs2sample, samples["s_g_gene_add"]) 443 | 444 | # set minimum value to 0 for each gene (a hack to avoid negative values) 445 | corrected = corrected - corrected.min() 446 | 447 | return corrected 448 | -------------------------------------------------------------------------------- /schierarchy/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/utils/__init__.py -------------------------------------------------------------------------------- /schierarchy/utils/data_transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | 4 | 5 | def data_to_zero_truncated_cdf(x): 6 | """ 7 | Quantile transformation of expression. To convert matrix apply along cell dimension 8 | Parameters 9 | ---------- 10 | x: log1p(expression / total_expression * 1e4) transformed expression vector (for 1 gene) 11 | 12 | Returns 13 | ------- 14 | quantile normalised expression vector 15 | """ 16 | x_sorted = np.sort(x) 17 | indx_sorted = np.argsort(x) 18 | x_sorted = x[indx_sorted] 19 | try: 20 | zero_ind = np.where(x_sorted == 0)[0][-1] 21 | p = np.concatenate( 22 | [ 23 | np.zeros(zero_ind), 24 | 1.0 25 | * np.arange(len(x_sorted) - zero_ind) 26 | / (len(x_sorted) - zero_ind - 1), 27 | ] 28 | ) 29 | except IndexError: 30 | p = 1.0 * np.arange(len(x_sorted)) / (len(x_sorted)) 31 | cdfs = np.zeros_like(x) 32 | cdfs[indx_sorted] = p 33 | return cdfs 34 | 35 | 36 | def cdf_to_pseudonorm(cdf, clip=1e-3): 37 | """ 38 | Parameters 39 | ---------- 40 | cdf: Quantile transformed expression 41 | clip: Clipping normal distriubtion (position of zeros) 42 | 43 | Returns 44 | ------- 45 | Pseudonormal expression 46 | """ 47 | return stats.norm.ppf(np.maximum(np.minimum(cdf, 1.0 - clip), clip)) 48 | 49 | 50 | def variance_normalistaion(x, std, clip=10): 51 | """ 52 | Normalisation by standard deviation 53 | Parameters 54 | ---------- 55 | x: Expression data 56 | std: std in the right dimensions 57 | clip: clipping threshold for standard deviation 58 | 59 | Returns 60 | ------- 61 | 62 | """ 63 | if clip is not None: 64 | return x / np.minimum(std, clip) 65 | else: 66 | return x / np.minimum(std) 67 | -------------------------------------------------------------------------------- /schierarchy/utils/simulation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from scvi.data import synthetic_iid 4 | 5 | 6 | def random_tree(hlevels): 7 | """ 8 | Generate random tree from the list of level widths and return as a dict of edges 9 | 10 | Parameters 11 | ---------- 12 | hlevels 13 | List of number of classes per level 14 | 15 | Returns 16 | ------- 17 | List of edges (parent:[children]) between levels 18 | 19 | 20 | Examples 21 | -------- 22 | >>> hlevels = [4,19,20] 23 | >>> tree = random_tree(hlevels) 24 | """ 25 | 26 | edge_dicts = [] 27 | for i_level in range(len(hlevels) - 1): 28 | level1 = np.arange(hlevels[i_level]) 29 | level2 = np.arange(hlevels[i_level + 1]) 30 | 31 | edge_dict = {k: [] for k in level1} 32 | garant = np.random.choice( 33 | level2, size=len(level1), replace=False 34 | ) # select garanteed 35 | for i in range(len(level1)): 36 | edge_dict[i].append(garant[i]) 37 | rest = np.random.choice(level1, size=len(level2)) # distribute the rest 38 | for i in range(len(level2)): 39 | if i not in garant: 40 | edge_dict[rest[i]].append(i) 41 | 42 | edge_dicts.append(edge_dict) 43 | return edge_dicts 44 | 45 | 46 | def invert_dict(d): 47 | inv_d = {} 48 | for k, v_list in d.items(): 49 | for v in v_list: 50 | inv_d[v] = k 51 | return inv_d 52 | 53 | 54 | def plot_tree(hlevels, tree): 55 | """ 56 | Graph plotting 57 | 58 | Parameters 59 | ---------- 60 | hlevels 61 | List of number of classes per level 62 | tree 63 | List of edges between levels 64 | """ 65 | ylim = np.max(hlevels[-1]) 66 | xlim = len(hlevels) 67 | x_offset = [(ylim - hlevels[i]) / 2 for i in range(xlim)] 68 | x_ploints = [] 69 | y_ploints = [] 70 | 71 | x_lines = [[], []] 72 | y_lines = [[], []] 73 | 74 | for i in range(xlim): 75 | for j in range(hlevels[i]): 76 | y_ploints.append(i) 77 | x_ploints.append(j + x_offset[i]) 78 | 79 | for i in range(xlim - 1): 80 | for k, v in tree[i].items(): 81 | for v_i in v: 82 | x_lines[0].append(k + x_offset[i]) 83 | x_lines[1].append(v_i + x_offset[i + 1]) 84 | 85 | y_lines[0].append(i) 86 | y_lines[1].append(i + 1) 87 | 88 | plt.scatter(x_ploints, y_ploints) 89 | plt.plot(x_lines, y_lines, color="black", alpha=0.5) 90 | plt.gca().invert_yaxis() 91 | 92 | 93 | def hierarchical_iid(hlevels, *args, **kwargs): 94 | """ 95 | Wrapper above scvi.data.synthetic_iid to produce hierarhical labels 96 | 97 | Parameters 98 | ---------- 99 | hlevels 100 | List of number of classes per level 101 | 102 | Returns 103 | ------- 104 | AnnData with batch info (``.obs['batch']``), label info (``.obs['labels']``) 105 | on level i (``.obs['level_i']``). List of edges (parent:[children]) between levels (``.uns['tree']``) 106 | """ 107 | tree = random_tree(hlevels) 108 | 109 | bottom_level_n_labels = hlevels[-1] 110 | synthetic_data = synthetic_iid(n_labels=bottom_level_n_labels, *args, **kwargs) 111 | from re import sub 112 | 113 | synthetic_data.obs["labels"] = [ 114 | int(sub("label_", "", i)) for i in synthetic_data.obs["labels"] 115 | ] 116 | 117 | levels = ["labels"] + [f"level_{i}" for i in range(len(hlevels) - 2, -1, -1)] 118 | for i in range(len(levels) - 1): 119 | level_up = synthetic_data.obs[levels[i]].apply( 120 | lambda x: invert_dict(tree[len(tree) - 1 - i])[x] 121 | ) 122 | synthetic_data.obs[levels[i + 1]] = level_up 123 | 124 | synthetic_data.obs = synthetic_data.obs.rename( 125 | columns={"labels": f"level_{len(levels) - 1}"} 126 | ) 127 | synthetic_data.uns["tree"] = tree 128 | return synthetic_data 129 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry 4 | 5 | import setuptools 6 | 7 | if __name__ == "__main__": 8 | setuptools.setup(name="schierarchy") 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_hierarchical_logist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from scvi.data import synthetic_iid 5 | 6 | from schierarchy import LogisticModel 7 | from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf 8 | from schierarchy.utils.simulation import hierarchical_iid 9 | 10 | 11 | def test_hierarchical_logist(): 12 | save_path = "./cell2location_model_test" 13 | hlevels = [4, 10, 20] 14 | dataset = hierarchical_iid(hlevels) 15 | level_keys = [f"level_{i}" for i in range(len(hlevels))] 16 | # tree = dataset.uns["tree"] 17 | del dataset.uns["tree"] 18 | dataset.layers["cdf"] = np.apply_along_axis( 19 | data_to_zero_truncated_cdf, 0, dataset.X 20 | ) 21 | 22 | for learning_mode in [ 23 | "fixed-sigma", 24 | "learn-sigma-single", 25 | "learn-sigma-gene", 26 | "learn-sigma-celltype", 27 | "learn-sigma-gene-celltype", 28 | ]: 29 | for use_dropout in [ 30 | {"use_dropout": True}, 31 | {"use_dropout": False}, 32 | {"use_gene_dropout": True}, 33 | {"use_gene_dropout": False}, 34 | ]: 35 | LogisticModel.setup_anndata(dataset, layer="cdf", level_keys=level_keys) 36 | # train regression model to get signatures of cell types 37 | sc_model = LogisticModel( 38 | dataset, 39 | laplace_learning_mode=learning_mode, 40 | **use_dropout, 41 | ) 42 | # test full data training 43 | sc_model.train(max_epochs=10, batch_size=None) 44 | # test minibatch training 45 | sc_model.train(max_epochs=10, batch_size=100) 46 | # export the estimated cell abundance (summary of the posterior distribution) 47 | dataset = sc_model.export_posterior( 48 | dataset, sample_kwargs={"num_samples": 10} 49 | ) 50 | # test plot_QC' 51 | # sc_model.plot_QC() 52 | # test save/load 53 | sc_model.save(save_path, overwrite=True, save_anndata=True) 54 | sc_model = LogisticModel.load(save_path) 55 | os.system(f"rm -rf {save_path}") 56 | 57 | 58 | def test_hierarchical_logist_prediction(): 59 | hlevels = [4, 10, 20] 60 | dataset = hierarchical_iid(hlevels) 61 | level_keys = [f"level_{i}" for i in range(len(hlevels))] 62 | del dataset.uns["tree"] 63 | dataset.layers["cdf"] = np.apply_along_axis( 64 | data_to_zero_truncated_cdf, 0, dataset.X 65 | ) 66 | 67 | LogisticModel.setup_anndata(dataset, layer="cdf", level_keys=level_keys) 68 | 69 | # train regression model to get signatures of cell types 70 | sc_model = LogisticModel(dataset) 71 | # test full data training 72 | sc_model.train(max_epochs=10, batch_size=None) 73 | # test prediction 74 | dataset2 = synthetic_iid(n_labels=5) 75 | dataset2.layers["cdf"] = np.apply_along_axis( 76 | data_to_zero_truncated_cdf, 0, dataset2.X 77 | ) 78 | dataset2 = sc_model.export_posterior( 79 | dataset2, prediction=True, sample_kwargs={"num_samples": 10} 80 | ) 81 | -------------------------------------------------------------------------------- /tests/test_hierarchical_logist_nohierarchy.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from schierarchy import LogisticModel 6 | from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf 7 | from schierarchy.utils.simulation import hierarchical_iid 8 | 9 | 10 | def test_hierarchical_logist_nohierarchy(): 11 | save_path = "./cell2location_model_test" 12 | hlevels = [20] 13 | dataset = hierarchical_iid(hlevels) 14 | level_keys = [f"level_{i}" for i in range(len(hlevels))] 15 | # tree = dataset.uns["tree"] 16 | del dataset.uns["tree"] 17 | dataset.layers["cdf"] = np.apply_along_axis( 18 | data_to_zero_truncated_cdf, 0, dataset.X 19 | ) 20 | for learning_mode in [ 21 | "fixed-sigma", 22 | "learn-sigma-single", 23 | "learn-sigma-gene", 24 | "learn-sigma-celltype", 25 | "learn-sigma-gene-celltype", 26 | ]: 27 | LogisticModel.setup_anndata(dataset, layer="cdf", level_keys=level_keys) 28 | 29 | # train regression model to get signatures of cell types 30 | sc_model = LogisticModel( 31 | dataset, 32 | laplace_learning_mode=learning_mode, 33 | ) 34 | # test full data training 35 | sc_model.train(max_epochs=10, batch_size=None) 36 | # test minibatch training 37 | sc_model.train(max_epochs=10, batch_size=100) 38 | # export the estimated cell abundance (summary of the posterior distribution) 39 | dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) 40 | dataset = sc_model.export_posterior( 41 | dataset, sample_kwargs={ 42 | "use_median": True, 43 | }, 44 | use_quantiles=True, 45 | add_to_varm=['q50'], 46 | ) 47 | # test plot_QC' 48 | # sc_model.plot_QC() 49 | # test save/load 50 | sc_model.save(save_path, overwrite=True, save_anndata=True) 51 | sc_model = LogisticModel.load(save_path) 52 | os.system(f"rm -rf {save_path}") 53 | -------------------------------------------------------------------------------- /tests/test_regression.py: -------------------------------------------------------------------------------- 1 | from scvi.data import synthetic_iid 2 | 3 | from schierarchy import RegressionModel 4 | 5 | 6 | def test_regression(): 7 | save_path = "./cell2location_model_test" 8 | dataset = synthetic_iid(n_labels=5) 9 | RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch") 10 | 11 | # train regression model to get signatures of cell types 12 | sc_model = RegressionModel(dataset) 13 | # test full data training 14 | sc_model.train(max_epochs=10, batch_size=None) 15 | # test minibatch training 16 | sc_model.train(max_epochs=10, batch_size=100) 17 | # export the estimated cell abundance (summary of the posterior distribution) 18 | dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) 19 | # test plot_QC' 20 | sc_model.plot_QC() 21 | # test save/load 22 | sc_model.save(save_path, overwrite=True, save_anndata=True) 23 | sc_model = RegressionModel.load(save_path) 24 | # export estimated expression in each cluster 25 | if "means_per_cluster_mu_fg" in dataset.varm.keys(): 26 | inf_aver = dataset.varm["means_per_cluster_mu_fg"][ 27 | [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]] 28 | ].copy() 29 | else: 30 | inf_aver = dataset.var[ 31 | [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]] 32 | ].copy() 33 | inf_aver.columns = dataset.uns["mod"]["factor_names"] 34 | --------------------------------------------------------------------------------