├── .editorconfig
├── .flake8
├── .github
└── workflows
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── codecov.yml
├── docs
├── Makefile
├── _static
│ └── css
│ │ └── custom.css
├── _templates
│ ├── autosummary
│ │ └── class.rst
│ └── class_no_inherited.rst
├── api
│ └── index.rst
├── conf.py
├── extensions
│ └── typed_returns.py
├── index.rst
├── installation.rst
├── make.bat
├── notebooks
│ ├── hierarchical_selsction_prototype.ipynb
│ └── marker_selection_example.ipynb
├── references.rst
└── release_notes
│ ├── index.rst
│ ├── v0.1.0.rst
│ └── v0.1.1.rst
├── pyproject.toml
├── readthedocs.yml
├── schierarchy
├── __init__.py
├── base
│ ├── __init__.py
│ └── _pyro_base_regression_module.py
├── logistic
│ ├── __init__.py
│ ├── _logistic_model.py
│ └── _logistic_module.py
├── regression
│ ├── __init__.py
│ ├── _reference_model.py
│ └── _reference_module.py
└── utils
│ ├── __init__.py
│ ├── data_transformation.py
│ └── simulation.py
├── setup.py
└── tests
├── __init__.py
├── test_hierarchical_logist.py
├── test_hierarchical_logist_nohierarchy.py
└── test_regression.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203, E266, E501, W503, W605, N812
3 | exclude = .git,docs
4 | max-line-length = 119
5 |
6 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: schierarchy
5 |
6 | on:
7 | push:
8 | branches: [main]
9 | pull_request:
10 | branches: [main]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.9]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Cache pip
27 | uses: actions/cache@v2
28 | with:
29 | path: ~/.cache/pip
30 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
31 | restore-keys: |
32 | ${{ runner.os }}-pip-
33 | - name: Install dependencies
34 | run: |
35 | pip install pytest-cov
36 | pip install .[dev]
37 | - name: Lint with flake8
38 | run: |
39 | flake8
40 | - name: Format with black
41 | run: |
42 | black --check .
43 | - name: Test with pytest
44 | run: |
45 | pytest --cov-report=xml --cov=mypackage
46 | - name: After success
47 | run: |
48 | bash <(curl -s https://codecov.io/bash)
49 | pip list
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project-specific
2 | .idea/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # vscode
135 | .vscode/settings.json
136 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 | - repo: https://github.com/python/black
7 | rev: 22.3.0
8 | hooks:
9 | - id: black
10 | - repo: https://gitlab.com/pycqa/flake8
11 | rev: 3.8.4
12 | hooks:
13 | - id: flake8
14 | - repo: https://github.com/pycqa/isort
15 | rev: 5.7.0
16 | hooks:
17 | - id: isort
18 | name: isort (python)
19 | additional_dependencies: [toml]
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scHierarchy: Hierarchical logistic regression model for marker gene selection using hierachical cell annotations
2 |
3 | [](https://github.com/vitkl/scHierarchy/stargazers)
4 | [](https://scHierarchy.readthedocs.io/en/stable/?badge=stable)
5 | 
6 | [](https://github.com/python/black)
7 |
8 |
9 |
10 | ## Installation
11 |
12 | Linux installation
13 | ```bash
14 | conda create -y -n schierarchy-env python=3.9
15 | conda activate schierarchy-env
16 | pip install git+https://github.com/dissatisfaction-ai/scHierarchy.git
17 | ```
18 |
19 | Mac installation
20 | ```bash
21 | conda create -y -n schierarchy-env python=3.8
22 | conda activate schierarchy-env
23 | pip install git+https://github.com/pyro-ppl/pyro.git@dev
24 | conda install -y -c anaconda hdf5 pytables netcdf4
25 | pip install git+https://github.com/dissatisfaction-ai/scHierarchy.git
26 | ```
27 |
28 | ## Usage example notebook
29 |
30 | https://github.com/dissatisfaction-ai/scHierarchy/blob/main/docs/notebooks/marker_selection_example.ipynb
31 |
32 | This will be updated using a publicly available dataset & Colab - however, it is challenging to find published dataset with several annotation levels yet sufficiently small to be used on Colab.
33 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # Run to check if valid
2 | # curl --data-binary @codecov.yml https://codecov.io/validate
3 | coverage:
4 | status:
5 | project:
6 | default:
7 | target: 80%
8 | threshold: 1%
9 | patch: off
10 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = scvi
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* influenced by and borrowed from: https://github.com/cvxgrp/pymde/blob/main/docs_src/source/_static/css/custom.css */
2 |
3 | @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,300;0,400;0,600;1,300;1,400;1,600&display=swap');
4 |
5 | :root {
6 | --sidebarcolor: #003262;
7 | --sidebarfontcolor: #ffffff;
8 | --sidebarhover: #295e97;
9 |
10 | --bodyfontcolor: #333;
11 | --webfont: 'Roboto';
12 |
13 | --contentwidth: 1000px;
14 | }
15 |
16 | /* Fonts and text */
17 | h1, h2, h3, h4, h5, h6 {
18 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif;
19 | font-weight: 400;
20 | }
21 |
22 | h2, h3, h4, h5, h6 {
23 | padding-top: 0.25em;
24 | margin-bottom: 0.5em;
25 | }
26 |
27 | h1 {
28 | font-size: 225%;
29 | }
30 |
31 | body {
32 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif;
33 | color: var(--bodyfontcolor);
34 | }
35 |
36 | p {
37 | font-size: 1em;
38 | line-height: 150%;
39 | }
40 |
41 |
42 | /* Sidebar */
43 | .wy-side-nav-search {
44 | background-color: var(--sidebarcolor);
45 | }
46 |
47 | .wy-nav-side {
48 | background: var(--sidebarcolor);
49 | }
50 |
51 | .wy-menu-vertical header, .wy-menu-vertical p.caption {
52 | color: var(--sidebarfontcolor);
53 | }
54 |
55 | .wy-menu-vertical a {
56 | color: var(--sidebarfontcolor);
57 | }
58 |
59 | .wy-side-nav-search > div.version {
60 | color: var(--sidebarfontcolor);
61 | }
62 |
63 | .wy-menu-vertical a:hover {
64 | background-color: var(--sidebarhover);
65 | }
66 |
67 | /* Main content */
68 | .wy-nav-content {
69 | max-width: var(--contentwidth);
70 | }
71 |
72 |
73 | html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list) > dt{
74 | margin-bottom: 6px;
75 | border-left: none;
76 | background: none;
77 | color: #555;
78 | }
79 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | .. rubric:: Attributes
12 |
13 | .. autosummary::
14 | :toctree: .
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | .. rubric:: Methods
24 |
25 | .. autosummary::
26 | :toctree: .
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
--------------------------------------------------------------------------------
/docs/_templates/class_no_inherited.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block methods %}
10 | {% if methods %}
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 | :toctree: .
15 | {% for item in methods %}
16 | {%- if item != '__init__' and item not in inherited_members%}
17 | ~{{ fullname }}.{{ item }}
18 | {%- endif -%}
19 |
20 | {%- endfor %}
21 | {% endif %}
22 | {% endblock %}
23 |
--------------------------------------------------------------------------------
/docs/api/index.rst:
--------------------------------------------------------------------------------
1 | ===
2 | API
3 | ===
4 |
5 | .. currentmodule:: mypackage
6 |
7 | .. note:: External functions can be imported from other packages like `scvi-tools` and be displayed on your docs website. As an example, `setup_anndata` here is directly from `scvi-tools`.
8 |
9 | Data
10 | ~~~~
11 | .. autosummary::
12 | :toctree: reference/
13 |
14 | setup_anndata
15 |
16 | MyModel
17 | ~~~~~~~
18 |
19 | .. autosummary::
20 | :toctree: reference/
21 |
22 | MyModel
23 | MyPyroModel
24 |
25 | MyModule
26 | ~~~~~~~~
27 | .. autosummary::
28 | :toctree: reference/
29 | :template: class_no_inherited.rst
30 |
31 | MyModule
32 | MyPyroModule
33 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | import sys
10 | from pathlib import Path
11 |
12 | HERE = Path(__file__).parent
13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")]
14 |
15 | import mypackage # noqa
16 |
17 |
18 | # -- General configuration ---------------------------------------------
19 |
20 | # If your documentation needs a minimal Sphinx version, state it here.
21 | #
22 | needs_sphinx = "3.0" # Nicer param docs
23 |
24 | # Add any Sphinx extension module names here, as strings. They can be
25 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
26 | extensions = [
27 | "sphinx.ext.autodoc",
28 | "sphinx.ext.viewcode",
29 | "nbsphinx",
30 | "nbsphinx_link",
31 | "sphinx.ext.mathjax",
32 | "sphinx.ext.napoleon",
33 | "sphinx_autodoc_typehints", # needs to be after napoleon
34 | "sphinx.ext.intersphinx",
35 | "sphinx.ext.autosummary",
36 | "scanpydoc.elegant_typehints",
37 | "scanpydoc.definition_list_typed_field",
38 | "scanpydoc.autosummary_generate_imported",
39 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
40 | ]
41 |
42 | # nbsphinx specific settings
43 | exclude_patterns = ["_build", "**.ipynb_checkpoints"]
44 | nbsphinx_execute = "never"
45 |
46 | # Add any paths that contain templates here, relative to this directory.
47 | templates_path = ["_templates"]
48 |
49 | # The suffix(es) of source filenames.
50 | # You can specify multiple suffix as a list of string:
51 | #
52 | # source_suffix = ['.rst', '.md']
53 | source_suffix = ".rst"
54 |
55 | # Generate the API documentation when building
56 | autosummary_generate = True
57 | autodoc_member_order = "bysource"
58 | napoleon_google_docstring = False
59 | napoleon_numpy_docstring = True
60 | napoleon_include_init_with_doc = False
61 | napoleon_use_rtype = True
62 | napoleon_use_param = True
63 | napoleon_custom_sections = [("Params", "Parameters")]
64 | todo_include_todos = False
65 | numpydoc_show_class_members = False
66 | annotate_defaults = True
67 | # The master toctree document.
68 | master_doc = "index"
69 |
70 |
71 | intersphinx_mapping = dict(
72 | anndata=("https://anndata.readthedocs.io/en/stable/", None),
73 | ipython=("https://ipython.readthedocs.io/en/stable/", None),
74 | matplotlib=("https://matplotlib.org/", None),
75 | numpy=("https://docs.scipy.org/doc/numpy/", None),
76 | pandas=("https://pandas.pydata.org/pandas-docs/stable/", None),
77 | python=("https://docs.python.org/3", None),
78 | scipy=("https://docs.scipy.org/doc/scipy/reference/", None),
79 | sklearn=("https://scikit-learn.org/stable/", None),
80 | torch=("https://pytorch.org/docs/master/", None),
81 | scanpy=("https://scanpy.readthedocs.io/en/stable/", None),
82 | pytorch_lightning=("https://pytorch-lightning.readthedocs.io/en/stable/", None),
83 | )
84 |
85 |
86 | # General information about the project.
87 | project = "schierarchy"
88 | copyright = "2021, Yosef Lab, UC Berkeley"
89 | author = "Adam Gayoso"
90 |
91 | # The version info for the project you're documenting, acts as replacement
92 | # for |version| and |release|, also used in various other places throughout
93 | # the built documents.
94 | #
95 | # The short X.Y version.
96 | version = mypackage.__version__
97 | # The full version, including alpha/beta/rc tags.
98 | release = mypackage.__version__
99 |
100 | # The language for content autogenerated by Sphinx. Refer to documentation
101 | # for a list of supported languages.
102 | #
103 | # This is also used if you do content translation via gettext catalogs.
104 | # Usually you set "language" from the command line for these cases.
105 | language = None
106 |
107 | # List of patterns, relative to source directory, that match files and
108 | # directories to ignore when looking for source files.
109 | # This patterns also effect to html_static_path and html_extra_path
110 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
111 |
112 | # The name of the Pygments (syntax highlighting) style to use.
113 | pygments_style = "tango"
114 |
115 | # If true, `todo` and `todoList` produce output, else they produce nothing.
116 | todo_include_todos = False
117 |
118 | # -- Options for HTML output -------------------------------------------------
119 |
120 | # The theme to use for HTML and HTML Help pages. See the documentation for
121 | # a list of builtin themes.
122 | #
123 | html_theme = "sphinx_rtd_theme"
124 |
125 | html_show_sourcelink = False
126 |
127 | html_show_copyright = False
128 |
129 | display_version = True
130 |
131 | # Add any paths that contain custom static files (such as style sheets) here,
132 | # relative to this directory. They are copied after the builtin static files,
133 | # so a file named "default.css" will overwrite the builtin "default.css".
134 | html_static_path = ["_static"]
135 |
136 | html_css_files = [
137 | "css/custom.css",
138 | ]
139 |
140 | html_favicon = "favicon.ico"
141 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | import re
4 |
5 | from sphinx.application import Sphinx
6 | from sphinx.ext.napoleon import NumpyDocstring
7 |
8 |
9 | def process_return(lines):
10 | for line in lines:
11 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line)
12 | if m:
13 | # Once this is in scanpydoc, we can use the fancy hover stuff
14 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
15 | else:
16 | yield line
17 |
18 |
19 | def scanpy_parse_returns_section(self, section):
20 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section())))
21 | lines = self._format_block(":returns: ", lines_raw)
22 | if lines and lines[-1]:
23 | lines.append("")
24 | return lines
25 |
26 |
27 | def setup(app: Sphinx):
28 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section
29 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | ========================
2 | mypackage documentation
3 | ========================
4 |
5 | Welcome! This is the corresponding documentation website for the `scvi-tools-skeleton
6 | `_. The purpose of this website is to demonstrate some of the functionality available for scvi-tools developer API users.
7 |
8 | We recommend building your own package by first using this repository as a `template `_. Subsequently, the info in the `/docs` directory can be updated to contain information relevant to your package.
9 |
10 |
11 | .. toctree::
12 | :maxdepth: 1
13 | :hidden:
14 |
15 | installation
16 | api/index
17 | release_notes/index
18 | references
19 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | Prerequisites
5 | ~~~~~~~~~~~~~~
6 |
7 | my_package can be installed via PyPI.
8 |
9 | conda prerequisites
10 | ###################
11 |
12 | 1. Install Conda. We typically use the Miniconda_ Python distribution. Use Python version >=3.7.
13 |
14 | 2. Create a new conda environment::
15 |
16 | conda create -n scvi-env python=3.7
17 |
18 | 3. Activate your environment::
19 |
20 | source activate scvi-env
21 |
22 | pip prerequisites:
23 | ##################
24 |
25 | 1. Install Python_, we prefer the `pyenv `_ version management system, along with `pyenv-virtualenv `_.
26 |
27 | 2. Install PyTorch_. If you have an Nvidia GPU, be sure to install a version of PyTorch that supports it -- scvi-tools runs much faster with a discrete GPU.
28 |
29 | .. _Miniconda: https://conda.io/miniconda.html
30 | .. _Python: https://www.python.org/downloads/
31 | .. _PyTorch: http://pytorch.org
32 |
33 | my_package installation
34 | ~~~~~~~~~~~~~~~~~~~~~~~
35 |
36 | Install my_package in one of the following ways:
37 |
38 | Through **pip**::
39 |
40 | pip install
41 |
42 | Through pip with packages to run notebooks. This installs scanpy, etc.::
43 |
44 | pip install [tutorials]
45 |
46 | Nightly version - clone this repo and run::
47 |
48 | pip install .
49 |
50 | For development - clone this repo and run::
51 |
52 | pip install -e .[dev,docs]
53 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=scvi
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/notebooks/hierarchical_selsction_prototype.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "1e96289b-7640-4294-9e10-ba8f0ef5eae9",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "import pandas as pd\n",
13 | "import torch\n",
14 | "from torch.nn import Parameter\n",
15 | "import numpy as np\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "\n",
18 | "from pyro.infer import MCMC, NUTS, Predictive\n",
19 | "import pyro\n",
20 | "from pyro import poutine\n",
21 | "from pyro import distributions as dist\n",
22 | "from pyro.optim import Adam\n",
23 | "from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO\n",
24 | "from pyro.infer.autoguide import AutoDelta, AutoNormal\n",
25 | "pyro.enable_validation(True)\n",
26 | "import tqdm\n",
27 | "\n",
28 | "import scipy\n",
29 | "from scipy import ndimage\n",
30 | "from scipy.sparse import load_npz\n",
31 | "from scipy.sparse import coo_matrix\n",
32 | "#import skimage.color\n",
33 | "\n",
34 | "from sklearn.neighbors import KDTree\n",
35 | "\n",
36 | "from collections import Counter\n",
37 | "\n",
38 | "import scanpy as sc\n",
39 | "\n",
40 | "import seaborn as sns"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "id": "810ef782-d010-41ea-bd35-869d65244d8d",
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "def subsample_clusters(obj, group_label = 'celltype_subset', subsample_factor = 0.2):\n",
51 | " chosen = []\n",
52 | " indxs = np.arange(obj.n_obs)\n",
53 | " for group in obj.obs[group_label].unique():\n",
54 | " group_indxs = np.where(obj.obs[group_label] == group)[0]\n",
55 | " group_chosen = np.random.choice(group_indxs,\n",
56 | " size=np.maximum(int(len(group_indxs)*subsample_factor), [len(group_indxs),250][int(len(group_indxs) > 250)]),\n",
57 | " replace=False)\n",
58 | " #print(len(group_chosen))\n",
59 | " chosen += list(group_chosen)\n",
60 | "\n",
61 | " return np.array(chosen)\n",
62 | "\n",
63 | "def data_to_zero_truncated_cdf(x):\n",
64 | " x_sorted = np.sort(x)\n",
65 | " indx_sorted = np.argsort(x)\n",
66 | " x_sorted = x[indx_sorted]\n",
67 | " zero_ind = np.where(x_sorted == 0)[0][-1]\n",
68 | " p = np.concatenate([np.zeros(zero_ind), 1. * np.arange(len(x_sorted) - zero_ind) / (len(x_sorted) - zero_ind - 1)])\n",
69 | " cdfs = np.zeros_like(x)\n",
70 | " cdfs[indx_sorted] = p\n",
71 | " return cdfs"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "id": "ccbe68b4-97d5-425a-a019-4a62576a201f",
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "data = sc.read_h5ad('../data_atlas/atals_processed.h5ad')\n",
82 | "data.layers['raw_data'] = ((data.X.expm1() / 10000).multiply(data.obs[['nCount_RNA']].values)).tocsr() #real data \n",
83 | "\n",
84 | "data.obs['celltype_subset_alt'] = list(data.obs['celltype_subset'].values)\n",
85 | "#data.obs.loc[data.obs['gene_module'] != 'no_gene_module', 'celltype_subset_alt'] = data.obs.loc[data.obs['gene_module'] != 'no_gene_module', 'gene_module']\n",
86 | "data.obs.loc[data.obs['celltype_major'] == 'CAFs', 'celltype_subset_alt'] = data.obs.loc[data.obs['celltype_major'] == 'CAFs', 'celltype_minor']\n",
87 | "\n",
88 | "group_today = 'celltype_subset_alt'\n",
89 | "\n",
90 | "data_sub = data[subsample_clusters(data, subsample_factor=0.1, group_label=group_today)]\n",
91 | "\n",
92 | "mean_exp_vals = []\n",
93 | "for g in data_sub.obs[group_today].unique():\n",
94 | " #print(g)\n",
95 | " tmp = data_sub.X[np.where(data_sub.obs[group_today] == g)[0],:]\n",
96 | " tmp = tmp.toarray().mean(axis=0)\n",
97 | " mean_exp_vals.append(tmp)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "id": "e15d27dc-05e0-457b-bd0d-df923c423c3b",
104 | "metadata": {},
105 | "outputs": [],
106 | "source": []
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "2d326189-e48c-4b4e-9588-1dcaddb110bd",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "expressed_genes = np.where((np.exp(np.stack(mean_exp_vals, axis=0).max(axis=0)) - 1 > \n",
116 | " (np.exp(np.stack(mean_exp_vals, axis=0).max(axis=0))[np.where(data_sub.var.index == 'RPLP0')[0]] - 1)/20))[0]\n",
117 | "\n",
118 | "data_sub = data_sub[:, expressed_genes].copy()"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "215c94c8-27a2-4e68-8680-5bd2ac933233",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": []
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "58613c69-f272-46c0-a40a-ed4442c422b1",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "def multi_logist_hierarchical(D, X, label_graph, class_size, device='cuda'):\n",
137 | " '''\n",
138 | " D ~ Multinomial(softmax(Xw))\n",
139 | " \n",
140 | " Parameters\n",
141 | " ----------\n",
142 | " D : torch.tensor\n",
143 | " Cell type assignement (cells x levels) \n",
144 | " X : torch.tensor\n",
145 | " Expression matrix (cells x genes)\n",
146 | " label_graph : dictionary\n",
147 | " level mapping binary (parent x children)\n",
148 | " class_size : dict \n",
149 | " \n",
150 | " \n",
151 | " \n",
152 | " device: torch.device\n",
153 | " Device specification (default - CUDA) for CPU use torch.device('cpu')\n",
154 | "\n",
155 | " '''\n",
156 | " i_cells, g_genes = X.shape\n",
157 | " c_class_levels = [class_size[i].size()[0] for i in range(2)]\n",
158 | " \n",
159 | " #top level weights\n",
160 | " w_top = pyro.sample('w_top', dist.Laplace(torch.tensor([0.]).to(device),\n",
161 | " torch.tensor([0.5]).to(device)).expand([g_genes, c_class_levels[0]]).to_event(2)) \n",
162 | " f_top = torch.nn.functional.softmax(torch.matmul(X, w_top / (class_size[0])**0.5), dim=1)\n",
163 | " #bottom levels\n",
164 | " w_bottom = pyro.sample('w_bottom', dist.Laplace(torch.tensor([0.]).to(device),\n",
165 | " torch.tensor([0.5]).to(device)).expand([g_genes, c_class_levels[1]]).to_event(2)) \n",
166 | " \n",
167 | " \n",
168 | " #w_top_actual = torch.zeros_like(w_top)\n",
169 | " f_bottom = torch.zeros((i_cells, c_class_levels[1])).to(device)\n",
170 | " #for parent in range(c_class_levels[0]):\n",
171 | " #w_top_actual[:, parent] = torch.maximum(w_top[:, parent] - w_bottom[:, label_graph[parent]].sum(axis=1), torch.zeros_like(w_top_actual[:, parent]))\n",
172 | " #f_bottom[:, label_graph[parent]] += torch.nn.functional.softmax(torch.matmul(X, w_bottom[:, label_graph[parent]] / (class_size[1][label_graph[parent]])[None,:]**0.5), dim=1) / len(label_graph[parent])\n",
173 | " #w_top_actual = pyro.deterministic('w_top_actual', w_top_actual)\n",
174 | " #f_top = torch.nn.functional.softmax(torch.matmul(X, w_top_actual / (class_size[0])**0.5), dim=1)\n",
175 | "\n",
176 | " for parent in range(c_class_levels[0]):\n",
177 | " f_bottom[:, label_graph[parent]] += torch.nn.functional.softmax(torch.matmul(X, w_bottom[:, label_graph[parent]] / (class_size[1][label_graph[parent]])[None,:]**0.5), dim=1) * f_top[:,parent,None]\n",
178 | " \n",
179 | " \n",
180 | " #bottom level\n",
181 | " obs_top = pyro.sample('likelihood_top', dist.Categorical(f_top).to_event(1), obs=D[:,0])\n",
182 | " obs_bottom = pyro.sample('likelihood_bottom', dist.Categorical(f_bottom).to_event(1), obs=D[:,1]) "
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "id": "c50759be-a918-470f-b8f4-9c170d6dd605",
189 | "metadata": {},
190 | "outputs": [],
191 | "source": []
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "id": "cac2b31a-10e0-458e-ac6d-ed772d0637dd",
197 | "metadata": {},
198 | "outputs": [],
199 | "source": []
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "id": "1eae2906-27dd-4319-ae6a-a5929e40168b",
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "#hierarchical labeling \n",
209 | "np.random.seed(42)\n",
210 | "pyro.set_rng_seed(42)\n",
211 | "torch.manual_seed(42)\n",
212 | "\n",
213 | "device='cpu'\n",
214 | "D = torch.tensor(np.float32(np.stack([data_sub.obs['celltype_major'].cat.codes.values, data_sub.obs[group_today].cat.codes.values], axis=1))).to(device=device)\n",
215 | "X = np.apply_along_axis(data_to_zero_truncated_cdf, 0, data_sub.X.toarray()) #/ (data_sub.layers['raw_data'].toarray().max(axis=0)**0.5)[None,:]\n",
216 | "X = torch.tensor(np.float32(X)).to(device=device)\n",
217 | "class_size = [torch.tensor(np.float32(data_sub.obs.groupby(group).size().values)).to(device=device) for group in ['celltype_major', group_today]]\n",
218 | "label_graph = {}\n",
219 | "for i, g in enumerate(data_sub.obs['celltype_major'].cat.categories):\n",
220 | " label_graph[i] = data_sub.obs[data_sub.obs['celltype_major'] == g][group_today].unique().codes\n",
221 | "\n",
222 | "pyro.clear_param_store()\n",
223 | "\n",
224 | "model = multi_logist_hierarchical\n",
225 | "guide = AutoNormal(model)\n",
226 | "\n",
227 | "adam_params = {\"lr\": 0.01, \"betas\": (0.95, 0.999)}\n",
228 | "optimizer = Adam(adam_params)\n",
229 | "\n",
230 | "svi = SVI(model, guide, optimizer, loss=Trace_ELBO())\n",
231 | "\n",
232 | "n_steps = 10000\n",
233 | "# do gradient steps\n",
234 | "loss_list = []\n",
235 | "for step in tqdm.tqdm(range(n_steps)):\n",
236 | " loss = svi.step(D, X, label_graph, class_size, 'cpu')\n",
237 | " loss_list.append(loss)\n"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "id": "730b52b9-9c49-43f4-aa68-faebd9e0168b",
244 | "metadata": {},
245 | "outputs": [],
246 | "source": []
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "id": "8cef7858-7757-4e68-be57-57c550bf03d4",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "trace = Predictive(model, guide=guide, num_samples=100)(D, X, label_graph, class_size, 'cpu')"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "id": "f054f8d2-478f-4ce3-b5bd-4671e611cd52",
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "gene_names = data_sub.var.index.values\n",
266 | "\n",
267 | "selected_dcit_top = {}\n",
268 | "for i, name in enumerate(data_sub.obs['celltype_major'].cat.categories):\n",
269 | " weights = trace['w_top'].mean(axis=0)[:, i].cpu().numpy()\n",
270 | " top4 = np.argpartition(weights, -4)[-4:]\n",
271 | " selected_dcit_top[name] = gene_names[top4]\n",
272 | "selected_dcit_top['hk'] = np.array(['RPLP0'])"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "id": "41a0b2e6-91e4-4761-8c20-7387e09689c1",
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "fig = sc.pl.dotplot(data, selected_dcit_top, 'celltype_major')"
283 | ]
284 | }
285 | ],
286 | "metadata": {
287 | "kernelspec": {
288 | "display_name": "Environment (test_schierarchy)",
289 | "language": "python",
290 | "name": "test_schierarchy"
291 | },
292 | "language_info": {
293 | "codemirror_mode": {
294 | "name": "ipython",
295 | "version": 3
296 | },
297 | "file_extension": ".py",
298 | "mimetype": "text/x-python",
299 | "name": "python",
300 | "nbconvert_exporter": "python",
301 | "pygments_lexer": "ipython3",
302 | "version": "3.9.7"
303 | }
304 | },
305 | "nbformat": 4,
306 | "nbformat_minor": 5
307 | }
308 |
--------------------------------------------------------------------------------
/docs/notebooks/marker_selection_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "7f51c100-e7c5-499e-8d55-82145a997d3c",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import sys\n",
11 | "\n",
12 | "#if installed somewhere else\n",
13 | "sys.path.insert(1, '/nfs/team205/vk7/sanger_projects/BayraktarLab/cell2location/')\n",
14 | "sys.path.insert(1, '/lustre/scratch119/casm/team299ly/al15/projects/scHierarchy/')\n",
15 | "\n",
16 | "import scanpy as sc\n",
17 | "import anndata\n",
18 | "import pandas as pd\n",
19 | "import numpy as np\n",
20 | "import matplotlib.pyplot as plt \n",
21 | "import matplotlib as mpl\n",
22 | "\n",
23 | "import cell2location\n",
24 | "import scvi\n",
25 | "import schierarchy\n",
26 | "\n",
27 | "from matplotlib import rcParams\n",
28 | "rcParams['pdf.fonttype'] = 42 # enables correct plotting of text\n",
29 | "import seaborn as sns"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "id": "db0a5b58-f850-4825-b7a5-344f2eb53375",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "#location of scRNA data\n",
40 | "sc_data_folder = '/nfs/casm/team299ly/al15/projects/sc-breast/data_atlas/'\n",
41 | "#location where the result is stored \n",
42 | "results_folder = '/nfs/casm/team299ly/al15/projects/sc-breast/data_atlas/results/'\n",
43 | "\n",
44 | "#prefix for experiment\n",
45 | "ref_run_name = f'{results_folder}hierarchical_logist/'"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "id": "0a892146-efb1-4c2b-a4b9-39da51888bbb",
51 | "metadata": {},
52 | "source": [
53 | "## Load Breast cancer scRNA dataset"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "id": "8d257eb3-2399-447c-ae54-ad0a795afe08",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "## read data\n",
64 | "adata_ref = anndata.read_h5ad(sc_data_folder + \"atals_processed.h5ad\")\n",
65 | "adata_ref.layers['processed'] = adata_ref.X\n",
66 | "#revert log transformation (if data is originally transformed)\n",
67 | "adata_ref.X = ((adata_ref.layers['processed'].expm1() / 10000).multiply(adata_ref.obs[['nCount_RNA']].values)).tocsr() #real data \n",
68 | "\n",
69 | "# mitochondrial genes\n",
70 | "adata_ref.var['mt'] = adata_ref.var_names.str.startswith('MT-') \n",
71 | "# ribosomal genes\n",
72 | "adata_ref.var['ribo'] = adata_ref.var_names.str.startswith((\"RPS\",\"RPL\"))\n",
73 | "# hemoglobin genes.\n",
74 | "adata_ref.var['hb'] = adata_ref.var_names.str.contains((\"^HB[^(P)]\"))\n",
75 | "\n",
76 | "#delete ribo mt and hb genes \n",
77 | "adata_ref = adata_ref[:, np.logical_and(np.logical_and(~adata_ref.var['mt'], ~adata_ref.var['ribo']), ~adata_ref.var['hb'])]"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "id": "e88ed26c-d3b4-464e-8513-82d1102bf65c",
83 | "metadata": {},
84 | "source": [
85 | "### Process single cell data"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "id": "9250e711-e373-4aea-8856-e28696ca1591",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "# before we estimate the reference cell type signature we recommend to perform very permissive genes selection\n",
96 | "# in this 2D histogram orange rectangle lays over excluded genes.\n",
97 | "# In this case, the downloaded dataset was already filtered using this method,\n",
98 | "# hence no density under the orange rectangle\n",
99 | "from cell2location.utils.filtering import filter_genes\n",
100 | "selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)\n",
101 | "adata_ref = adata_ref[:, selected].copy()\n",
102 | "\n",
103 | "#remove genes which are omnispread\n",
104 | "max_cutoff = (adata_ref.var['n_cells'] / adata_ref.n_obs) > 0.8\n",
105 | "print(f'% of genes expressed in more than 80% of cells {max_cutoff.mean()}')\n",
106 | "\n",
107 | "# filter the object\n",
108 | "adata_ref = adata_ref[:, ~max_cutoff].copy()"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "id": "c9540a8d-892e-4364-a0b4-d182e241bf83",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "%%time\n",
119 | "#qunatile normalise log_transformed data, could be replaced with a transformation of your choice\n",
120 | "from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf\n",
121 | "adata_ref.layers[\"cdf\"] = np.apply_along_axis(\n",
122 | " data_to_zero_truncated_cdf, 0, adata_ref.layers[\"processed\"].toarray()\n",
123 | ")"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "id": "3cf6c2f7-c5ee-4fc4-8ae0-76d1c8b7b235",
129 | "metadata": {},
130 | "source": [
131 | "## Initialise and run the model"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "id": "43a42209-900d-4a1c-9be6-031aa9028f3a",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "from schierarchy import LogisticModel\n",
142 | "\n",
143 | "#names of label columns from the most coarse to the most fine\n",
144 | "level_keys = ['celltype_major', 'celltype_minor', 'celltype_subset']\n",
145 | "\n",
146 | "LogisticModel.setup_anndata(adata_ref, layer=\"cdf\", level_keys=level_keys)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "id": "edfb39e4-f98a-4a2d-b4e7-f16291ed7a11",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "# train regression model to get signatures of cell types\n",
157 | "from schierarchy import LogisticModel\n",
158 | "learning_mode = 'fixed-sigma'\n",
159 | "mod = LogisticModel(adata_ref, laplace_learning_mode=learning_mode)\n",
160 | "\n",
161 | "# Use all data for training (validation not implemented yet, train_size=1)\n",
162 | "mod.train(max_epochs=600, batch_size=2500, train_size=1, lr=0.01, use_gpu=True)\n",
163 | "\n",
164 | "# plot ELBO loss history during training, removing first 20 epochs from the plot\n",
165 | "mod.plot_history(50)"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "id": "8b3fb8ca-c349-4a1f-8308-490877955c18",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "%%time\n",
176 | "\n",
177 | "# In this section, we export the estimated gene weights and per-cell probabilities \n",
178 | "# (summary of the posterior distribution).\n",
179 | "adata_ref = mod.export_posterior(\n",
180 | " adata_ref, sample_kwargs={'num_samples': 50, 'batch_size': 2500, 'use_gpu': True}\n",
181 | ")\n",
182 | "\n",
183 | "# Save model\n",
184 | "mod.save(f\"{ref_run_name}\", overwrite=True)\n",
185 | "\n",
186 | "# Save anndata object with results\n",
187 | "adata_file = f\"{ref_run_name}/sc.h5ad\"\n",
188 | "adata_ref.write(adata_file)\n",
189 | "adata_file"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "id": "206c4872-c283-48c0-8e10-0a5f06113f27",
195 | "metadata": {},
196 | "source": [
197 | "## Load model"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "id": "1fec3185-f973-494d-879c-106acc2ac45d",
204 | "metadata": {},
205 | "outputs": [],
206 | "source": [
207 | "#if you're not making predictions - just work with adata_file, it already has stored results \n",
208 | "model = LogisticModel.load(ref_run_name, adata_ref)\n",
209 | "adata_ref = model.export_posterior(\n",
210 | " adata_ref, sample_kwargs={'num_samples': 50, 'batch_size': 2500, 'use_gpu': True}\n",
211 | ")"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "id": "504b907f-b12a-438d-a96a-51f3524d5b1d",
217 | "metadata": {},
218 | "source": [
219 | "## Visualise hierarchy of marker genes"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "id": "6423dfe8-8688-4567-bbaa-138438d50338",
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "adata_file = f\"{ref_run_name}/sc.h5ad\"\n",
230 | "adata_ref = sc.read(adata_file)"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "id": "d41bc1c8-a183-4e16-a38f-944e4b9b407f",
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "#complete slected gene plots \n",
241 | "gene_names = adata_ref.var['gene_ids'].values\n",
242 | "observed_labels = []\n",
243 | "\n",
244 | "for level in level_keys:\n",
245 | " selected_dcit = {}\n",
246 | " for i, name in enumerate(adata_ref.obs[level].cat.categories):\n",
247 | " weights = adata_ref.varm[f'means_weight_{level}'][f'means_weight_{level}_{name}'].values\n",
248 | " top_n = np.argpartition(weights, -3)[-3:]\n",
249 | " if name not in observed_labels:\n",
250 | " selected_dcit[name] = gene_names[top_n]\n",
251 | " fig = sc.pl.dotplot(adata_ref, selected_dcit, level, log=True, gene_symbols='gene_ids')"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "id": "4fdd013a-2f7e-40f4-8065-2d86d6ca5262",
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "ind = adata_ref.obs[level_keys[0]].isin(['T-cells'])\n",
262 | "adata_ref_subset = adata_ref[ind, :]\n",
263 | "\n",
264 | "gene_names = adata_ref.var['gene_ids'].values\n",
265 | "observed_labels = []\n",
266 | "\n",
267 | "\n",
268 | "for level in level_keys:\n",
269 | " selected_dcit = {}\n",
270 | " for i, name in enumerate(adata_ref_subset.obs[level].cat.categories):\n",
271 | " weights = adata_ref_subset.varm[f'means_weight_{level}'][f'means_weight_{level}_{name}'].values\n",
272 | " top_n = np.argpartition(weights, -5)[-5:]\n",
273 | " if name not in observed_labels:\n",
274 | " selected_dcit[name] = gene_names[top_n]\n",
275 | " observed_labels.append(name)\n",
276 | " ind = adata_ref_subset.obs[level].isin(list(selected_dcit.keys()))\n",
277 | " adata_ref_subset_v2 = adata_ref_subset[ind, :]\n",
278 | " if adata_ref_subset_v2.n_obs > 0:\n",
279 | " fig = sc.pl.dotplot(adata_ref_subset_v2, selected_dcit, level, log=True, gene_symbols='gene_ids')"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "id": "0184e2b9-758e-498a-88a6-d62799739ecb",
286 | "metadata": {},
287 | "outputs": [],
288 | "source": []
289 | }
290 | ],
291 | "metadata": {
292 | "kernelspec": {
293 | "display_name": "Environment (schierarchy)",
294 | "language": "python",
295 | "name": "schierarchy"
296 | },
297 | "language_info": {
298 | "codemirror_mode": {
299 | "name": "ipython",
300 | "version": 3
301 | },
302 | "file_extension": ".py",
303 | "mimetype": "text/x-python",
304 | "name": "python",
305 | "nbconvert_exporter": "python",
306 | "pygments_lexer": "ipython3",
307 | "version": "3.9.7"
308 | }
309 | },
310 | "nbformat": 4,
311 | "nbformat_minor": 5
312 | }
313 |
--------------------------------------------------------------------------------
/docs/references.rst:
--------------------------------------------------------------------------------
1 | References
2 | ==========
3 |
4 | .. note:: The reference below can be referenced in docstrings. See :class:`~mypackage.MyModule` for an example.
5 |
6 | .. [Lopez18] Romain Lopez, Jeffrey Regier, Michael Cole, Michael I. Jordan, Nir Yosef (2018),
7 | *Deep generative modeling for single-cell transcriptomics*,
8 | `Nature Methods `__.
9 |
--------------------------------------------------------------------------------
/docs/release_notes/index.rst:
--------------------------------------------------------------------------------
1 | Release notes
2 | =============
3 |
4 | This is the list of changes to scvi-tools between each release. Full commit history
5 | is available in the `commit logs `_.
6 |
7 | Version 0.1
8 | -----------
9 | .. toctree::
10 | :maxdepth: 2
11 |
12 | v0.1.1
13 | v0.1.0
14 |
--------------------------------------------------------------------------------
/docs/release_notes/v0.1.0.rst:
--------------------------------------------------------------------------------
1 | New in 0.1.0 (2020-01-25)
2 | -------------------------
3 | Initial release of the skeleton. Complies with `scvi-tools>=0.9.0a0`.
4 |
--------------------------------------------------------------------------------
/docs/release_notes/v0.1.1.rst:
--------------------------------------------------------------------------------
1 | New in 0.1.1 (2020-01-25)
2 | -------------------------
3 | Add more documentation.
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | include_trailing_comma = true
3 | multi_line_output = 3
4 | profile = "black"
5 | skip_glob = ["docs/*", "schierarchy/__init__.py"]
6 |
7 | [tool.poetry]
8 | authors = ["Artem Lomakin ", "Vitalii Kleshchevnikov "]
9 | classifiers = [
10 | "Development Status :: 4 - Beta",
11 | "Intended Audience :: Science/Research",
12 | "Natural Language :: English",
13 | "Programming Language :: Python :: 3.9",
14 | "Programming Language :: Python :: 3.10",
15 | "Operating System :: MacOS :: MacOS X",
16 | "Operating System :: Microsoft :: Windows",
17 | "Operating System :: POSIX :: Linux",
18 | "Topic :: Scientific/Engineering :: Bio-Informatics",
19 | ]
20 | description = "Hierachical cell identity toolkit."
21 | documentation = "https://github.com/vitkl/schierarchy"
22 | homepage = "https://github.com/vitkl/schierarchy"
23 | license = "Apache License, Version 2.0"
24 | name = "schierarchy"
25 | packages = [
26 | {include = "schierarchy"},
27 | ]
28 | readme = "README.md"
29 | version = "0.0.2"
30 |
31 | [tool.poetry.dependencies]
32 | anndata = ">=0.7.5"
33 | black = {version = "==22.3.0", optional = true}
34 | codecov = {version = ">=2.0.8", optional = true}
35 | flake8 = {version = ">=3.7.7", optional = true}
36 | importlib-metadata = {version = "^1.0", python = "<3.8"}
37 | ipython = {version = ">=7.1.1", optional = true}
38 | isort = {version = ">=5.7", optional = true}
39 | jupyter = {version = ">=1.0", optional = true}
40 | leidenalg = {version = "*", optional = true}
41 | loompy = {version = ">=3.0.6", optional = true}
42 | nbconvert = {version = ">=5.4.0", optional = true}
43 | nbformat = {version = ">=4.4.0", optional = true}
44 | nbsphinx = {version = "*", optional = true}
45 | nbsphinx-link = {version = "*", optional = true}
46 | pre-commit = {version = ">=2.7.1", optional = true}
47 | pydata-sphinx-theme = {version = ">=0.4.0", optional = true}
48 | pytest = {version = ">=4.4", optional = true}
49 | python = ">=3.9.18"
50 | python-igraph = {version = "*", optional = true}
51 | scanpy = {version = ">=1.6", optional = true}
52 | scanpydoc = {version = ">=0.5", optional = true}
53 | scikit-misc = {version = ">=0.1.3", optional = true}
54 | cell2location = {git = "https://github.com/BayraktarLab/cell2location.git@hires_sliding_window"}
55 | pyro-ppl = {version = ">=1.8.0"}
56 | scvi-tools = ">=1.0.0"
57 | sphinx = {version = "^3.0", optional = true}
58 | sphinx-autodoc-typehints = {version = "*", optional = true}
59 | sphinx-rtd-theme = {version = "*", optional = true}
60 | typing_extensions = {version = "*", python = "<3.8"}
61 |
62 | [tool.poetry.extras]
63 | dev = ["black", "pytest", "flake8", "codecov", "scanpy", "loompy", "jupyter", "nbformat", "nbconvert", "pre-commit", "isort"]
64 | docs = [
65 | "sphinx",
66 | "scanpydoc",
67 | "nbsphinx",
68 | "nbsphinx-link",
69 | "ipython",
70 | "pydata-sphinx-theme",
71 | "typing_extensions",
72 | "sphinx-autodoc-typehints",
73 | "sphinx-rtd-theme",
74 | ]
75 | tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc"]
76 |
77 | [tool.poetry.dev-dependencies]
78 |
79 | [build-system]
80 | build-backend = "poetry.masonry.api"
81 | requires = [
82 | "poetry>=1.0",
83 | "setuptools", # keep it here or "pip install -e" would fail
84 | ]
85 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | image: latest
4 | sphinx:
5 | configuration: docs/conf.py
6 | python:
7 | version: 3.9
8 | install:
9 | - method: pip
10 | path: .
11 | extra_requirements:
12 | - docs
13 |
--------------------------------------------------------------------------------
/schierarchy/__init__.py:
--------------------------------------------------------------------------------
1 | """scvi-tools-skeleton."""
2 |
3 | import logging
4 |
5 | from rich.console import Console
6 | from rich.logging import RichHandler
7 |
8 | from .regression._reference_model import RegressionModel
9 | from .regression._reference_module import RegressionBackgroundDetectionTechPyroModel
10 |
11 | from .logistic._logistic_model import LogisticModel
12 | from .logistic._logistic_module import HierarchicalLogisticPyroModel
13 |
14 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
15 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
16 | try:
17 | import importlib.metadata as importlib_metadata
18 | except ModuleNotFoundError:
19 | import importlib_metadata
20 |
21 | package_name = "schierarchy"
22 | __version__ = importlib_metadata.version(package_name)
23 |
24 | logger = logging.getLogger(__name__)
25 | # set the logging level
26 | logger.setLevel(logging.INFO)
27 |
28 | # nice logging outputs
29 | console = Console(force_terminal=True)
30 | if console.is_jupyter is True:
31 | console.is_jupyter = False
32 | ch = RichHandler(show_path=False, console=console, show_time=False)
33 | formatter = logging.Formatter("scHierarchy: %(message)s")
34 | ch.setFormatter(formatter)
35 | logger.addHandler(ch)
36 |
37 | # this prevents double outputs
38 | logger.propagate = False
39 |
40 | __all__ = [
41 | "RegressionModel",
42 | "LogisticModel",
43 | "HierarchicalLogisticPyroModel",
44 | "RegressionBackgroundDetectionTechPyroModel",
45 | ]
46 |
--------------------------------------------------------------------------------
/schierarchy/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/base/__init__.py
--------------------------------------------------------------------------------
/schierarchy/base/_pyro_base_regression_module.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from cell2location.models.base._pyro_mixin import AutoGuideMixinModule, init_to_value
4 | from scvi._compat import Literal
5 | from scvi.module.base import PyroBaseModuleClass
6 |
7 |
8 | class RegressionBaseModule(PyroBaseModuleClass, AutoGuideMixinModule):
9 | def __init__(
10 | self,
11 | model,
12 | amortised: bool = False,
13 | encoder_mode: Literal["single", "multiple", "single-multiple"] = "single",
14 | encoder_kwargs=None,
15 | create_autoguide_kwargs: Optional[dict] = None,
16 | **kwargs,
17 | ):
18 | """
19 | Module class which defines AutoGuide given model. Supports multiple model architectures.
20 |
21 | Parameters
22 | ----------
23 | amortised
24 | boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters?
25 | encoder_mode
26 | Use single encoder for all variables ("single"), one encoder per variable ("multiple")
27 | or a single encoder in the first step and multiple encoders in the second step ("single-multiple").
28 | encoder_kwargs
29 | arguments for Neural Network construction (scvi.nn.FCLayers)
30 | kwargs
31 | arguments for specific model class - e.g. number of genes, values of the prior distribution
32 | """
33 | super().__init__()
34 | self.hist = []
35 |
36 | self._model = model(**kwargs)
37 | self._amortised = amortised
38 | if create_autoguide_kwargs is None:
39 | create_autoguide_kwargs = dict()
40 |
41 | self._guide = self._create_autoguide(
42 | model=self.model,
43 | amortised=self.is_amortised,
44 | encoder_kwargs=encoder_kwargs,
45 | encoder_mode=encoder_mode,
46 | init_loc_fn=self.init_to_value,
47 | n_cat_list=[kwargs["n_batch"]],
48 | **create_autoguide_kwargs,
49 | )
50 |
51 | self._get_fn_args_from_batch = self._model._get_fn_args_from_batch
52 |
53 | @property
54 | def model(self):
55 | return self._model
56 |
57 | @property
58 | def guide(self):
59 | return self._guide
60 |
61 | @property
62 | def is_amortised(self):
63 | return self._amortised
64 |
65 | @property
66 | def list_obs_plate_vars(self):
67 | return self.model.list_obs_plate_vars()
68 |
69 | def init_to_value(self, site):
70 |
71 | if getattr(self.model, "np_init_vals", None) is not None:
72 | init_vals = {
73 | k: getattr(self.model, f"init_val_{k}")
74 | for k in self.model.np_init_vals.keys()
75 | }
76 | else:
77 | init_vals = dict()
78 | return init_to_value(site=site, values=init_vals)
79 |
--------------------------------------------------------------------------------
/schierarchy/logistic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/logistic/__init__.py
--------------------------------------------------------------------------------
/schierarchy/logistic/_logistic_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from datetime import date
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from anndata import AnnData
8 | from cell2location.models.base._pyro_mixin import (
9 | AutoGuideMixinModule,
10 | PltExportMixin,
11 | QuantileMixin,
12 | init_to_value,
13 | )
14 | from pyro import clear_param_store
15 | from pyro.infer.autoguide import AutoNormalMessenger, init_to_feasible, init_to_mean
16 | from scvi import REGISTRY_KEYS
17 | from scvi.data import AnnDataManager
18 | from scvi.data.fields import CategoricalJointObsField, LayerField, NumericalObsField
19 | from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin
20 | from scvi.module.base import PyroBaseModuleClass
21 | from scvi.utils import setup_anndata_dsp
22 |
23 | from ._logistic_module import HierarchicalLogisticPyroModel
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def infer_tree(labels_df, level_keys):
29 | """
30 |
31 | Parameters
32 | ----------
33 | labels_df
34 | DataFrame with annotations
35 | level_keys
36 | List of column names from top to bottom levels (from less detailed to more detailed)
37 |
38 | Returns
39 | -------
40 | List of edges between level len(n_levels - 1 )
41 |
42 | """
43 | # for multiple layers of hierarchy
44 | if len(level_keys) > 1:
45 | tree_inferred = [{} for i in range(len(level_keys) - 1)]
46 | for i in range(len(level_keys) - 1):
47 | layer_p = labels_df.loc[:, level_keys[i]]
48 | layer_ch = labels_df.loc[:, level_keys[i + 1]]
49 | for j in range(labels_df.shape[0]):
50 | if layer_p[j] not in tree_inferred[i].keys():
51 | tree_inferred[i][layer_p[j]] = [layer_ch[j]]
52 | else:
53 | if layer_ch[j] not in tree_inferred[i][layer_p[j]]:
54 | tree_inferred[i][layer_p[j]].append(layer_ch[j])
55 | # if only one level
56 | else:
57 | tree_inferred = [list(labels_df[level_keys[0]].unique())]
58 |
59 | return tree_inferred
60 |
61 |
62 | """def _setup_summary_stats(adata, level_keys):
63 | n_cells = adata.shape[0]
64 | n_vars = adata.shape[1]
65 | n_cells_per_label_per_level = [
66 | adata.obs.groupby(group).size().values.astype(int) for group in level_keys
67 | ]
68 |
69 | n_levels = len(level_keys)
70 |
71 | summary_stats = {
72 | "n_cells": n_cells,
73 | "n_vars": n_vars,
74 | "n_levels": n_levels,
75 | "n_cells_per_label_per_level": n_cells_per_label_per_level,
76 | }
77 |
78 | adata.uns["_scvi"]["summary_stats"] = summary_stats
79 | adata.uns["tree"] = infer_tree(adata.obsm["_scvi_extra_categoricals"], level_keys)
80 |
81 | logger.info(
82 | "Successfully registered anndata object containing {} cells, {} vars, "
83 | "{} cell annotation levels.".format(n_cells, n_vars, n_levels)
84 | )
85 | return summary_stats"""
86 |
87 |
88 | class LogisticBaseModule(PyroBaseModuleClass, AutoGuideMixinModule):
89 | def __init__(
90 | self,
91 | model,
92 | init_loc_fn=init_to_mean(fallback=init_to_feasible),
93 | guide_class=AutoNormalMessenger,
94 | guide_kwargs: Optional[dict] = None,
95 | **kwargs,
96 | ):
97 | """
98 | Module class which defines AutoGuide given model. Supports multiple model architectures.
99 |
100 | Parameters
101 | ----------
102 | amortised
103 | boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters?
104 | encoder_mode
105 | Use single encoder for all variables ("single"), one encoder per variable ("multiple")
106 | or a single encoder in the first step and multiple encoders in the second step ("single-multiple").
107 | encoder_kwargs
108 | arguments for Neural Network construction (scvi.nn.FCLayers)
109 | kwargs
110 | arguments for specific model class - e.g. number of genes, values of the prior distribution
111 | """
112 | super().__init__()
113 | self.hist = []
114 | self._model = model(**kwargs)
115 |
116 | if guide_kwargs is None:
117 | guide_kwargs = dict()
118 |
119 | self._guide = guide_class(
120 | self.model,
121 | init_loc_fn=init_loc_fn,
122 | **guide_kwargs
123 | # create_plates=model.create_plates,
124 | )
125 |
126 | self._get_fn_args_from_batch = self._model._get_fn_args_from_batch
127 |
128 | @property
129 | def model(self):
130 | return self._model
131 |
132 | @property
133 | def guide(self):
134 | return self._guide
135 |
136 | @property
137 | def list_obs_plate_vars(self):
138 | return self.model.list_obs_plate_vars()
139 |
140 | def init_to_value(self, site):
141 |
142 | if getattr(self.model, "np_init_vals", None) is not None:
143 | init_vals = {
144 | k: getattr(self.model, f"init_val_{k}")
145 | for k in self.model.np_init_vals.keys()
146 | }
147 | else:
148 | init_vals = dict()
149 | return init_to_value(site=site, values=init_vals)
150 |
151 |
152 | class LogisticModel(
153 | QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExportMixin, BaseModelClass
154 | ):
155 | """
156 | Model which estimates per cluster average mRNA count account for batch effects. User-end model class.
157 |
158 | https://github.com/BayraktarLab/cell2location
159 |
160 | Parameters
161 | ----------
162 | adata
163 | single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
164 | level_keys
165 | List of column names from top to bottom levels (from less detailed to more detailed)
166 | use_gpu
167 | Use the GPU?
168 | **model_kwargs
169 | Keyword args for :class:`~scvi.external.LocationModelLinearDependentWMultiExperimentModel`
170 |
171 | Examples
172 | --------
173 | TODO add example
174 | >>>
175 | """
176 |
177 | def __init__(
178 | self,
179 | adata: AnnData,
180 | laplace_learning_mode: str = "fixed-sigma",
181 | # tree: list,
182 | model_class=None,
183 | **model_kwargs,
184 | ):
185 | # in case any other model was created before that shares the same parameter names.
186 | clear_param_store()
187 |
188 | super().__init__(adata)
189 |
190 | level_keys = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[
191 | "field_keys"
192 | ]
193 | self.n_cells_per_label_per_level_ = [
194 | self.adata.obs.groupby(group).size().values.astype(int)
195 | for group in level_keys
196 | ]
197 | self.n_levels_ = len(level_keys)
198 | self.level_keys_ = level_keys
199 | self.tree_ = infer_tree(
200 | self.adata_manager.get_from_registry(REGISTRY_KEYS.CAT_COVS_KEY), level_keys
201 | )
202 | self.laplace_learning_mode_ = laplace_learning_mode
203 |
204 | if model_class is None:
205 | model_class = HierarchicalLogisticPyroModel
206 | self.module = LogisticBaseModule(
207 | model=model_class,
208 | n_obs=self.summary_stats["n_cells"],
209 | n_vars=self.summary_stats["n_vars"],
210 | n_levels=self.n_levels_,
211 | n_cells_per_label_per_level=self.n_cells_per_label_per_level_,
212 | tree=self.tree_,
213 | laplace_learning_mode=self.laplace_learning_mode_,
214 | **model_kwargs,
215 | )
216 | self.samples = dict()
217 | self.init_params_ = self._get_init_params(locals())
218 |
219 | @classmethod
220 | @setup_anndata_dsp.dedent
221 | def setup_anndata(
222 | cls,
223 | adata: AnnData,
224 | layer: Optional[str] = None,
225 | level_keys: Optional[list] = None,
226 | **kwargs,
227 | ):
228 | """
229 | %(summary)s.
230 |
231 | Parameters
232 | ----------
233 | %(param_layer)s
234 | %(param_batch_key)s
235 | %(param_labels_key)s
236 | %(param_cat_cov_keys)s
237 | %(param_cont_cov_keys)s
238 | """
239 | setup_method_args = cls._get_setup_method_args(**locals())
240 | adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
241 | anndata_fields = [
242 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
243 | CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, level_keys),
244 | NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
245 | ]
246 | adata_manager = AnnDataManager(
247 | fields=anndata_fields, setup_method_args=setup_method_args
248 | )
249 | adata_manager.register_fields(adata, **kwargs)
250 | cls.register_manager(adata_manager)
251 |
252 | def _export2adata(self, samples):
253 | r"""
254 | Export key model variables and samples
255 |
256 | Parameters
257 | ----------
258 | samples
259 | dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()``
260 |
261 | Returns
262 | -------
263 | Updated dictionary with additional details is saved to ``adata.uns['mod']``.
264 | """
265 | # add factor filter and samples of all parameters to unstructured data
266 | results = {
267 | "model_name": str(self.module.__class__.__name__),
268 | "date": str(date.today()),
269 | "var_names": self.adata.var_names.tolist(),
270 | "obs_names": self.adata.obs_names.tolist(),
271 | "post_sample_means": samples["post_sample_means"] if "post_sample_means" in samples else None,
272 | "post_sample_stds": samples["post_sample_stds"] if "post_sample_stds" in samples else None,
273 | }
274 | # add posterior quantiles
275 | for k, v in samples.items():
276 | if k.startswith("post_sample_"):
277 | results[k] = v
278 |
279 | return results
280 |
281 | def export_posterior(
282 | self,
283 | adata,
284 | prediction: bool = False,
285 | use_quantiles: bool = False,
286 | sample_kwargs: Optional[dict] = None,
287 | export_slot: str = "mod",
288 | add_to_varm: list = ["means", "stds", "q05", "q95"],
289 | ):
290 | """
291 | Summarise posterior distribution and export results (cell abundance) to anndata object:
292 | 1. adata.obsm: Estimated references expression signatures (average mRNA count in each cell type),
293 | as pd.DataFrames for each posterior distribution summary `add_to_varm`,
294 | posterior mean, sd, 5% and 95% quantiles (['means', 'stds', 'q05', 'q95']).
295 | If export to adata.varm fails with error, results are saved to adata.var instead.
296 | 2. adata.uns: Posterior of all parameters, model name, date,
297 | cell type names ('factor_names'), obs and var names.
298 |
299 | Parameters
300 | ----------
301 | adata
302 | anndata object where results should be saved
303 | prediction
304 | Prediction mode predicts cell labels on new data.
305 | sample_kwargs
306 | arguments for self.sample_posterior (generating and summarising posterior samples), namely:
307 | num_samples - number of samples to use (Default = 1000).
308 | batch_size - data batch size (keep low enough to fit on GPU, default 2048).
309 | use_gpu - use gpu for generating samples?
310 | export_slot
311 | adata.uns slot where to export results
312 | add_to_varm
313 | posterior distribution summary to export in adata.varm (['means', 'stds', 'q05', 'q95']).
314 | Returns
315 | -------
316 |
317 | """
318 |
319 | sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict()
320 |
321 | label_keys = list(
322 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[
323 | "field_keys"
324 | ]
325 | )
326 |
327 | # when prediction mode change to evaluation mode and swap adata object
328 | if prediction:
329 | self.module.eval()
330 | self.module.model.prediction = True
331 | # use version of this function for prediction
332 | self.module._get_fn_args_from_batch = (
333 | self.module.model._get_fn_args_from_batch
334 | )
335 | # resize plates for according to the validation object
336 | self.module.model.n_obs = adata.n_obs
337 | # create index column
338 | adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
339 | # for minibatch learning, selected indices lay in "ind_x"
340 | # scvi.data.register_tensor_from_anndata(
341 | # adata,
342 | # registry_key="ind_x",
343 | # adata_attr_name="obs",
344 | # adata_key_name="_indices",
345 | # )
346 | # if all columns with labels don't exist, create them and fill with 0s
347 | if np.all(~np.isin(label_keys, adata.obs.columns)):
348 | adata.obs.loc[:, label_keys] = 0
349 | # substitute adata object
350 | adata_train = self.adata.copy()
351 | self.adata = self._validate_anndata(adata)
352 | # self.adata = adata
353 |
354 | if use_quantiles:
355 | add_to_varm = [i for i in add_to_varm if (i not in ["means", "stds"]) and ("q" in i)]
356 | if len(add_to_varm) == 0:
357 | raise ValueError("No quantiles to export - please add add_to_obsm=['q05', 'q50', 'q95'].")
358 | self.samples = dict()
359 | for i in add_to_varm:
360 | q = float(f"0.{i[1:]}")
361 | self.samples[f"post_sample_{i}"] = self.posterior_quantile(q=q, **sample_kwargs)
362 | else:
363 | # generate samples from posterior distributions for all parameters
364 | # and compute mean, 5%/95% quantiles and standard deviation
365 | self.samples = self.sample_posterior(**sample_kwargs)
366 |
367 | # revert adata object substitution
368 | self.adata = adata_train
369 | self.module.eval()
370 | self.module.model.prediction = False
371 | # re-set default version of this function
372 | self.module._get_fn_args_from_batch = (
373 | self.module.model._get_fn_args_from_batch
374 | )
375 | obs_names = adata.obs_names
376 | else:
377 | if use_quantiles:
378 | add_to_varm = [i for i in add_to_varm if (i not in ["means", "stds"]) and ("q" in i)]
379 | if len(add_to_varm) == 0:
380 | raise ValueError("No quantiles to export - please add add_to_obsm=['q05', 'q50', 'q95'].")
381 | self.samples = dict()
382 | for i in add_to_varm:
383 | q = float(f"0.{i[1:]}")
384 | self.samples[f"post_sample_{i}"] = self.posterior_quantile(q=q, **sample_kwargs)
385 | else:
386 | # generate samples from posterior distributions for all parameters
387 | # and compute mean, 5%/95% quantiles and standard deviation
388 | self.samples = self.sample_posterior(**sample_kwargs)
389 | obs_names = self.adata.obs_names
390 |
391 | # export posterior distribution summary for all parameters and
392 | # annotation (model, date, var, obs and cell type names) to anndata object
393 | adata.uns[export_slot] = self._export2adata(self.samples)
394 |
395 | # export estimated expression in each cluster
396 | # first convert np.arrays to pd.DataFrames with cell type and observation names
397 | # data frames contain mean, 5%/95% quantiles and standard deviation, denoted by a prefix
398 | for i in range(self.n_levels_):
399 | categories = list(
400 | list(
401 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[
402 | "mappings"
403 | ].values()
404 | )[i]
405 | )
406 | for k in add_to_varm:
407 | sample_df = pd.DataFrame(
408 | self.samples[f"post_sample_{k}"].get(f"weight_level_{i}", None),
409 | columns=[f"{k}_weight_{label_keys[i]}_{c}" for c in categories],
410 | index=self.adata.var_names,
411 | )
412 | try:
413 | adata.varm[f"{k}_weight_{label_keys[i]}"] = sample_df.loc[
414 | adata.var_names, :
415 | ]
416 | except ValueError:
417 | # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names`
418 | adata.var[sample_df.columns] = sample_df.loc[adata.var_names, :]
419 |
420 | sample_df = pd.DataFrame(
421 | self.samples[f"post_sample_{k}"].get(f"label_prob_{i}", None),
422 | columns=obs_names,
423 | index=[f"{k}_label_{label_keys[i]}_{c}" for c in categories],
424 | ).T
425 | try:
426 | # TODO change to user input name
427 | adata.obsm[f"{k}_label_prob_{label_keys[i]}"] = sample_df.loc[
428 | adata.obs_names, :
429 | ]
430 | except ValueError:
431 | # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names`
432 | adata.obs[sample_df.columns] = sample_df.loc[adata.obs_names, :]
433 |
434 | return adata
435 |
436 |
437 | # TODO plot QC - prediction accuracy curve
438 |
--------------------------------------------------------------------------------
/schierarchy/logistic/_logistic_module.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pyro
4 | import pyro.distributions as dist
5 | import torch
6 | from pyro.nn import PyroModule
7 | from scvi import REGISTRY_KEYS
8 |
9 |
10 | class HierarchicalLogisticPyroModel(PyroModule):
11 |
12 | prediction = False
13 |
14 | def __init__(
15 | self,
16 | n_obs,
17 | n_vars,
18 | n_levels,
19 | n_cells_per_label_per_level,
20 | tree,
21 | laplace_prior={"mu": 0.0, "sigma": 0.5, "exp_rate": 3.0},
22 | laplace_learning_mode="fixed-sigma",
23 | init_vals: Optional[dict] = None,
24 | dropout_p: float = 0.1,
25 | use_dropout: bool = True,
26 | use_gene_dropout: bool = False,
27 | ):
28 | r"""
29 |
30 | Parameters
31 | ----------
32 | n_obs
33 | n_vars
34 | n_batch
35 | n_extra_categoricals
36 | laplace_prior
37 | laplace_learning_mode = 'fixed-sigma', 'learn-sigma-single', 'learn-sigma-gene', 'learn-sigma-celltype',
38 | 'learn-sigma-gene-celltype'
39 | """
40 |
41 | ############# Initialise parameters ################
42 | super().__init__()
43 |
44 | self.dropout = torch.nn.Dropout(p=dropout_p)
45 |
46 | self.n_obs = n_obs
47 | self.n_vars = n_vars
48 | self.n_levels = n_levels
49 | self.n_cells_per_label_per_level = n_cells_per_label_per_level
50 | self.tree = tree
51 | self.laplace_prior = laplace_prior
52 | self.laplace_learning_mode = laplace_learning_mode
53 | self.use_dropout = use_dropout
54 | self.use_gene_dropout = use_gene_dropout
55 |
56 | if self.laplace_learning_mode not in [
57 | "fixed-sigma",
58 | "learn-sigma-single",
59 | "learn-sigma-gene",
60 | "learn-sigma-celltype",
61 | "learn-sigma-gene-celltype",
62 | ]:
63 | raise NotImplementedError
64 |
65 | if (init_vals is not None) & (type(init_vals) is dict):
66 | self.np_init_vals = init_vals
67 | for k in init_vals.keys():
68 | self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k]))
69 |
70 | for i in range(self.n_levels):
71 | self.register_buffer(
72 | f"n_cells_per_label_per_level_{i}",
73 | torch.tensor(n_cells_per_label_per_level[i]),
74 | )
75 |
76 | self.register_buffer(
77 | "laplace_prior_mu",
78 | torch.tensor(self.laplace_prior["mu"]),
79 | )
80 |
81 | self.register_buffer(
82 | "laplace_prior_sigma",
83 | torch.tensor(self.laplace_prior["sigma"]),
84 | )
85 |
86 | self.register_buffer(
87 | "exponential_prior_rate",
88 | torch.tensor(self.laplace_prior["exp_rate"]),
89 | )
90 |
91 | self.register_buffer("ones", torch.ones((1, 1)))
92 |
93 | @property
94 | def layers_size(self):
95 | if self.tree is not None:
96 | if len(self.tree) > 0 and type(self.tree[0]) is not list:
97 | return [len(x) for x in self.tree] + [
98 | len(
99 | [item for sublist in self.tree[-1].values() for item in sublist]
100 | )
101 | ]
102 | else:
103 | return [len(self.tree[0])]
104 | else:
105 | return None
106 |
107 | @property
108 | def _get_fn_args_from_batch(self):
109 | if self.prediction:
110 | return self._get_fn_args_from_batch_prediction
111 | else:
112 | return self._get_fn_args_from_batch_training
113 |
114 | @staticmethod
115 | def _get_fn_args_from_batch_training(tensor_dict):
116 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY]
117 | idx = tensor_dict["ind_x"].long().squeeze()
118 | levels = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY]
119 | return (x_data, idx, levels), {}
120 |
121 | @staticmethod
122 | def _get_fn_args_from_batch_prediction(tensor_dict):
123 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY]
124 | idx = tensor_dict["ind_x"].long().squeeze()
125 | return (x_data, idx, idx), {}
126 |
127 | ############# Define the model ################
128 |
129 | def create_plates(self, x_data, idx, levels):
130 | return pyro.plate("obs_plate", size=self.n_obs, dim=-1, subsample=idx)
131 |
132 | def list_obs_plate_vars(self):
133 | """Create a dictionary with the name of observation/minibatch plate,
134 | indexes of model args to provide to encoder,
135 | variable names that belong to the observation plate
136 | and the number of dimensions in non-plate axis of each variable"""
137 |
138 | return {
139 | "name": "obs_plate",
140 | "input": [], # expression data + (optional) batch index
141 | "input_transform": [], # how to transform input data before passing to NN
142 | "sites": {},
143 | }
144 |
145 | def forward(self, x_data, idx, levels):
146 | obs_plate = self.create_plates(x_data, idx, levels)
147 |
148 | if self.use_dropout:
149 | x_data = self.dropout(x_data)
150 | if self.use_gene_dropout:
151 | x_data = x_data * self.dropout(self.ones.expand([1, self.n_vars]).clone())
152 |
153 | f = []
154 | for i in range(self.n_levels):
155 | # create weights for level i
156 | if self.laplace_learning_mode == "fixed-sigma":
157 | w_i = pyro.sample(
158 | f"weight_level_{i}",
159 | dist.Laplace(self.laplace_prior_mu, self.laplace_prior_sigma)
160 | .expand([self.n_vars, self.layers_size[i]])
161 | .to_event(2),
162 | )
163 | elif self.laplace_learning_mode == "learn-sigma-single":
164 | sigma_i = pyro.sample(
165 | f"sigma_level_{i}",
166 | dist.Exponential(self.exponential_prior_rate).expand([1, 1]),
167 | )
168 | w_i = pyro.sample(
169 | f"weight_level_{i}",
170 | dist.Laplace(self.laplace_prior_mu, sigma_i)
171 | .expand([self.n_vars, self.layers_size[i]])
172 | .to_event(2),
173 | )
174 | elif self.laplace_learning_mode == "learn-sigma-gene":
175 | sigma_ig = pyro.sample(
176 | f"sigma_ig_level_{i}",
177 | dist.Exponential(self.exponential_prior_rate)
178 | .expand([self.n_vars])
179 | .to_event(1),
180 | )
181 | w_i = pyro.sample(
182 | f"weight_level_{i}",
183 | dist.Laplace(self.laplace_prior_mu, sigma_ig[:, None])
184 | .expand([self.n_vars, self.layers_size[i]])
185 | .to_event(2),
186 | )
187 | elif self.laplace_learning_mode == "learn-sigma-celltype":
188 | sigma_ic = pyro.sample(
189 | f"sigma_ic_level_{i}",
190 | dist.Exponential(self.exponential_prior_rate)
191 | .expand([self.layers_size[i]])
192 | .to_event(1),
193 | )
194 | w_i = pyro.sample(
195 | f"weight_level_{i}",
196 | dist.Laplace(self.laplace_prior_mu, sigma_ic[None, :])
197 | .expand([self.n_vars, self.layers_size[i]])
198 | .to_event(2),
199 | )
200 | elif self.laplace_learning_mode == "learn-sigma-gene-celltype":
201 | sigma_ig = pyro.sample(
202 | f"sigma_ig_level_{i}",
203 | dist.Exponential(self.exponential_prior_rate)
204 | .expand([self.n_vars])
205 | .to_event(1),
206 | )
207 |
208 | sigma_ic = pyro.sample(
209 | f"sigma_ic_level_{i}",
210 | dist.Exponential(self.exponential_prior_rate)
211 | .expand([self.layers_size[i]])
212 | .to_event(1),
213 | )
214 | w_i = pyro.sample(
215 | f"weight_level_{i}",
216 | dist.Laplace(
217 | self.laplace_prior_mu, sigma_ig[:, None] @ sigma_ic[None, :]
218 | ).to_event(2),
219 | )
220 | # parameter for cluster size weight normalisation w / sqrt(n_cells per cluster)
221 | n_cells_per_label = self.get_buffer(f"n_cells_per_label_per_level_{i}")
222 | if i == 0:
223 | # computer f for level 0 (it is independent from the previous level as it doesn't exist)
224 | f_i = torch.nn.functional.softmax(
225 | torch.matmul(x_data, w_i / n_cells_per_label**0.5), dim=1
226 | )
227 | else:
228 | # initiate f for level > 0
229 | f_i = torch.ones((x_data.shape[0], self.layers_size[i])).to(
230 | x_data.device
231 | )
232 | # compute f as f_(i) * f_(i-1) for each cluster group under the parent node
233 | # multiplication could handle non-tree structures (multiple parents for one child cluster)
234 | for parent, children in self.tree[i - 1].items():
235 | f_i[:, children] *= (
236 | torch.nn.functional.softmax(
237 | torch.matmul(
238 | x_data,
239 | w_i[:, children]
240 | / (n_cells_per_label[children])[None, :] ** 0.5,
241 | ),
242 | dim=1,
243 | )
244 | * f[i - 1][:, parent, None]
245 | )
246 | # record level i probabilities as level i+1 depends on them
247 | f.append(f_i)
248 | with obs_plate:
249 | pyro.deterministic(f"label_prob_{i}", f_i.T)
250 | if not self.prediction:
251 | pyro.sample(
252 | f"likelihood_{i}", dist.Categorical(f_i), obs=levels[:, i]
253 | )
254 |
--------------------------------------------------------------------------------
/schierarchy/regression/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/regression/__init__.py
--------------------------------------------------------------------------------
/schierarchy/regression/_reference_model.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from anndata import AnnData
7 | from cell2location.cluster_averages import compute_cluster_averages
8 | from cell2location.models.base._pyro_base_reference_module import RegressionBaseModule
9 | from cell2location.models.base._pyro_mixin import PltExportMixin, QuantileMixin
10 | from pyro import clear_param_store
11 | from scvi import REGISTRY_KEYS
12 | from scvi.data import AnnDataManager
13 | from scvi.data.fields import (
14 | CategoricalJointObsField,
15 | CategoricalObsField,
16 | LayerField,
17 | NumericalJointObsField,
18 | NumericalObsField,
19 | )
20 | from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin
21 | from scvi.utils import setup_anndata_dsp
22 |
23 | from ._reference_module import RegressionBackgroundDetectionTechPyroModel
24 |
25 |
26 | class RegressionModel(
27 | QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExportMixin, BaseModelClass
28 | ):
29 | """
30 | Model which estimates per cluster average mRNA count account for batch effects. User-end model class.
31 |
32 | https://github.com/BayraktarLab/cell2location
33 |
34 | Parameters
35 | ----------
36 | adata
37 | single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
38 | use_gpu
39 | Use the GPU?
40 | **model_kwargs
41 | Keyword args for :class:`~scvi.external.LocationModelLinearDependentWMultiExperimentModel`
42 |
43 | Examples
44 | --------
45 | TODO add example
46 | >>>
47 | """
48 |
49 | def __init__(
50 | self,
51 | adata: AnnData,
52 | model_class=None,
53 | use_average_as_initial: bool = True,
54 | **model_kwargs,
55 | ):
56 | # in case any other model was created before that shares the same parameter names.
57 | clear_param_store()
58 |
59 | super().__init__(adata)
60 |
61 | if model_class is None:
62 | model_class = RegressionBackgroundDetectionTechPyroModel
63 |
64 | # annotations for cell types
65 | self.n_factors_ = self.summary_stats["n_labels"]
66 | self.factor_names_ = self.adata_manager.get_state_registry(
67 | REGISTRY_KEYS.LABELS_KEY
68 | ).categorical_mapping
69 | # annotations for extra categorical covariates
70 | if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry:
71 | self.extra_categoricals_ = self.adata_manager.get_state_registry(
72 | REGISTRY_KEYS.CAT_COVS_KEY
73 | )
74 | self.n_extra_categoricals_ = self.extra_categoricals_.n_cats_per_key
75 | model_kwargs["n_extra_categoricals"] = self.n_extra_categoricals_
76 |
77 | # use per class average as initial value
78 | if use_average_as_initial:
79 | # compute cluster average expression
80 | aver = self._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY)
81 | model_kwargs["init_vals"] = {
82 | "per_cluster_mu_fg": aver.values.T.astype("float32") + 0.0001
83 | }
84 |
85 | self.module = RegressionBaseModule(
86 | model=model_class,
87 | n_obs=self.summary_stats["n_cells"],
88 | n_vars=self.summary_stats["n_vars"],
89 | n_factors=self.n_factors_,
90 | n_batch=self.summary_stats["n_batch"],
91 | **model_kwargs,
92 | )
93 | self._model_summary_string = f'RegressionBackgroundDetectionTech model with the following params: \nn_factors: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} '
94 | self.init_params_ = self._get_init_params(locals())
95 |
96 | @classmethod
97 | @setup_anndata_dsp.dedent
98 | def setup_anndata(
99 | cls,
100 | adata: AnnData,
101 | layer: Optional[str] = None,
102 | batch_key: Optional[str] = None,
103 | labels_key: Optional[str] = None,
104 | categorical_covariate_keys: Optional[List[str]] = None,
105 | continuous_covariate_keys: Optional[List[str]] = None,
106 | **kwargs,
107 | ):
108 | """
109 | %(summary)s.
110 |
111 | Parameters
112 | ----------
113 | %(param_layer)s
114 | %(param_batch_key)s
115 | %(param_labels_key)s
116 | %(param_cat_cov_keys)s
117 | %(param_cont_cov_keys)s
118 | """
119 | setup_method_args = cls._get_setup_method_args(**locals())
120 | adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
121 | anndata_fields = [
122 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
123 | CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
124 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
125 | CategoricalJointObsField(
126 | REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
127 | ),
128 | NumericalJointObsField(
129 | REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
130 | ),
131 | NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
132 | ]
133 | adata_manager = AnnDataManager(
134 | fields=anndata_fields, setup_method_args=setup_method_args
135 | )
136 | adata_manager.register_fields(adata, **kwargs)
137 | cls.register_manager(adata_manager)
138 |
139 | def train(
140 | self,
141 | max_epochs: Optional[int] = None,
142 | batch_size: int = 2500,
143 | train_size: float = 1,
144 | lr: float = 0.002,
145 | **kwargs,
146 | ):
147 | """Train the model with useful defaults
148 |
149 | Parameters
150 | ----------
151 | max_epochs
152 | Number of passes through the dataset. If `None`, defaults to
153 | `np.min([round((20000 / n_cells) * 400), 400])`
154 | train_size
155 | Size of training set in the range [0.0, 1.0].
156 | batch_size
157 | Minibatch size to use during training. If `None`, no minibatching occurs and all
158 | data is copied to device (e.g., GPU).
159 | lr
160 | Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`).
161 | Specifying optimiser via plan_kwargs overrides this choice of lr.
162 | kwargs
163 | Other arguments to scvi.model.base.PyroSviTrainMixin().train() method
164 | """
165 |
166 | kwargs["max_epochs"] = max_epochs
167 | kwargs["batch_size"] = batch_size
168 | kwargs["train_size"] = train_size
169 | kwargs["lr"] = lr
170 |
171 | super().train(**kwargs)
172 |
173 | def _compute_cluster_averages(self, key=REGISTRY_KEYS.LABELS_KEY):
174 | """
175 | Compute average per cluster (key=REGISTRY_KEYS.LABELS_KEY) or per batch (key=REGISTRY_KEYS.BATCH_KEY).
176 |
177 | Returns
178 | -------
179 | pd.DataFrame with variables in rows and labels in columns
180 | """
181 | # find cell label column
182 | label_col = self.adata_manager.get_state_registry(key).original_key
183 |
184 | # find data slot
185 | x_dict = self.adata_manager.data_registry["X"]
186 | if x_dict["attr_name"] == "X":
187 | use_raw = False
188 | else:
189 | use_raw = True
190 | if x_dict["attr_name"] == "layers":
191 | layer = x_dict["attr_key"]
192 | else:
193 | layer = None
194 |
195 | # compute mean expression of each gene in each cluster/batch
196 | aver = compute_cluster_averages(
197 | self.adata, labels=label_col, use_raw=use_raw, layer=layer
198 | )
199 |
200 | return aver
201 |
202 | def export_posterior(
203 | self,
204 | adata,
205 | sample_kwargs: Optional[dict] = None,
206 | export_slot: str = "mod",
207 | add_to_varm: list = ["means", "stds", "q05", "q95"],
208 | scale_average_detection: bool = True,
209 | ):
210 | """
211 | Summarise posterior distribution and export results (cell abundance) to anndata object:
212 | 1. adata.obsm: Estimated references expression signatures (average mRNA count in each cell type),
213 | as pd.DataFrames for each posterior distribution summary `add_to_varm`,
214 | posterior mean, sd, 5% and 95% quantiles (['means', 'stds', 'q05', 'q95']).
215 | If export to adata.varm fails with error, results are saved to adata.var instead.
216 | 2. adata.uns: Posterior of all parameters, model name, date,
217 | cell type names ('factor_names'), obs and var names.
218 |
219 | Parameters
220 | ----------
221 | adata
222 | anndata object where results should be saved
223 | sample_kwargs
224 | arguments for self.sample_posterior (generating and summarising posterior samples), namely:
225 | num_samples - number of samples to use (Default = 1000).
226 | batch_size - data batch size (keep low enough to fit on GPU, default 2048).
227 | use_gpu - use gpu for generating samples?
228 | export_slot
229 | adata.uns slot where to export results
230 | add_to_varm
231 | posterior distribution summary to export in adata.varm (['means', 'stds', 'q05', 'q95']).
232 | Returns
233 | -------
234 |
235 | """
236 |
237 | sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict()
238 |
239 | # generate samples from posterior distributions for all parameters
240 | # and compute mean, 5%/95% quantiles and standard deviation
241 | self.samples = self.sample_posterior(**sample_kwargs)
242 |
243 | # export posterior distribution summary for all parameters and
244 | # annotation (model, date, var, obs and cell type names) to anndata object
245 | adata.uns[export_slot] = self._export2adata(self.samples)
246 |
247 | # export estimated expression in each cluster
248 | # first convert np.arrays to pd.DataFrames with cell type and observation names
249 | # data frames contain mean, 5%/95% quantiles and standard deviation, denoted by a prefix
250 | for k in add_to_varm:
251 | sample_df = self.sample2df_vars(
252 | self.samples,
253 | site_name="per_cluster_mu_fg",
254 | summary_name=k,
255 | name_prefix="",
256 | )
257 | if scale_average_detection and (
258 | "detection_y_c" in list(self.samples[f"post_sample_{k}"].keys())
259 | ):
260 | sample_df = (
261 | sample_df * self.samples[f"post_sample_{k}"]["detection_y_c"].mean()
262 | )
263 | try:
264 | adata.varm[f"{k}_per_cluster_mu_fg"] = sample_df.loc[adata.var.index, :]
265 | except ValueError:
266 | # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names`
267 | adata.var[sample_df.columns] = sample_df.loc[adata.var.index, :]
268 |
269 | return adata
270 |
271 | def plot_QC(
272 | self,
273 | summary_name: str = "means",
274 | use_n_obs: int = 1000,
275 | scale_average_detection: bool = True,
276 | ):
277 | """
278 | Show quality control plots:
279 | 1. Reconstruction accuracy to assess if there are any issues with model training.
280 | The plot should be roughly diagonal, strong deviations signal problems that need to be investigated.
281 | Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random
282 | observations are used to speed up computation.
283 |
284 | 2. Estimated reference expression signatures (accounting for batch effect)
285 | compared to average expression in each cluster. We expect the signatures to be different
286 | from average when batch effects are present, however, when this plot is very different from
287 | a perfect diagonal, such as very low values on Y-axis, non-zero density everywhere)
288 | it indicates problems with signature estimation.
289 |
290 | Parameters
291 | ----------
292 | summary_name
293 | posterior distribution summary to use ('means', 'stds', 'q05', 'q95')
294 |
295 | Returns
296 | -------
297 |
298 | """
299 |
300 | super().plot_QC(summary_name=summary_name, use_n_obs=use_n_obs)
301 | plt.show()
302 |
303 | inf_aver = self.samples[f"post_sample_{summary_name}"]["per_cluster_mu_fg"].T
304 | if scale_average_detection and (
305 | "detection_y_c" in list(self.samples[f"post_sample_{summary_name}"].keys())
306 | ):
307 | inf_aver = (
308 | inf_aver
309 | * self.samples[f"post_sample_{summary_name}"]["detection_y_c"].mean()
310 | )
311 | aver = self._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY)
312 | aver = aver[self.factor_names_]
313 |
314 | plt.hist2d(
315 | np.log10(aver.values.flatten() + 1),
316 | np.log10(inf_aver.flatten() + 1),
317 | bins=50,
318 | norm=matplotlib.colors.LogNorm(),
319 | )
320 | plt.xlabel("Mean expression for every gene in every cluster")
321 | plt.ylabel("Estimated expression for every gene in every cluster")
322 | plt.show()
323 |
--------------------------------------------------------------------------------
/schierarchy/regression/_reference_module.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pyro
6 | import pyro.distributions as dist
7 | import torch
8 | from pyro.nn import PyroModule
9 | from scvi import REGISTRY_KEYS
10 | from scvi.nn import one_hot
11 |
12 |
13 | class RegressionBackgroundDetectionTechPyroModel(PyroModule):
14 | r"""
15 | Given cell type annotation for each cell, the corresponding reference cell type signatures :math:`g_{f,g}`,
16 | which represent the average mRNA count of each gene `g` in each cell type `f={1, .., F}`,
17 | are estimated from sc/snRNA-seq data using Negative Binomial regression,
18 | which allows to robustly combine data across technologies and batches.
19 |
20 | This model combines batches, and treats data :math:`D` as Negative Binomial distributed,
21 | given mean :math:`\mu` and overdispersion :math:`\alpha`:
22 |
23 | .. math::
24 | D_{c,g} \sim \mathtt{NB}(alpha=\alpha_{g}, mu=\mu_{c,g})
25 | .. math::
26 | \mu_{c,g} = (\mu_{f,g} + s_{e,g}) * y_e * y_{t,g}
27 |
28 | Which is equivalent to:
29 |
30 | .. math::
31 | D_{c,g} \sim \mathtt{Poisson}(\mathtt{Gamma}(\alpha_{f,g}, \alpha_{f,g} / \mu_{c,g}))
32 |
33 | Here, :math:`\mu_{f,g}` denotes average mRNA count in each cell type :math:`f` for each gene :math:`g`;
34 | :math:`y_c` denotes normalisation for each experiment :math:`e` to account for sequencing depth.
35 | :math:`y_{t,g}` denotes per gene :math:`g` detection efficiency normalisation for each technology :math:`t`.
36 |
37 | """
38 |
39 | def __init__(
40 | self,
41 | n_obs,
42 | n_vars,
43 | n_factors,
44 | n_batch,
45 | n_extra_categoricals=None,
46 | alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0},
47 | gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0},
48 | gene_add_mean_hyp_prior={
49 | "alpha": 1.0,
50 | "beta": 100.0,
51 | },
52 | detection_hyp_prior={"mean_alpha": 1.0, "mean_beta": 1.0},
53 | gene_tech_prior={"mean": 1, "alpha": 200},
54 | init_vals: Optional[dict] = None,
55 | ):
56 | """
57 |
58 | Parameters
59 | ----------
60 | n_obs
61 | n_vars
62 | n_factors
63 | n_batch
64 | n_extra_categoricals
65 | alpha_g_phi_hyp_prior
66 | gene_add_alpha_hyp_prior
67 | gene_add_mean_hyp_prior
68 | detection_hyp_prior
69 | gene_tech_prior
70 | """
71 |
72 | ############# Initialise parameters ################
73 | super().__init__()
74 |
75 | self.n_obs = n_obs
76 | self.n_vars = n_vars
77 | self.n_factors = n_factors
78 | self.n_batch = n_batch
79 | self.n_extra_categoricals = n_extra_categoricals
80 |
81 | self.alpha_g_phi_hyp_prior = alpha_g_phi_hyp_prior
82 | self.gene_add_alpha_hyp_prior = gene_add_alpha_hyp_prior
83 | self.gene_add_mean_hyp_prior = gene_add_mean_hyp_prior
84 | self.detection_hyp_prior = detection_hyp_prior
85 | self.gene_tech_prior = gene_tech_prior
86 |
87 | if (init_vals is not None) & (type(init_vals) is dict):
88 | self.np_init_vals = init_vals
89 | for k in init_vals.keys():
90 | self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k]))
91 |
92 | self.register_buffer(
93 | "detection_mean_hyp_prior_alpha",
94 | torch.tensor(self.detection_hyp_prior["mean_alpha"]),
95 | )
96 | self.register_buffer(
97 | "detection_mean_hyp_prior_beta",
98 | torch.tensor(self.detection_hyp_prior["mean_beta"]),
99 | )
100 | self.register_buffer(
101 | "gene_tech_prior_alpha",
102 | torch.tensor(self.gene_tech_prior["alpha"]),
103 | )
104 | self.register_buffer(
105 | "gene_tech_prior_beta",
106 | torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]),
107 | )
108 |
109 | self.register_buffer(
110 | "alpha_g_phi_hyp_prior_alpha",
111 | torch.tensor(self.alpha_g_phi_hyp_prior["alpha"]),
112 | )
113 | self.register_buffer(
114 | "alpha_g_phi_hyp_prior_beta",
115 | torch.tensor(self.alpha_g_phi_hyp_prior["beta"]),
116 | )
117 | self.register_buffer(
118 | "gene_add_alpha_hyp_prior_alpha",
119 | torch.tensor(self.gene_add_alpha_hyp_prior["alpha"]),
120 | )
121 | self.register_buffer(
122 | "gene_add_alpha_hyp_prior_beta",
123 | torch.tensor(self.gene_add_alpha_hyp_prior["beta"]),
124 | )
125 | self.register_buffer(
126 | "gene_add_mean_hyp_prior_alpha",
127 | torch.tensor(self.gene_add_mean_hyp_prior["alpha"]),
128 | )
129 | self.register_buffer(
130 | "gene_add_mean_hyp_prior_beta",
131 | torch.tensor(self.gene_add_mean_hyp_prior["beta"]),
132 | )
133 |
134 | self.register_buffer("ones", torch.ones((1, 1)))
135 | self.register_buffer("eps", torch.tensor(1e-8))
136 |
137 | ############# Define the model ################
138 | @staticmethod
139 | def _get_fn_args_from_batch_no_cat(tensor_dict):
140 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY]
141 | ind_x = tensor_dict["ind_x"].long().squeeze()
142 | batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY]
143 | label_index = tensor_dict[REGISTRY_KEYS.LABELS_KEY]
144 | return (x_data, ind_x, batch_index, label_index, label_index), {}
145 |
146 | @staticmethod
147 | def _get_fn_args_from_batch_cat(tensor_dict):
148 | x_data = tensor_dict[REGISTRY_KEYS.X_KEY]
149 | ind_x = tensor_dict["ind_x"].long().squeeze()
150 | batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY]
151 | label_index = tensor_dict[REGISTRY_KEYS.LABELS_KEY]
152 | extra_categoricals = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY]
153 | return (x_data, ind_x, batch_index, label_index, extra_categoricals), {}
154 |
155 | @property
156 | def _get_fn_args_from_batch(self):
157 | if self.n_extra_categoricals is not None:
158 | return self._get_fn_args_from_batch_cat
159 | else:
160 | return self._get_fn_args_from_batch_no_cat
161 |
162 | def create_plates(self, x_data, idx, batch_index, label_index, extra_categoricals):
163 | return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx)
164 |
165 | def list_obs_plate_vars(self):
166 | """Create a dictionary with the name of observation/minibatch plate,
167 | indexes of model args to provide to encoder,
168 | variable names that belong to the observation plate
169 | and the number of dimensions in non-plate axis of each variable"""
170 |
171 | return {
172 | "name": "obs_plate",
173 | "input": [], # expression data + (optional) batch index
174 | "input_transform": [], # how to transform input data before passing to NN
175 | "sites": {},
176 | }
177 |
178 | def forward(self, x_data, idx, batch_index, label_index, extra_categoricals):
179 |
180 | obs2sample = one_hot(batch_index, self.n_batch)
181 | obs2label = one_hot(label_index, self.n_factors)
182 | if self.n_extra_categoricals is not None:
183 | obs2extra_categoricals = torch.cat(
184 | [
185 | one_hot(
186 | extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)),
187 | n_cat,
188 | )
189 | for i, n_cat in enumerate(self.n_extra_categoricals)
190 | ],
191 | dim=1,
192 | )
193 |
194 | obs_plate = self.create_plates(
195 | x_data, idx, batch_index, label_index, extra_categoricals
196 | )
197 |
198 | # =====================Per-cluster average mRNA count ======================= #
199 | # \mu_{f,g}
200 | per_cluster_mu_fg = pyro.sample(
201 | "per_cluster_mu_fg",
202 | dist.Gamma(self.ones, self.ones)
203 | .expand([self.n_factors, self.n_vars])
204 | .to_event(2),
205 | )
206 |
207 | # =====================Gene-specific multiplicative component ======================= #
208 | # `y_{t, g}` per gene multiplicative effect that explains the difference
209 | # in sensitivity between genes in each technology or covariate effect
210 | if self.n_extra_categoricals is not None:
211 | detection_tech_gene_tg = pyro.sample(
212 | "detection_tech_gene_tg",
213 | dist.Gamma(
214 | self.ones * self.gene_tech_prior_alpha,
215 | self.ones * self.gene_tech_prior_beta,
216 | )
217 | .expand([np.sum(self.n_extra_categoricals), self.n_vars])
218 | .to_event(2),
219 | )
220 |
221 | # =====================Cell-specific detection efficiency ======================= #
222 | # y_c with hierarchical mean prior
223 | detection_mean_y_e = pyro.sample(
224 | "detection_mean_y_e",
225 | dist.Gamma(
226 | self.ones * self.detection_mean_hyp_prior_alpha,
227 | self.ones * self.detection_mean_hyp_prior_beta,
228 | )
229 | .expand([self.n_batch, 1])
230 | .to_event(2),
231 | )
232 | detection_y_c = obs2sample @ detection_mean_y_e # (self.n_obs, 1)
233 |
234 | # =====================Gene-specific additive component ======================= #
235 | # s_{e,g} accounting for background, free-floating RNA
236 | s_g_gene_add_alpha_hyp = pyro.sample(
237 | "s_g_gene_add_alpha_hyp",
238 | dist.Gamma(
239 | self.ones * self.gene_add_alpha_hyp_prior_alpha,
240 | self.ones * self.gene_add_alpha_hyp_prior_beta,
241 | ),
242 | )
243 | s_g_gene_add_mean = pyro.sample(
244 | "s_g_gene_add_mean",
245 | dist.Gamma(
246 | self.gene_add_mean_hyp_prior_alpha,
247 | self.gene_add_mean_hyp_prior_beta,
248 | )
249 | .expand([self.n_batch, 1])
250 | .to_event(2),
251 | ) # (self.n_batch)
252 | s_g_gene_add_alpha_e_inv = pyro.sample(
253 | "s_g_gene_add_alpha_e_inv",
254 | dist.Exponential(s_g_gene_add_alpha_hyp)
255 | .expand([self.n_batch, 1])
256 | .to_event(2),
257 | ) # (self.n_batch)
258 | s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2)
259 |
260 | s_g_gene_add = pyro.sample(
261 | "s_g_gene_add",
262 | dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean)
263 | .expand([self.n_batch, self.n_vars])
264 | .to_event(2),
265 | ) # (self.n_batch, n_vars)
266 |
267 | # =====================Gene-specific overdispersion ======================= #
268 | alpha_g_phi_hyp = pyro.sample(
269 | "alpha_g_phi_hyp",
270 | dist.Gamma(
271 | self.ones * self.alpha_g_phi_hyp_prior_alpha,
272 | self.ones * self.alpha_g_phi_hyp_prior_beta,
273 | ),
274 | )
275 | alpha_g_inverse = pyro.sample(
276 | "alpha_g_inverse",
277 | dist.Exponential(alpha_g_phi_hyp).expand([1, self.n_vars]).to_event(2),
278 | ) # (self.n_batch or 1, self.n_vars)
279 |
280 | # =====================Expected expression ======================= #
281 |
282 | # overdispersion
283 | alpha = self.ones / alpha_g_inverse.pow(2)
284 | # biological expression
285 | mu = (
286 | obs2label @ per_cluster_mu_fg
287 | + obs2sample @ s_g_gene_add # contaminating RNA
288 | ) * detection_y_c # cell-specific normalisation
289 | if self.n_extra_categoricals is not None:
290 | # gene-specific normalisation for covatiates
291 | mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg)
292 | # total_count, logits = _convert_mean_disp_to_counts_logits(
293 | # mu, alpha, eps=self.eps
294 | # )
295 |
296 | # =====================DATA likelihood ======================= #
297 | # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
298 | with obs_plate:
299 | pyro.sample(
300 | "data_target",
301 | dist.GammaPoisson(concentration=alpha, rate=alpha / mu),
302 | # dist.NegativeBinomial(total_count=total_count, logits=logits),
303 | obs=x_data,
304 | )
305 |
306 | # =====================Other functions======================= #
307 | def compute_expected(self, samples, adata_manager, ind_x=None):
308 | r"""Compute expected expression of each gene in each cell. Useful for evaluating how well
309 | the model learned expression pattern of all genes in the data.
310 |
311 | Parameters
312 | ----------
313 | samples
314 | dictionary with values of the posterior
315 | adata
316 | registered anndata
317 | ind_x
318 | indices of cells to use (to reduce data size)
319 | """
320 | if ind_x is None:
321 | ind_x = np.arange(adata_manager.adata.n_obs).astype(int)
322 | else:
323 | ind_x = ind_x.astype(int)
324 | obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY)
325 | obs2sample = (
326 | pd.get_dummies(obs2sample.flatten()).values[ind_x, :].astype("float32")
327 | )
328 | obs2label = adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY)
329 | obs2label = (
330 | pd.get_dummies(obs2label.flatten()).values[ind_x, :].astype("float32")
331 | )
332 | if self.n_extra_categoricals is not None:
333 | extra_categoricals = adata_manager.get_from_registry(
334 | REGISTRY_KEYS.CAT_COVS_KEY
335 | )
336 | obs2extra_categoricals = np.concatenate(
337 | [
338 | pd.get_dummies(extra_categoricals.iloc[ind_x, i]).astype("float32")
339 | for i, n_cat in enumerate(self.n_extra_categoricals)
340 | ],
341 | axis=1,
342 | )
343 |
344 | alpha = 1 / np.power(samples["alpha_g_inverse"], 2)
345 |
346 | mu = (
347 | np.dot(obs2label, samples["per_cluster_mu_fg"])
348 | + np.dot(obs2sample, samples["s_g_gene_add"])
349 | ) * np.dot(
350 | obs2sample, samples["detection_mean_y_e"]
351 | ) # samples["detection_y_c"][ind_x, :]
352 | if self.n_extra_categoricals is not None:
353 | mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"])
354 |
355 | return {"mu": mu, "alpha": alpha}
356 |
357 | def compute_expected_subset(self, samples, adata_manager, fact_ind, cell_ind):
358 | r"""Compute expected expression of each gene in each cell that comes from
359 | a subset of factors (cell types) or cells.
360 |
361 | Useful for evaluating how well the model learned expression pattern of all genes in the data.
362 |
363 | Parameters
364 | ----------
365 | samples
366 | dictionary with values of the posterior
367 | adata
368 | registered anndata
369 | fact_ind
370 | indices of factors/cell types to use
371 | cell_ind
372 | indices of cells to use
373 | """
374 | obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY)
375 | obs2sample = pd.get_dummies(obs2sample.flatten())
376 | obs2label = adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY)
377 | obs2label = pd.get_dummies(obs2label.flatten())
378 | if self.n_extra_categoricals is not None:
379 | extra_categoricals = adata_manager.get_from_registry(
380 | REGISTRY_KEYS.CAT_COVS_KEY
381 | )
382 | obs2extra_categoricals = np.concatenate(
383 | [
384 | pd.get_dummies(extra_categoricals.iloc[:, i])
385 | for i, n_cat in enumerate(self.n_extra_categoricals)
386 | ],
387 | axis=1,
388 | )
389 |
390 | alpha = 1 / np.power(samples["alpha_g_inverse"], 2)
391 |
392 | mu = (
393 | np.dot(
394 | obs2label[cell_ind, fact_ind], samples["per_cluster_mu_fg"][fact_ind, :]
395 | )
396 | + np.dot(obs2sample[cell_ind, :], samples["s_g_gene_add"])
397 | ) * np.dot(
398 | obs2sample, samples["detection_mean_y_e"]
399 | ) # samples["detection_y_c"]
400 | if self.n_extra_categoricals is not None:
401 | mu = mu * np.dot(
402 | obs2extra_categoricals[cell_ind, :], samples["detection_tech_gene_tg"]
403 | )
404 |
405 | return {"mu": mu, "alpha": alpha}
406 |
407 | def normalise(self, samples, adata_manager, adata):
408 | r"""Normalise expression data by estimated technical variables.
409 |
410 | Parameters
411 | ----------
412 | samples
413 | dictionary with values of the posterior
414 | adata
415 | registered anndata
416 |
417 | """
418 | obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY)
419 | obs2sample = pd.get_dummies(obs2sample.flatten())
420 | if self.n_extra_categoricals is not None:
421 | extra_categoricals = adata_manager.get_from_registry(
422 | REGISTRY_KEYS.CAT_COVS_KEY
423 | )
424 | obs2extra_categoricals = np.concatenate(
425 | [
426 | pd.get_dummies(extra_categoricals.iloc[:, i])
427 | for i, n_cat in enumerate(self.n_extra_categoricals)
428 | ],
429 | axis=1,
430 | )
431 | # get counts matrix
432 | corrected = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
433 | # normalise per-sample scaling
434 | corrected = corrected / np.dot(obs2sample, samples["detection_mean_y_e"])
435 | # normalise per gene effects
436 | if self.n_extra_categoricals is not None:
437 | corrected = corrected / np.dot(
438 | obs2extra_categoricals, samples["detection_tech_gene_tg"]
439 | )
440 |
441 | # remove additive sample effects
442 | corrected = corrected - np.dot(obs2sample, samples["s_g_gene_add"])
443 |
444 | # set minimum value to 0 for each gene (a hack to avoid negative values)
445 | corrected = corrected - corrected.min()
446 |
447 | return corrected
448 |
--------------------------------------------------------------------------------
/schierarchy/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/schierarchy/utils/__init__.py
--------------------------------------------------------------------------------
/schierarchy/utils/data_transformation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import stats
3 |
4 |
5 | def data_to_zero_truncated_cdf(x):
6 | """
7 | Quantile transformation of expression. To convert matrix apply along cell dimension
8 | Parameters
9 | ----------
10 | x: log1p(expression / total_expression * 1e4) transformed expression vector (for 1 gene)
11 |
12 | Returns
13 | -------
14 | quantile normalised expression vector
15 | """
16 | x_sorted = np.sort(x)
17 | indx_sorted = np.argsort(x)
18 | x_sorted = x[indx_sorted]
19 | try:
20 | zero_ind = np.where(x_sorted == 0)[0][-1]
21 | p = np.concatenate(
22 | [
23 | np.zeros(zero_ind),
24 | 1.0
25 | * np.arange(len(x_sorted) - zero_ind)
26 | / (len(x_sorted) - zero_ind - 1),
27 | ]
28 | )
29 | except IndexError:
30 | p = 1.0 * np.arange(len(x_sorted)) / (len(x_sorted))
31 | cdfs = np.zeros_like(x)
32 | cdfs[indx_sorted] = p
33 | return cdfs
34 |
35 |
36 | def cdf_to_pseudonorm(cdf, clip=1e-3):
37 | """
38 | Parameters
39 | ----------
40 | cdf: Quantile transformed expression
41 | clip: Clipping normal distriubtion (position of zeros)
42 |
43 | Returns
44 | -------
45 | Pseudonormal expression
46 | """
47 | return stats.norm.ppf(np.maximum(np.minimum(cdf, 1.0 - clip), clip))
48 |
49 |
50 | def variance_normalistaion(x, std, clip=10):
51 | """
52 | Normalisation by standard deviation
53 | Parameters
54 | ----------
55 | x: Expression data
56 | std: std in the right dimensions
57 | clip: clipping threshold for standard deviation
58 |
59 | Returns
60 | -------
61 |
62 | """
63 | if clip is not None:
64 | return x / np.minimum(std, clip)
65 | else:
66 | return x / np.minimum(std)
67 |
--------------------------------------------------------------------------------
/schierarchy/utils/simulation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from scvi.data import synthetic_iid
4 |
5 |
6 | def random_tree(hlevels):
7 | """
8 | Generate random tree from the list of level widths and return as a dict of edges
9 |
10 | Parameters
11 | ----------
12 | hlevels
13 | List of number of classes per level
14 |
15 | Returns
16 | -------
17 | List of edges (parent:[children]) between levels
18 |
19 |
20 | Examples
21 | --------
22 | >>> hlevels = [4,19,20]
23 | >>> tree = random_tree(hlevels)
24 | """
25 |
26 | edge_dicts = []
27 | for i_level in range(len(hlevels) - 1):
28 | level1 = np.arange(hlevels[i_level])
29 | level2 = np.arange(hlevels[i_level + 1])
30 |
31 | edge_dict = {k: [] for k in level1}
32 | garant = np.random.choice(
33 | level2, size=len(level1), replace=False
34 | ) # select garanteed
35 | for i in range(len(level1)):
36 | edge_dict[i].append(garant[i])
37 | rest = np.random.choice(level1, size=len(level2)) # distribute the rest
38 | for i in range(len(level2)):
39 | if i not in garant:
40 | edge_dict[rest[i]].append(i)
41 |
42 | edge_dicts.append(edge_dict)
43 | return edge_dicts
44 |
45 |
46 | def invert_dict(d):
47 | inv_d = {}
48 | for k, v_list in d.items():
49 | for v in v_list:
50 | inv_d[v] = k
51 | return inv_d
52 |
53 |
54 | def plot_tree(hlevels, tree):
55 | """
56 | Graph plotting
57 |
58 | Parameters
59 | ----------
60 | hlevels
61 | List of number of classes per level
62 | tree
63 | List of edges between levels
64 | """
65 | ylim = np.max(hlevels[-1])
66 | xlim = len(hlevels)
67 | x_offset = [(ylim - hlevels[i]) / 2 for i in range(xlim)]
68 | x_ploints = []
69 | y_ploints = []
70 |
71 | x_lines = [[], []]
72 | y_lines = [[], []]
73 |
74 | for i in range(xlim):
75 | for j in range(hlevels[i]):
76 | y_ploints.append(i)
77 | x_ploints.append(j + x_offset[i])
78 |
79 | for i in range(xlim - 1):
80 | for k, v in tree[i].items():
81 | for v_i in v:
82 | x_lines[0].append(k + x_offset[i])
83 | x_lines[1].append(v_i + x_offset[i + 1])
84 |
85 | y_lines[0].append(i)
86 | y_lines[1].append(i + 1)
87 |
88 | plt.scatter(x_ploints, y_ploints)
89 | plt.plot(x_lines, y_lines, color="black", alpha=0.5)
90 | plt.gca().invert_yaxis()
91 |
92 |
93 | def hierarchical_iid(hlevels, *args, **kwargs):
94 | """
95 | Wrapper above scvi.data.synthetic_iid to produce hierarhical labels
96 |
97 | Parameters
98 | ----------
99 | hlevels
100 | List of number of classes per level
101 |
102 | Returns
103 | -------
104 | AnnData with batch info (``.obs['batch']``), label info (``.obs['labels']``)
105 | on level i (``.obs['level_i']``). List of edges (parent:[children]) between levels (``.uns['tree']``)
106 | """
107 | tree = random_tree(hlevels)
108 |
109 | bottom_level_n_labels = hlevels[-1]
110 | synthetic_data = synthetic_iid(n_labels=bottom_level_n_labels, *args, **kwargs)
111 | from re import sub
112 |
113 | synthetic_data.obs["labels"] = [
114 | int(sub("label_", "", i)) for i in synthetic_data.obs["labels"]
115 | ]
116 |
117 | levels = ["labels"] + [f"level_{i}" for i in range(len(hlevels) - 2, -1, -1)]
118 | for i in range(len(levels) - 1):
119 | level_up = synthetic_data.obs[levels[i]].apply(
120 | lambda x: invert_dict(tree[len(tree) - 1 - i])[x]
121 | )
122 | synthetic_data.obs[levels[i + 1]] = level_up
123 |
124 | synthetic_data.obs = synthetic_data.obs.rename(
125 | columns={"labels": f"level_{len(levels) - 1}"}
126 | )
127 | synthetic_data.uns["tree"] = tree
128 | return synthetic_data
129 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry
4 |
5 | import setuptools
6 |
7 | if __name__ == "__main__":
8 | setuptools.setup(name="schierarchy")
9 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dissatisfaction-ai/scHierarchy/c0deba823adfe3fde344db2b05ce2dea7963e80e/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_hierarchical_logist.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from scvi.data import synthetic_iid
5 |
6 | from schierarchy import LogisticModel
7 | from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf
8 | from schierarchy.utils.simulation import hierarchical_iid
9 |
10 |
11 | def test_hierarchical_logist():
12 | save_path = "./cell2location_model_test"
13 | hlevels = [4, 10, 20]
14 | dataset = hierarchical_iid(hlevels)
15 | level_keys = [f"level_{i}" for i in range(len(hlevels))]
16 | # tree = dataset.uns["tree"]
17 | del dataset.uns["tree"]
18 | dataset.layers["cdf"] = np.apply_along_axis(
19 | data_to_zero_truncated_cdf, 0, dataset.X
20 | )
21 |
22 | for learning_mode in [
23 | "fixed-sigma",
24 | "learn-sigma-single",
25 | "learn-sigma-gene",
26 | "learn-sigma-celltype",
27 | "learn-sigma-gene-celltype",
28 | ]:
29 | for use_dropout in [
30 | {"use_dropout": True},
31 | {"use_dropout": False},
32 | {"use_gene_dropout": True},
33 | {"use_gene_dropout": False},
34 | ]:
35 | LogisticModel.setup_anndata(dataset, layer="cdf", level_keys=level_keys)
36 | # train regression model to get signatures of cell types
37 | sc_model = LogisticModel(
38 | dataset,
39 | laplace_learning_mode=learning_mode,
40 | **use_dropout,
41 | )
42 | # test full data training
43 | sc_model.train(max_epochs=10, batch_size=None)
44 | # test minibatch training
45 | sc_model.train(max_epochs=10, batch_size=100)
46 | # export the estimated cell abundance (summary of the posterior distribution)
47 | dataset = sc_model.export_posterior(
48 | dataset, sample_kwargs={"num_samples": 10}
49 | )
50 | # test plot_QC'
51 | # sc_model.plot_QC()
52 | # test save/load
53 | sc_model.save(save_path, overwrite=True, save_anndata=True)
54 | sc_model = LogisticModel.load(save_path)
55 | os.system(f"rm -rf {save_path}")
56 |
57 |
58 | def test_hierarchical_logist_prediction():
59 | hlevels = [4, 10, 20]
60 | dataset = hierarchical_iid(hlevels)
61 | level_keys = [f"level_{i}" for i in range(len(hlevels))]
62 | del dataset.uns["tree"]
63 | dataset.layers["cdf"] = np.apply_along_axis(
64 | data_to_zero_truncated_cdf, 0, dataset.X
65 | )
66 |
67 | LogisticModel.setup_anndata(dataset, layer="cdf", level_keys=level_keys)
68 |
69 | # train regression model to get signatures of cell types
70 | sc_model = LogisticModel(dataset)
71 | # test full data training
72 | sc_model.train(max_epochs=10, batch_size=None)
73 | # test prediction
74 | dataset2 = synthetic_iid(n_labels=5)
75 | dataset2.layers["cdf"] = np.apply_along_axis(
76 | data_to_zero_truncated_cdf, 0, dataset2.X
77 | )
78 | dataset2 = sc_model.export_posterior(
79 | dataset2, prediction=True, sample_kwargs={"num_samples": 10}
80 | )
81 |
--------------------------------------------------------------------------------
/tests/test_hierarchical_logist_nohierarchy.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | from schierarchy import LogisticModel
6 | from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf
7 | from schierarchy.utils.simulation import hierarchical_iid
8 |
9 |
10 | def test_hierarchical_logist_nohierarchy():
11 | save_path = "./cell2location_model_test"
12 | hlevels = [20]
13 | dataset = hierarchical_iid(hlevels)
14 | level_keys = [f"level_{i}" for i in range(len(hlevels))]
15 | # tree = dataset.uns["tree"]
16 | del dataset.uns["tree"]
17 | dataset.layers["cdf"] = np.apply_along_axis(
18 | data_to_zero_truncated_cdf, 0, dataset.X
19 | )
20 | for learning_mode in [
21 | "fixed-sigma",
22 | "learn-sigma-single",
23 | "learn-sigma-gene",
24 | "learn-sigma-celltype",
25 | "learn-sigma-gene-celltype",
26 | ]:
27 | LogisticModel.setup_anndata(dataset, layer="cdf", level_keys=level_keys)
28 |
29 | # train regression model to get signatures of cell types
30 | sc_model = LogisticModel(
31 | dataset,
32 | laplace_learning_mode=learning_mode,
33 | )
34 | # test full data training
35 | sc_model.train(max_epochs=10, batch_size=None)
36 | # test minibatch training
37 | sc_model.train(max_epochs=10, batch_size=100)
38 | # export the estimated cell abundance (summary of the posterior distribution)
39 | dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10})
40 | dataset = sc_model.export_posterior(
41 | dataset, sample_kwargs={
42 | "use_median": True,
43 | },
44 | use_quantiles=True,
45 | add_to_varm=['q50'],
46 | )
47 | # test plot_QC'
48 | # sc_model.plot_QC()
49 | # test save/load
50 | sc_model.save(save_path, overwrite=True, save_anndata=True)
51 | sc_model = LogisticModel.load(save_path)
52 | os.system(f"rm -rf {save_path}")
53 |
--------------------------------------------------------------------------------
/tests/test_regression.py:
--------------------------------------------------------------------------------
1 | from scvi.data import synthetic_iid
2 |
3 | from schierarchy import RegressionModel
4 |
5 |
6 | def test_regression():
7 | save_path = "./cell2location_model_test"
8 | dataset = synthetic_iid(n_labels=5)
9 | RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch")
10 |
11 | # train regression model to get signatures of cell types
12 | sc_model = RegressionModel(dataset)
13 | # test full data training
14 | sc_model.train(max_epochs=10, batch_size=None)
15 | # test minibatch training
16 | sc_model.train(max_epochs=10, batch_size=100)
17 | # export the estimated cell abundance (summary of the posterior distribution)
18 | dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10})
19 | # test plot_QC'
20 | sc_model.plot_QC()
21 | # test save/load
22 | sc_model.save(save_path, overwrite=True, save_anndata=True)
23 | sc_model = RegressionModel.load(save_path)
24 | # export estimated expression in each cluster
25 | if "means_per_cluster_mu_fg" in dataset.varm.keys():
26 | inf_aver = dataset.varm["means_per_cluster_mu_fg"][
27 | [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]
28 | ].copy()
29 | else:
30 | inf_aver = dataset.var[
31 | [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]
32 | ].copy()
33 | inf_aver.columns = dataset.uns["mod"]["factor_names"]
34 |
--------------------------------------------------------------------------------