├── .editorconfig
├── .github
└── workflows
│ ├── release.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── codecov.yml
├── docs
├── Makefile
├── _templates
│ ├── autosummary
│ │ └── class.rst
│ └── class_no_inherited.rst
├── api
│ └── index.md
├── conf.py
├── extensions
│ └── typed_returns.py
├── index.md
├── make.bat
├── references.bib
├── references.md
├── release_notes
│ ├── index.rst
│ └── v0.1.0.rst
└── tutorial.ipynb
├── pyproject.toml
├── readthedocs.yml
├── setup.py
├── tests
├── __init__.py
└── test_velovi.py
└── velovi
├── __init__.py
├── _constants.py
├── _model.py
├── _module.py
└── _utils.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | tags:
6 | - "*.*.*"
7 |
8 | jobs:
9 | release:
10 | name: Release
11 | runs-on: ubuntu-latest
12 | steps:
13 | # will use ref/SHA that triggered it
14 | - name: Checkout code
15 | uses: actions/checkout@v3
16 |
17 | - name: Set up Python 3.10
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: "3.9"
21 |
22 | - name: Install poetry
23 | uses: abatilo/actions-poetry@v2.0.0
24 | with:
25 | poetry-version: 1.4.2
26 |
27 | - name: Build project for distribution
28 | run: poetry build
29 |
30 | - name: Check Version
31 | id: check-version
32 | run: |
33 | [[ "$(poetry version --short)" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]] \
34 | || echo ::set-output name=prerelease::true
35 |
36 | - name: Publish to PyPI
37 | env:
38 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }}
39 | run: poetry publish
40 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: velovi
5 |
6 | on:
7 | push:
8 | branches: [main]
9 | pull_request:
10 | branches: [main]
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ["3.9", "3.10", "3.11"]
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Cache pip
26 | uses: actions/cache@v2
27 | with:
28 | path: ~/.cache/pip
29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
30 | restore-keys: |
31 | ${{ runner.os }}-pip-
32 | - name: Install dependencies
33 | run: |
34 | pip install pytest-cov
35 | pip install .[dev]
36 | - name: Test with pytest
37 | run: |
38 | pytest --cov-report=xml --cov=velovi
39 | - name: After success
40 | run: |
41 | bash <(curl -s https://codecov.io/bash)
42 | pip list
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # DS_Store
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # vscode
135 | .vscode/settings.json
136 | docs/api/reference/
137 | *.h5ad
138 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - commit
6 | - push
7 | minimum_pre_commit_version: 2.16.0
8 | repos:
9 | - repo: https://github.com/psf/black
10 | rev: "23.1.0"
11 | hooks:
12 | - id: black
13 | - repo: https://github.com/asottile/blacken-docs
14 | rev: 1.13.0
15 | hooks:
16 | - id: blacken-docs
17 | - repo: https://github.com/pre-commit/mirrors-prettier
18 | rev: v3.0.0-alpha.6
19 | hooks:
20 | - id: prettier
21 | # Newer versions of node don't work on systems that have an older version of GLIBC
22 | # (in particular Ubuntu 18.04 and Centos 7)
23 | # EOL of Centos 7 is in 2024-06, we can probably get rid of this then.
24 | # See https://github.com/scverse/cookiecutter-scverse/issues/143 and
25 | # https://github.com/jupyterlab/jupyterlab/issues/12675
26 | language_version: "17.9.1"
27 | - repo: https://github.com/charliermarsh/ruff-pre-commit
28 | rev: v0.0.254
29 | hooks:
30 | - id: ruff
31 | args: [--fix, --exit-non-zero-on-fix]
32 | - repo: https://github.com/pre-commit/pre-commit-hooks
33 | rev: v4.4.0
34 | hooks:
35 | - id: detect-private-key
36 | - id: check-ast
37 | - id: end-of-file-fixer
38 | - id: mixed-line-ending
39 | args: [--fix=lf]
40 | - id: trailing-whitespace
41 | - id: check-case-conflict
42 | - repo: local
43 | hooks:
44 | - id: forbid-to-commit
45 | name: Don't commit rej files
46 | entry: |
47 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
48 | Fix the merge conflicts manually and remove the .rej files.
49 | language: fail
50 | files: '.*\.rej$'
51 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, Yosef Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # velovi
2 |
3 | [![Tests][badge-tests]][link-tests]
4 | [![Documentation][badge-docs]][link-docs]
5 |
6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/yoseflab/velovi/test.yml?branch=main
7 | [link-tests]: https://github.com/yoseflab/velovi/actions/workflows/test.yml
8 | [badge-docs]: https://img.shields.io/readthedocs/velovi
9 |
10 | 🚧 :warning: This package is no longer being actively developed or maintained. Please use the
11 | [scvi-tools](https://github.com/scverse/scvi-tools) package instead. See this
12 | [thread](https://github.com/scverse/scvi-tools/issues/2610) for more details. :warning: 🚧
13 |
14 | Variational inference for RNA velocity. This is an experimental repo for the veloVI model. Installation instructions and tutorials are in the docs.
15 |
16 | ## Getting started
17 |
18 | Please refer to the [documentation][link-docs].
19 |
20 | ## Installation
21 |
22 | You need to have Python 3.8 or newer installed on your system. If you don't have
23 | Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
24 |
25 | There are several alternative options to install velovi:
26 |
27 |
34 |
35 | 1. Install the latest release on PyPI:
36 |
37 | ```bash
38 | pip install velovi
39 | ```
40 |
41 | 2. Install the latest development version:
42 |
43 | ```bash
44 | pip install git+https://github.com/yoseflab/velovi.git@main
45 | ```
46 |
47 | ## Release notes
48 |
49 | See the [changelog][changelog].
50 |
51 | ## Contact
52 |
53 | For questions and help requests, you can reach out in the [scverse discourse][scverse-discourse].
54 | If you found a bug, please use the [issue tracker][issue-tracker].
55 |
56 | ## Citation
57 |
58 | ```
59 | @article{gayoso2022deep,
60 | title={Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells},
61 | author={Gayoso, Adam and Weiler, Philipp and Lotfollahi, Mohammad and Klein, Dominik and Hong, Justin and Streets, Aaron M and Theis, Fabian J and Yosef, Nir},
62 | journal={bioRxiv},
63 | pages={2022--08},
64 | year={2022},
65 | publisher={Cold Spring Harbor Laboratory}
66 | }
67 | ```
68 |
69 | [scverse-discourse]: https://discourse.scverse.org/
70 | [issue-tracker]: https://github.com/yoseflab/velovi/issues
71 | [changelog]: https://velovi.readthedocs.io/latest/changelog.html
72 | [link-docs]: https://velovi.readthedocs.io
73 | [link-api]: https://velovi.readthedocs.io/latest/api.html
74 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # Run to check if valid
2 | # curl --data-binary @codecov.yml https://codecov.io/validate
3 | coverage:
4 | status:
5 | project:
6 | default:
7 | target: 80%
8 | threshold: 1%
9 | patch: off
10 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = scvi
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/_templates/class_no_inherited.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 | :show-inheritance:
9 |
10 | {% block attributes %}
11 | {% if attributes %}
12 | Attributes table
13 | ~~~~~~~~~~~~~~~~
14 |
15 | .. autosummary::
16 | {% for item in attributes %}
17 | {%- if item not in inherited_members%}
18 | ~{{ fullname }}.{{ item }}
19 | {%- endif -%}
20 | {%- endfor %}
21 | {% endif %}
22 | {% endblock %}
23 |
24 |
25 | {% block methods %}
26 | {% if methods %}
27 | Methods table
28 | ~~~~~~~~~~~~~~
29 |
30 | .. autosummary::
31 | {% for item in methods %}
32 | {%- if item != '__init__' and item not in inherited_members%}
33 | ~{{ fullname }}.{{ item }}
34 | {%- endif -%}
35 |
36 | {%- endfor %}
37 | {% endif %}
38 | {% endblock %}
39 |
40 | {% block attributes_documentation %}
41 | {% if attributes %}
42 | Attributes
43 | ~~~~~~~~~~
44 |
45 | {% for item in attributes %}
46 | {%- if item not in inherited_members%}
47 |
48 | .. autoattribute:: {{ [objname, item] | join(".") }}
49 | {%- endif -%}
50 | {%- endfor %}
51 |
52 | {% endif %}
53 | {% endblock %}
54 |
55 | {% block methods_documentation %}
56 | {% if methods %}
57 | Methods
58 | ~~~~~~~
59 |
60 | {% for item in methods %}
61 | {%- if item != '__init__' and item not in inherited_members%}
62 |
63 | .. automethod:: {{ [objname, item] | join(".") }}
64 | {%- endif -%}
65 | {%- endfor %}
66 |
67 | {% endif %}
68 | {% endblock %}
69 |
--------------------------------------------------------------------------------
/docs/api/index.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: velovi
5 | ```
6 |
7 | ```{eval-rst}
8 | .. autosummary::
9 | :toctree: reference/
10 | :nosignatures:
11 |
12 | VELOVI
13 | ```
14 |
15 | ```{eval-rst}
16 | .. autosummary::
17 | :toctree: reference/
18 | :template: class_no_inherited.rst
19 | :nosignatures:
20 |
21 | VELOVAE
22 | ```
23 |
24 | ```{eval-rst}
25 | .. autosummary::
26 | :toctree: reference/
27 | :nosignatures:
28 |
29 | get_permutation_scores
30 | ```
31 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import subprocess
3 | import os
4 | import importlib
5 | import inspect
6 | import re
7 | import sys
8 | from datetime import datetime
9 | from importlib.metadata import metadata
10 | from pathlib import Path
11 |
12 | HERE = Path(__file__).parent
13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")]
14 |
15 |
16 | # -- Project information -----------------------------------------------------
17 |
18 | project_name = "velovi"
19 | info = metadata(project_name)
20 | package_name = "velovi"
21 | author = info["Author"]
22 | copyright = f"{datetime.now():%Y}, {author}."
23 | version = info["Version"]
24 | repository_url = f"https://github.com/yoseflab/{project_name}"
25 |
26 | # The full version, including alpha/beta/rc tags
27 | release = info["Version"]
28 |
29 | bibtex_bibfiles = ["references.bib"]
30 | templates_path = ["_templates"]
31 | nitpicky = True # Warn about broken links
32 | needs_sphinx = "4.0"
33 |
34 | html_context = {
35 | "display_github": True, # Integrate GitHub
36 | "github_user": "yoseflab", # Username
37 | "github_repo": project_name, # Repo name
38 | "github_version": "main", # Version
39 | "conf_py_path": "/docs/", # Path in the checkout to the docs root
40 | }
41 |
42 | # -- General configuration ---------------------------------------------------
43 |
44 | # Add any Sphinx extension module names here, as strings.
45 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
46 | extensions = [
47 | "myst_nb",
48 | "sphinx.ext.autodoc",
49 | "sphinx.ext.linkcode",
50 | "sphinx.ext.intersphinx",
51 | "sphinx.ext.autosummary",
52 | "sphinx.ext.napoleon",
53 | "sphinxcontrib.bibtex",
54 | "sphinx.ext.mathjax",
55 | "sphinx.ext.extlinks",
56 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
57 | "sphinx_copybutton",
58 | ]
59 |
60 | autosummary_generate = True
61 | autodoc_member_order = "bysource"
62 | default_role = "literal"
63 | autodoc_typehints = "description"
64 | bibtex_reference_style = "author_year"
65 | napoleon_google_docstring = True
66 | napoleon_numpy_docstring = True
67 | napoleon_include_init_with_doc = False
68 | napoleon_use_rtype = True # having a separate entry generally helps readability
69 | napoleon_use_param = True
70 | myst_enable_extensions = [
71 | "amsmath",
72 | "colon_fence",
73 | "deflist",
74 | "dollarmath",
75 | "html_image",
76 | "html_admonition",
77 | ]
78 | myst_url_schemes = ("http", "https", "mailto")
79 | nb_output_stderr = "remove"
80 | nb_execution_mode = "off"
81 | nb_merge_streams = True
82 |
83 | source_suffix = {
84 | ".rst": "restructuredtext",
85 | ".ipynb": "myst-nb",
86 | ".myst": "myst-nb",
87 | }
88 |
89 | intersphinx_mapping = {
90 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
91 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None),
92 | "matplotlib": ("https://matplotlib.org/", None),
93 | "numpy": ("https://numpy.org/doc/stable/", None),
94 | "pandas": ("https://pandas.pydata.org/docs/", None),
95 | "python": ("https://docs.python.org/3", None),
96 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
97 | "sklearn": ("https://scikit-learn.org/stable/", None),
98 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None),
99 | "jax": ("https://jax.readthedocs.io/en/latest/", None),
100 | "torch": ("https://pytorch.org/docs/master/", None),
101 | "plottable": ("https://plottable.readthedocs.io/en/latest/", None),
102 | "scvi-tools": ("https://docs.scvi-tools.org/en/stable/", None),
103 | "mudata": ("https://mudata.readthedocs.io/en/latest/", None),
104 | }
105 |
106 | # List of patterns, relative to source directory, that match files and
107 | # directories to ignore when looking for source files.
108 | # This pattern also affects html_static_path and html_extra_path.
109 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
110 |
111 | # extlinks config
112 | extlinks = {
113 | "issue": (f"{repository_url}/issues/%s", "#%s"),
114 | "pr": (f"{repository_url}/pull/%s", "#%s"),
115 | "ghuser": ("https://github.com/%s", "@%s"),
116 | }
117 |
118 |
119 | # -- Linkcode settings -------------------------------------------------
120 |
121 |
122 | def git(*args):
123 | """Run a git command and return the output."""
124 | return subprocess.check_output(["git", *args]).strip().decode()
125 |
126 |
127 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192
128 | # Current git reference. Uses branch/tag name if found, otherwise uses commit hash
129 | git_ref = None
130 | try:
131 | git_ref = git("name-rev", "--name-only", "--no-undefined", "HEAD")
132 | git_ref = re.sub(r"^(remotes/[^/]+|tags)/", "", git_ref)
133 | except Exception:
134 | pass
135 |
136 | # (if no name found or relative ref, use commit hash instead)
137 | if not git_ref or re.search(r"[\^~]", git_ref):
138 | try:
139 | git_ref = git("rev-parse", "HEAD")
140 | except Exception:
141 | git_ref = "main"
142 |
143 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192
144 | github_repo = "https://github.com/" + html_context["github_user"] + "/" + project_name
145 | _project_module_path = os.path.dirname(importlib.util.find_spec(package_name).origin) # type: ignore
146 |
147 |
148 | def linkcode_resolve(domain, info):
149 | """Resolve links for the linkcode extension."""
150 | if domain != "py":
151 | return None
152 |
153 | try:
154 | obj: Any = sys.modules[info["module"]]
155 | for part in info["fullname"].split("."):
156 | obj = getattr(obj, part)
157 | obj = inspect.unwrap(obj)
158 |
159 | if isinstance(obj, property):
160 | obj = inspect.unwrap(obj.fget) # type: ignore
161 |
162 | path = os.path.relpath(inspect.getsourcefile(obj), start=_project_module_path) # type: ignore
163 | src, lineno = inspect.getsourcelines(obj)
164 | except Exception:
165 | return None
166 |
167 | path = f"{path}#L{lineno}-L{lineno + len(src) - 1}"
168 | return f"{github_repo}/blob/{git_ref}/{package_name}/{path}"
169 |
170 |
171 | # -- Options for HTML output -------------------------------------------------
172 |
173 | # The theme to use for HTML and HTML Help pages. See the documentation for
174 | # a list of builtin themes.
175 | #
176 | html_theme = "sphinx_book_theme"
177 | html_static_path = ["_static"]
178 | html_title = "velovi"
179 |
180 | html_theme_options = {
181 | "repository_url": github_repo,
182 | "use_repository_button": True,
183 | }
184 |
185 | pygments_style = "default"
186 |
187 | nitpick_ignore = [
188 | # If building the documentation fails because of a missing link that is outside your control,
189 | # you can add an exception to this list.
190 | ]
191 |
192 |
193 | def setup(app):
194 | """App setup hook."""
195 | app.add_config_value(
196 | "recommonmark_config",
197 | {
198 | "auto_toc_tree_section": "Contents",
199 | "enable_auto_toc_tree": True,
200 | "enable_math": True,
201 | "enable_inline_math": False,
202 | "enable_eval_rst": True,
203 | },
204 | True,
205 | )
206 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | import re
4 |
5 | from sphinx.application import Sphinx
6 | from sphinx.ext.napoleon import NumpyDocstring
7 |
8 |
9 | def process_return(lines):
10 | """Process the return section of a docstring."""
11 | for line in lines:
12 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line)
13 | if m:
14 | # Once this is in scanpydoc, we can use the fancy hover stuff
15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
16 | else:
17 | yield line
18 |
19 |
20 | def scanpy_parse_returns_section(self, section):
21 | """Parse the returns section of the docstring."""
22 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section())))
23 | lines = self._format_block(":returns: ", lines_raw)
24 | if lines and lines[-1]:
25 | lines.append("")
26 | return lines
27 |
28 |
29 | def setup(app: Sphinx):
30 | """Setup the extension."""
31 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section
32 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{include} ../README.md
2 |
3 | ```
4 |
5 | Welcome to the velovi documentation.
6 |
7 | ```{toctree}
8 | :hidden: true
9 | :maxdepth: 1
10 |
11 | api/index
12 | tutorial
13 | release_notes/index
14 | references
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=scvi
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{GayosoWeiler2022,
2 | title={Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells},
3 | author={Gayoso, Adam and Weiler, Philipp and Lotfollahi, Mohammad and Klein, Dominik and Hong, Justin and Streets, Aaron M and Theis, Fabian J and Yosef, Nir},
4 | journal={bioRxiv},
5 | pages={2022--08},
6 | year={2022},
7 | publisher={Cold Spring Harbor Laboratory}
8 | }
9 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/docs/release_notes/index.rst:
--------------------------------------------------------------------------------
1 | Release notes
2 | =============
3 |
4 | Version 0.1
5 | -----------
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | v0.1.0
10 |
--------------------------------------------------------------------------------
/docs/release_notes/v0.1.0.rst:
--------------------------------------------------------------------------------
1 | New in 0.1.0 (2022-MM-DD)
2 | -------------------------
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["poetry-core>=1.0.0"]
3 | build-backend = "poetry.core.masonry.api"
4 |
5 | [tool.poetry]
6 | authors = ["Adam Gayoso "]
7 | classifiers = [
8 | "Development Status :: 4 - Beta",
9 | "Intended Audience :: Science/Research",
10 | "Natural Language :: English",
11 | "Programming Language :: Python :: 3.9",
12 | "Programming Language :: Python :: 3.10",
13 | "Programming Language :: Python :: 3.11",
14 | "Operating System :: MacOS :: MacOS X",
15 | "Operating System :: Microsoft :: Windows",
16 | "Operating System :: POSIX :: Linux",
17 | "Topic :: Scientific/Engineering :: Bio-Informatics",
18 | ]
19 | description = "Estimation of RNA velocity with variational inference."
20 | documentation = "https://scvi-tools.org"
21 | homepage = "https://github.com/YosefLab/velovi"
22 | license = "BSD-3-Clause"
23 | name = "velovi"
24 | packages = [
25 | {include = "velovi"},
26 | ]
27 | readme = "README.md"
28 | version = "0.4.0"
29 |
30 | [tool.poetry.dependencies]
31 | anndata = ">=0.7.5"
32 | black = {version = ">=20.8b1", optional = true}
33 | codecov = {version = ">=2.0.8", optional = true}
34 | ruff = {version = "*", optional = true}
35 | importlib-metadata = {version = "^1.0", python = "<3.8"}
36 | ipython = {version = ">=7.1.1", optional = true}
37 | jupyter = {version = ">=1.0", optional = true}
38 | pre-commit = {version = ">=2.7.1", optional = true}
39 | sphinx-book-theme = {version = ">=1.0.0", optional = true}
40 | myst-nb = {version = "*", optional = true}
41 | sphinx-copybutton = {version = "*", optional = true}
42 | sphinxcontrib-bibtex = {version = "*", optional = true}
43 | ipykernel = {version = "*", optional = true}
44 | pytest = {version = ">=4.4", optional = true}
45 | pytest-cov = {version = "*", optional = true}
46 | python = ">=3.9,<4.0"
47 | python-igraph = {version = "*", optional = true}
48 | scanpy = {version = ">=1.6", optional = true}
49 | scanpydoc = {version = ">=0.5", optional = true}
50 | scvelo = ">=0.2.5"
51 | scvi-tools = ">=1.0.0"
52 | scikit-learn = ">=0.21.2"
53 | sphinx = {version = ">=4.1", optional = true}
54 | sphinx-autodoc-typehints = {version = "*", optional = true}
55 |
56 | [tool.poetry.extras]
57 | dev = ["black", "pytest", "pytest-cov", "ruff", "codecov", "scanpy", "loompy", "jupyter", "pre-commit"]
58 | docs = [
59 | "sphinx",
60 | "scanpydoc",
61 | "ipython",
62 | "myst-nb",
63 | "sphinx-book-theme",
64 | "sphinx-copybutton",
65 | "sphinxcontrib-bibtex",
66 | "ipykernel",
67 | "ipython",
68 | ]
69 | tutorials = ["scanpy"]
70 |
71 | [tool.poetry.dev-dependencies]
72 |
73 |
74 | [tool.coverage.run]
75 | source = ["velovi"]
76 | omit = [
77 | "**/test_*.py",
78 | ]
79 |
80 | [tool.pytest.ini_options]
81 | testpaths = ["tests"]
82 | xfail_strict = true
83 |
84 |
85 | [tool.black]
86 | include = '\.pyi?$'
87 | exclude = '''
88 | (
89 | /(
90 | \.eggs
91 | | \.git
92 | | \.hg
93 | | \.mypy_cache
94 | | \.tox
95 | | \.venv
96 | | _build
97 | | buck-out
98 | | build
99 | | dist
100 | )/
101 | )
102 | '''
103 |
104 | [tool.ruff]
105 | src = ["."]
106 | line-length = 119
107 | target-version = "py38"
108 | select = [
109 | "F", # Errors detected by Pyflakes
110 | "E", # Error detected by Pycodestyle
111 | "W", # Warning detected by Pycodestyle
112 | "I", # isort
113 | "D", # pydocstyle
114 | "B", # flake8-bugbear
115 | "TID", # flake8-tidy-imports
116 | "C4", # flake8-comprehensions
117 | "BLE", # flake8-blind-except
118 | "UP", # pyupgrade
119 | "RUF100", # Report unused noqa directives
120 | ]
121 | ignore = [
122 | # line too long -> we accept long comment lines; black gets rid of long code lines
123 | "E501",
124 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
125 | "E731",
126 | # allow I, O, l as variable names -> I is the identity matrix
127 | "E741",
128 | # Missing docstring in public package
129 | "D104",
130 | # Missing docstring in public module
131 | "D100",
132 | # Missing docstring in __init__
133 | "D107",
134 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
135 | "B008",
136 | # __magic__ methods are are often self-explanatory, allow missing docstrings
137 | "D105",
138 | # first line should end with a period [Bug: doesn't work with single-line docstrings]
139 | "D400",
140 | # First line should be in imperative mood; try rephrasing
141 | "D401",
142 | ## Disable one in each pair of mutually incompatible rules
143 | # We don’t want a blank line before a class docstring
144 | "D203",
145 | # We want docstrings to start immediately after the opening triple quote
146 | "D213",
147 | # Missing argument description in the docstring TODO: enable
148 | "D417",
149 | ]
150 |
151 | [tool.ruff.per-file-ignores]
152 | "docs/*" = ["I", "BLE001"]
153 | "tests/*" = ["D"]
154 | "*/__init__.py" = ["F401"]
155 | "velovi/__init__.py" = ["I"]
156 |
157 | [tool.jupytext]
158 | formats = "ipynb,md"
159 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-20.04
4 | tools:
5 | python: "3.10"
6 | sphinx:
7 | configuration: docs/conf.py
8 | python:
9 | install:
10 | - method: pip
11 | path: .
12 | extra_requirements:
13 | - docs
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry
4 |
5 | import setuptools
6 |
7 | if __name__ == "__main__":
8 | setuptools.setup(name="velovi")
9 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/velovi/63de3f0b5c2588da056b8e95018f4f136255935d/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_velovi.py:
--------------------------------------------------------------------------------
1 | import scvelo as scv
2 | from scvi.data import synthetic_iid
3 |
4 |
5 | def test_preprocess_data():
6 | adata = synthetic_iid()
7 | adata.layers["spliced"] = adata.X.copy()
8 | adata.layers["unspliced"] = adata.X.copy()
9 | scv.pp.normalize_per_cell(adata)
10 | scv.pp.log1p(adata)
11 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30)
12 | # TODO(adamgayoso): use real data for this test
13 | # preprocess_data(adata)
14 |
15 |
16 | def test_velovi():
17 | from velovi import VELOVI
18 |
19 | n_latent = 5
20 | adata = synthetic_iid()
21 | adata.layers["spliced"] = adata.X.copy()
22 | adata.layers["unspliced"] = adata.X.copy()
23 | VELOVI.setup_anndata(adata, unspliced_layer="unspliced", spliced_layer="spliced")
24 | model = VELOVI(adata, n_latent=n_latent)
25 | model.train(1, check_val_every_n_epoch=1, train_size=0.5)
26 | model.get_latent_representation()
27 | model.get_velocity()
28 | model.get_latent_time()
29 | model.get_state_assignment()
30 | model.get_expression_fit()
31 | model.get_directional_uncertainty()
32 | model.get_permutation_scores(labels_key="labels")
33 |
34 | model.history
35 |
36 | # tests __repr__
37 | print(model)
38 |
--------------------------------------------------------------------------------
/velovi/__init__.py:
--------------------------------------------------------------------------------
1 | """velovi."""
2 |
3 | import logging
4 | import warnings
5 |
6 | # warning has to be at the top level to print on import
7 | warnings.warn(
8 | "The velovi package is no longer being actively developed or maintained as of v0.4.0. Please "
9 | "use the implementation in the scvi-tools package instead. For more information, see "
10 | "https://github.com/scverse/scvi-tools/issues/2610.",
11 | UserWarning,
12 | stacklevel=1,
13 | )
14 |
15 | from rich.console import Console # noqa
16 | from rich.logging import RichHandler # noqa
17 |
18 | from ._constants import REGISTRY_KEYS # noqa
19 | from ._model import VELOVI, VELOVAE # noqa
20 | from ._utils import get_permutation_scores, preprocess_data # noqa
21 |
22 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
23 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
24 | try:
25 | import importlib.metadata as importlib_metadata
26 | except ModuleNotFoundError:
27 | import importlib_metadata
28 |
29 | package_name = "velovi"
30 | __version__ = importlib_metadata.version(package_name)
31 |
32 | logger = logging.getLogger(__name__)
33 | # set the logging level
34 | logger.setLevel(logging.INFO)
35 |
36 | # nice logging outputs
37 | console = Console(force_terminal=True)
38 | if console.is_jupyter is True:
39 | console.is_jupyter = False
40 | ch = RichHandler(show_path=False, console=console, show_time=False)
41 | formatter = logging.Formatter("velovi: %(message)s")
42 | ch.setFormatter(formatter)
43 | logger.addHandler(ch)
44 |
45 | # this prevents double outputs
46 | logger.propagate = False
47 |
48 |
49 | __all__ = [
50 | "VELOVI",
51 | "VELOVAE",
52 | "REGISTRY_KEYS",
53 | "get_permutation_scores",
54 | "preprocess_data",
55 | ]
56 |
--------------------------------------------------------------------------------
/velovi/_constants.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 |
4 | class _REGISTRY_KEYS_NT(NamedTuple):
5 | X_KEY: str = "X"
6 | U_KEY: str = "U"
7 |
8 |
9 | REGISTRY_KEYS = _REGISTRY_KEYS_NT()
10 |
--------------------------------------------------------------------------------
/velovi/_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | from typing import Iterable, List, Literal, Optional, Sequence, Tuple, Union
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import torch.nn.functional as F
9 | from anndata import AnnData
10 | from joblib import Parallel, delayed
11 | from scipy.stats import ttest_ind
12 | from scvi.data import AnnDataManager
13 | from scvi.data.fields import LayerField
14 | from scvi.dataloaders import DataSplitter
15 | from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
16 | from scvi.train import TrainingPlan, TrainRunner
17 | from scvi.utils._docstrings import devices_dsp, setup_anndata_dsp
18 |
19 | from ._constants import REGISTRY_KEYS
20 | from ._module import VELOVAE
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def _softplus_inverse(x: np.ndarray) -> np.ndarray:
26 | x = torch.from_numpy(x)
27 | x_inv = torch.where(x > 20, x, x.expm1().log()).numpy()
28 | return x_inv
29 |
30 |
31 | class VELOVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
32 | """Velocity Variational Inference.
33 |
34 | Parameters
35 | ----------
36 | adata
37 | AnnData object that has been registered via :func:`~velovi.VELOVI.setup_anndata`.
38 | n_hidden
39 | Number of nodes per hidden layer.
40 | n_latent
41 | Dimensionality of the latent space.
42 | n_layers
43 | Number of hidden layers used for encoder and decoder NNs.
44 | dropout_rate
45 | Dropout rate for neural networks.
46 | gamma_init_data
47 | Initialize gamma using the data-driven technique.
48 | linear_decoder
49 | Use a linear decoder from latent space to time.
50 | **model_kwargs
51 | Keyword args for :class:`~velovi.VELOVAE`
52 | """
53 |
54 | def __init__(
55 | self,
56 | adata: AnnData,
57 | n_hidden: int = 256,
58 | n_latent: int = 10,
59 | n_layers: int = 1,
60 | dropout_rate: float = 0.1,
61 | gamma_init_data: bool = False,
62 | linear_decoder: bool = False,
63 | **model_kwargs,
64 | ):
65 | super().__init__(adata)
66 | self.n_latent = n_latent
67 |
68 | spliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
69 | unspliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)
70 |
71 | sorted_unspliced = np.argsort(unspliced, axis=0)
72 | ind = int(adata.n_obs * 0.99)
73 | us_upper_ind = sorted_unspliced[ind:, :]
74 |
75 | us_upper = []
76 | ms_upper = []
77 | for i in range(len(us_upper_ind)):
78 | row = us_upper_ind[i]
79 | us_upper += [unspliced[row, np.arange(adata.n_vars)][np.newaxis, :]]
80 | ms_upper += [spliced[row, np.arange(adata.n_vars)][np.newaxis, :]]
81 | us_upper = np.median(np.concatenate(us_upper, axis=0), axis=0)
82 | ms_upper = np.median(np.concatenate(ms_upper, axis=0), axis=0)
83 |
84 | alpha_unconstr = _softplus_inverse(us_upper)
85 | alpha_unconstr = np.asarray(alpha_unconstr).ravel()
86 |
87 | alpha_1_unconstr = np.zeros(us_upper.shape).ravel()
88 | lambda_alpha_unconstr = np.zeros(us_upper.shape).ravel()
89 |
90 | if gamma_init_data:
91 | gamma_unconstr = np.clip(_softplus_inverse(us_upper / ms_upper), None, 10)
92 | else:
93 | gamma_unconstr = None
94 |
95 | self.module = VELOVAE(
96 | n_input=self.summary_stats["n_vars"],
97 | n_hidden=n_hidden,
98 | n_latent=n_latent,
99 | n_layers=n_layers,
100 | dropout_rate=dropout_rate,
101 | gamma_unconstr_init=gamma_unconstr,
102 | alpha_unconstr_init=alpha_unconstr,
103 | alpha_1_unconstr_init=alpha_1_unconstr,
104 | lambda_alpha_unconstr_init=lambda_alpha_unconstr,
105 | switch_spliced=ms_upper,
106 | switch_unspliced=us_upper,
107 | linear_decoder=linear_decoder,
108 | **model_kwargs,
109 | )
110 | self._model_summary_string = (
111 | "VELOVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
112 | "{}"
113 | ).format(
114 | n_hidden,
115 | n_latent,
116 | n_layers,
117 | dropout_rate,
118 | )
119 | self.init_params_ = self._get_init_params(locals())
120 |
121 | @devices_dsp.dedent
122 | def train(
123 | self,
124 | max_epochs: Optional[int] = 500,
125 | lr: float = 1e-2,
126 | weight_decay: float = 1e-2,
127 | accelerator: str = "auto",
128 | devices: Union[int, list[int], str] = "auto",
129 | train_size: float = 0.9,
130 | validation_size: Optional[float] = None,
131 | batch_size: int = 256,
132 | early_stopping: bool = True,
133 | gradient_clip_val: float = 10,
134 | plan_kwargs: Optional[dict] = None,
135 | **trainer_kwargs,
136 | ):
137 | """Train the model.
138 |
139 | Parameters
140 | ----------
141 | max_epochs
142 | Number of passes through the dataset. If `None`, defaults to
143 | `np.min([round((20000 / n_cells) * 400), 400])`
144 | lr
145 | Learning rate for optimization
146 | weight_decay
147 | Weight decay for optimization
148 | %(param_accelerator)s
149 | %(param_devices)s
150 | train_size
151 | Size of training set in the range [0.0, 1.0].
152 | validation_size
153 | Size of the test set. If `None`, defaults to 1 - `train_size`. If
154 | `train_size + validation_size < 1`, the remaining cells belong to a test set.
155 | batch_size
156 | Minibatch size to use during training.
157 | early_stopping
158 | Perform early stopping. Additional arguments can be passed in `**kwargs`.
159 | See :class:`~scvi.train.Trainer` for further options.
160 | gradient_clip_val
161 | Val for gradient clipping
162 | plan_kwargs
163 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
164 | `train()` will overwrite values present in `plan_kwargs`, when appropriate.
165 | **trainer_kwargs
166 | Other keyword args for :class:`~scvi.train.Trainer`.
167 | """
168 | user_plan_kwargs = plan_kwargs.copy() if isinstance(plan_kwargs, dict) else {}
169 | plan_kwargs = {"lr": lr, "weight_decay": weight_decay, "optimizer": "AdamW"}
170 | plan_kwargs.update(user_plan_kwargs)
171 |
172 | user_train_kwargs = trainer_kwargs.copy()
173 | trainer_kwargs = {"gradient_clip_val": gradient_clip_val}
174 | trainer_kwargs.update(user_train_kwargs)
175 |
176 | data_splitter = DataSplitter(
177 | self.adata_manager,
178 | train_size=train_size,
179 | validation_size=validation_size,
180 | batch_size=batch_size,
181 | )
182 | training_plan = TrainingPlan(self.module, **plan_kwargs)
183 |
184 | es = "early_stopping"
185 | trainer_kwargs[es] = (
186 | early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
187 | )
188 | runner = TrainRunner(
189 | self,
190 | training_plan=training_plan,
191 | data_splitter=data_splitter,
192 | max_epochs=max_epochs,
193 | accelerator=accelerator,
194 | devices=devices,
195 | **trainer_kwargs,
196 | )
197 | return runner()
198 |
199 | @torch.inference_mode()
200 | def get_state_assignment(
201 | self,
202 | adata: Optional[AnnData] = None,
203 | indices: Optional[Sequence[int]] = None,
204 | gene_list: Optional[Sequence[str]] = None,
205 | hard_assignment: bool = False,
206 | n_samples: int = 20,
207 | batch_size: Optional[int] = None,
208 | return_mean: bool = True,
209 | return_numpy: Optional[bool] = None,
210 | ) -> Tuple[Union[np.ndarray, pd.DataFrame], List[str]]:
211 | """Returns cells by genes by states probabilities.
212 |
213 | Parameters
214 | ----------
215 | adata
216 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
217 | AnnData object used to initialize the model.
218 | indices
219 | Indices of cells in adata to use. If `None`, all cells are used.
220 | gene_list
221 | Return frequencies of expression for a subset of genes.
222 | This can save memory when working with large datasets and few genes are
223 | of interest.
224 | hard_assignment
225 | Return a hard state assignment
226 | n_samples
227 | Number of posterior samples to use for estimation.
228 | batch_size
229 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
230 | return_mean
231 | Whether to return the mean of the samples.
232 | return_numpy
233 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
234 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
235 | Otherwise, it defaults to `True`.
236 |
237 | Returns
238 | -------
239 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
240 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
241 | """
242 | adata = self._validate_anndata(adata)
243 | scdl = self._make_data_loader(
244 | adata=adata, indices=indices, batch_size=batch_size
245 | )
246 |
247 | if gene_list is None:
248 | gene_mask = slice(None)
249 | else:
250 | all_genes = adata.var_names
251 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
252 |
253 | if n_samples > 1 and return_mean is False:
254 | if return_numpy is False:
255 | warnings.warn(
256 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
257 | )
258 | return_numpy = True
259 | if indices is None:
260 | indices = np.arange(adata.n_obs)
261 |
262 | states = []
263 | for tensors in scdl:
264 | minibatch_samples = []
265 | for _ in range(n_samples):
266 | _, generative_outputs = self.module.forward(
267 | tensors=tensors,
268 | compute_loss=False,
269 | )
270 | output = generative_outputs["px_pi"]
271 | output = output[..., gene_mask, :]
272 | output = output.cpu().numpy()
273 | minibatch_samples.append(output)
274 | # samples by cells by genes by four
275 | states.append(np.stack(minibatch_samples, axis=0))
276 | if return_mean:
277 | states[-1] = np.mean(states[-1], axis=0)
278 |
279 | states = np.concatenate(states, axis=0)
280 | state_cats = [
281 | "induction",
282 | "induction_steady",
283 | "repression",
284 | "repression_steady",
285 | ]
286 | if hard_assignment and return_mean:
287 | hard_assign = states.argmax(-1)
288 |
289 | hard_assign = pd.DataFrame(
290 | data=hard_assign, index=adata.obs_names, columns=adata.var_names
291 | )
292 | for i, s in enumerate(state_cats):
293 | hard_assign = hard_assign.replace(i, s)
294 |
295 | states = hard_assign
296 |
297 | return states, state_cats
298 |
299 | @torch.inference_mode()
300 | def get_latent_time(
301 | self,
302 | adata: Optional[AnnData] = None,
303 | indices: Optional[Sequence[int]] = None,
304 | gene_list: Optional[Sequence[str]] = None,
305 | time_statistic: Literal["mean", "max"] = "mean",
306 | n_samples: int = 1,
307 | n_samples_overall: Optional[int] = None,
308 | batch_size: Optional[int] = None,
309 | return_mean: bool = True,
310 | return_numpy: Optional[bool] = None,
311 | ) -> Union[np.ndarray, pd.DataFrame]:
312 | """Returns the cells by genes latent time.
313 |
314 | Parameters
315 | ----------
316 | adata
317 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
318 | AnnData object used to initialize the model.
319 | indices
320 | Indices of cells in adata to use. If `None`, all cells are used.
321 | gene_list
322 | Return frequencies of expression for a subset of genes.
323 | This can save memory when working with large datasets and few genes are
324 | of interest.
325 | time_statistic
326 | Whether to compute expected time over states, or maximum a posteriori time over maximal
327 | probability state.
328 | n_samples
329 | Number of posterior samples to use for estimation.
330 | n_samples_overall
331 | Number of overall samples to return. Setting this forces n_samples=1.
332 | batch_size
333 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
334 | return_mean
335 | Whether to return the mean of the samples.
336 | return_numpy
337 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
338 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
339 | Otherwise, it defaults to `True`.
340 |
341 | Returns
342 | -------
343 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
344 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
345 | """
346 | adata = self._validate_anndata(adata)
347 | if indices is None:
348 | indices = np.arange(adata.n_obs)
349 | if n_samples_overall is not None:
350 | indices = np.random.choice(indices, n_samples_overall)
351 | scdl = self._make_data_loader(
352 | adata=adata, indices=indices, batch_size=batch_size
353 | )
354 |
355 | if gene_list is None:
356 | gene_mask = slice(None)
357 | else:
358 | all_genes = adata.var_names
359 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
360 |
361 | if n_samples > 1 and return_mean is False:
362 | if return_numpy is False:
363 | warnings.warn(
364 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
365 | )
366 | return_numpy = True
367 | if indices is None:
368 | indices = np.arange(adata.n_obs)
369 |
370 | times = []
371 | for tensors in scdl:
372 | minibatch_samples = []
373 | for _ in range(n_samples):
374 | _, generative_outputs = self.module.forward(
375 | tensors=tensors,
376 | compute_loss=False,
377 | )
378 | pi = generative_outputs["px_pi"]
379 | ind_prob = pi[..., 0]
380 | steady_prob = pi[..., 1]
381 | rep_prob = pi[..., 2]
382 | # rep_steady_prob = pi[..., 3]
383 | switch_time = F.softplus(self.module.switch_time_unconstr)
384 |
385 | ind_time = generative_outputs["px_rho"] * switch_time
386 | rep_time = switch_time + (
387 | generative_outputs["px_tau"] * (self.module.t_max - switch_time)
388 | )
389 |
390 | if time_statistic == "mean":
391 | output = (
392 | ind_prob * ind_time
393 | + rep_prob * rep_time
394 | + steady_prob * switch_time
395 | # + rep_steady_prob * self.module.t_max
396 | )
397 | else:
398 | t = torch.stack(
399 | [
400 | ind_time,
401 | switch_time.expand(ind_time.shape),
402 | rep_time,
403 | torch.zeros_like(ind_time),
404 | ],
405 | dim=2,
406 | )
407 | max_prob = torch.amax(pi, dim=-1)
408 | max_prob = torch.stack([max_prob] * 4, dim=2)
409 | max_prob_mask = pi.ge(max_prob)
410 | output = (t * max_prob_mask).sum(dim=-1)
411 |
412 | output = output[..., gene_mask]
413 | output = output.cpu().numpy()
414 | minibatch_samples.append(output)
415 | # samples by cells by genes by four
416 | times.append(np.stack(minibatch_samples, axis=0))
417 | if return_mean:
418 | times[-1] = np.mean(times[-1], axis=0)
419 |
420 | if n_samples > 1:
421 | # The -2 axis correspond to cells.
422 | times = np.concatenate(times, axis=-2)
423 | else:
424 | times = np.concatenate(times, axis=0)
425 |
426 | if return_numpy is None or return_numpy is False:
427 | return pd.DataFrame(
428 | times,
429 | columns=adata.var_names[gene_mask],
430 | index=adata.obs_names[indices],
431 | )
432 | else:
433 | return times
434 |
435 | @torch.inference_mode()
436 | def get_velocity(
437 | self,
438 | adata: Optional[AnnData] = None,
439 | indices: Optional[Sequence[int]] = None,
440 | gene_list: Optional[Sequence[str]] = None,
441 | n_samples: int = 1,
442 | n_samples_overall: Optional[int] = None,
443 | batch_size: Optional[int] = None,
444 | return_mean: bool = True,
445 | return_numpy: Optional[bool] = None,
446 | velo_statistic: str = "mean",
447 | velo_mode: Literal["spliced", "unspliced"] = "spliced",
448 | clip: bool = True,
449 | ) -> Union[np.ndarray, pd.DataFrame]:
450 | """Returns cells by genes velocity estimates.
451 |
452 | Parameters
453 | ----------
454 | adata
455 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
456 | AnnData object used to initialize the model.
457 | indices
458 | Indices of cells in adata to use. If `None`, all cells are used.
459 | gene_list
460 | Return velocities for a subset of genes.
461 | This can save memory when working with large datasets and few genes are
462 | of interest.
463 | n_samples
464 | Number of posterior samples to use for estimation for each cell.
465 | n_samples_overall
466 | Number of overall samples to return. Setting this forces n_samples=1.
467 | batch_size
468 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
469 | return_mean
470 | Whether to return the mean of the samples.
471 | return_numpy
472 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
473 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
474 | Otherwise, it defaults to `True`.
475 | velo_statistic
476 | Whether to compute expected velocity over states, or maximum a posteriori velocity over maximal
477 | probability state.
478 | velo_mode
479 | Compute ds/dt or du/dt.
480 | clip
481 | Clip to minus spliced value
482 |
483 | Returns
484 | -------
485 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
486 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
487 | """
488 | adata = self._validate_anndata(adata)
489 | if indices is None:
490 | indices = np.arange(adata.n_obs)
491 | if n_samples_overall is not None:
492 | indices = np.random.choice(indices, n_samples_overall)
493 | n_samples = 1
494 | scdl = self._make_data_loader(
495 | adata=adata, indices=indices, batch_size=batch_size
496 | )
497 |
498 | if gene_list is None:
499 | gene_mask = slice(None)
500 | else:
501 | all_genes = adata.var_names
502 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
503 |
504 | if n_samples > 1 and return_mean is False:
505 | if return_numpy is False:
506 | warnings.warn(
507 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
508 | )
509 | return_numpy = True
510 | if indices is None:
511 | indices = np.arange(adata.n_obs)
512 |
513 | velos = []
514 | for tensors in scdl:
515 | minibatch_samples = []
516 | for _ in range(n_samples):
517 | inference_outputs, generative_outputs = self.module.forward(
518 | tensors=tensors,
519 | compute_loss=False,
520 | )
521 | pi = generative_outputs["px_pi"]
522 | alpha = inference_outputs["alpha"]
523 | alpha_1 = inference_outputs["alpha_1"]
524 | lambda_alpha = inference_outputs["lambda_alpha"]
525 | beta = inference_outputs["beta"]
526 | gamma = inference_outputs["gamma"]
527 | tau = generative_outputs["px_tau"]
528 | rho = generative_outputs["px_rho"]
529 |
530 | ind_prob = pi[..., 0]
531 | steady_prob = pi[..., 1]
532 | rep_prob = pi[..., 2]
533 | switch_time = F.softplus(self.module.switch_time_unconstr)
534 |
535 | ind_time = switch_time * rho
536 | u_0, s_0 = self.module._get_induction_unspliced_spliced(
537 | alpha, alpha_1, lambda_alpha, beta, gamma, switch_time
538 | )
539 | rep_time = (self.module.t_max - switch_time) * tau
540 | mean_u_rep, mean_s_rep = self.module._get_repression_unspliced_spliced(
541 | u_0,
542 | s_0,
543 | beta,
544 | gamma,
545 | rep_time,
546 | )
547 | if velo_mode == "spliced":
548 | velo_rep = beta * mean_u_rep - gamma * mean_s_rep
549 | else:
550 | velo_rep = -beta * mean_u_rep
551 | mean_u_ind, mean_s_ind = self.module._get_induction_unspliced_spliced(
552 | alpha, alpha_1, lambda_alpha, beta, gamma, ind_time
553 | )
554 | if velo_mode == "spliced":
555 | velo_ind = beta * mean_u_ind - gamma * mean_s_ind
556 | else:
557 | transcription_rate = alpha_1 - (alpha_1 - alpha) * torch.exp(
558 | -lambda_alpha * ind_time
559 | )
560 | velo_ind = transcription_rate - beta * mean_u_ind
561 |
562 | if velo_mode == "spliced":
563 | # velo_steady = beta * u_0 - gamma * s_0
564 | velo_steady = torch.zeros_like(velo_ind)
565 | else:
566 | # velo_steady = alpha - beta * u_0
567 | velo_steady = torch.zeros_like(velo_ind)
568 |
569 | # expectation
570 | if velo_statistic == "mean":
571 | output = (
572 | ind_prob * velo_ind
573 | + rep_prob * velo_rep
574 | + steady_prob * velo_steady
575 | )
576 | # maximum
577 | else:
578 | v = torch.stack(
579 | [
580 | velo_ind,
581 | velo_steady.expand(velo_ind.shape),
582 | velo_rep,
583 | torch.zeros_like(velo_rep),
584 | ],
585 | dim=2,
586 | )
587 | max_prob = torch.amax(pi, dim=-1)
588 | max_prob = torch.stack([max_prob] * 4, dim=2)
589 | max_prob_mask = pi.ge(max_prob)
590 | output = (v * max_prob_mask).sum(dim=-1)
591 |
592 | output = output[..., gene_mask]
593 | output = output.cpu().numpy()
594 | minibatch_samples.append(output)
595 | # samples by cells by genes
596 | velos.append(np.stack(minibatch_samples, axis=0))
597 | if return_mean:
598 | # mean over samples axis
599 | velos[-1] = np.mean(velos[-1], axis=0)
600 |
601 | if n_samples > 1:
602 | # The -2 axis correspond to cells.
603 | velos = np.concatenate(velos, axis=-2)
604 | else:
605 | velos = np.concatenate(velos, axis=0)
606 |
607 | spliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
608 |
609 | if clip:
610 | velos = np.clip(velos, -spliced[indices], None)
611 |
612 | if return_numpy is None or return_numpy is False:
613 | return pd.DataFrame(
614 | velos,
615 | columns=adata.var_names[gene_mask],
616 | index=adata.obs_names[indices],
617 | )
618 | else:
619 | return velos
620 |
621 | @torch.inference_mode()
622 | def get_expression_fit(
623 | self,
624 | adata: Optional[AnnData] = None,
625 | indices: Optional[Sequence[int]] = None,
626 | gene_list: Optional[Sequence[str]] = None,
627 | n_samples: int = 1,
628 | batch_size: Optional[int] = None,
629 | return_mean: bool = True,
630 | return_numpy: Optional[bool] = None,
631 | restrict_to_latent_dim: Optional[int] = None,
632 | ) -> Union[np.ndarray, pd.DataFrame]:
633 | r"""Returns the fitted spliced and unspliced abundance (s(t) and u(t)).
634 |
635 | Parameters
636 | ----------
637 | adata
638 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
639 | AnnData object used to initialize the model.
640 | indices
641 | Indices of cells in adata to use. If `None`, all cells are used.
642 | gene_list
643 | Return frequencies of expression for a subset of genes.
644 | This can save memory when working with large datasets and few genes are
645 | of interest.
646 | n_samples
647 | Number of posterior samples to use for estimation.
648 | batch_size
649 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
650 | return_mean
651 | Whether to return the mean of the samples.
652 | return_numpy
653 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
654 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
655 | Otherwise, it defaults to `True`.
656 |
657 | Returns
658 | -------
659 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
660 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
661 | """
662 | adata = self._validate_anndata(adata)
663 |
664 | scdl = self._make_data_loader(
665 | adata=adata, indices=indices, batch_size=batch_size
666 | )
667 |
668 | if gene_list is None:
669 | gene_mask = slice(None)
670 | else:
671 | all_genes = adata.var_names
672 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
673 |
674 | if n_samples > 1 and return_mean is False:
675 | if return_numpy is False:
676 | warnings.warn(
677 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
678 | )
679 | return_numpy = True
680 | if indices is None:
681 | indices = np.arange(adata.n_obs)
682 |
683 | fits_s = []
684 | fits_u = []
685 | for tensors in scdl:
686 | minibatch_samples_s = []
687 | minibatch_samples_u = []
688 | for _ in range(n_samples):
689 | inference_outputs, generative_outputs = self.module.forward(
690 | tensors=tensors,
691 | compute_loss=False,
692 | generative_kwargs={"latent_dim": restrict_to_latent_dim},
693 | )
694 |
695 | gamma = inference_outputs["gamma"]
696 | beta = inference_outputs["beta"]
697 | alpha = inference_outputs["alpha"]
698 | alpha_1 = inference_outputs["alpha_1"]
699 | lambda_alpha = inference_outputs["lambda_alpha"]
700 | px_pi = generative_outputs["px_pi"]
701 | scale = generative_outputs["scale"]
702 | px_rho = generative_outputs["px_rho"]
703 | px_tau = generative_outputs["px_tau"]
704 |
705 | (
706 | mixture_dist_s,
707 | mixture_dist_u,
708 | _,
709 | ) = self.module.get_px(
710 | px_pi,
711 | px_rho,
712 | px_tau,
713 | scale,
714 | gamma,
715 | beta,
716 | alpha,
717 | alpha_1,
718 | lambda_alpha,
719 | )
720 | fit_s = mixture_dist_s.mean
721 | fit_u = mixture_dist_u.mean
722 |
723 | fit_s = fit_s[..., gene_mask]
724 | fit_s = fit_s.cpu().numpy()
725 | fit_u = fit_u[..., gene_mask]
726 | fit_u = fit_u.cpu().numpy()
727 |
728 | minibatch_samples_s.append(fit_s)
729 | minibatch_samples_u.append(fit_u)
730 |
731 | # samples by cells by genes
732 | fits_s.append(np.stack(minibatch_samples_s, axis=0))
733 | if return_mean:
734 | # mean over samples axis
735 | fits_s[-1] = np.mean(fits_s[-1], axis=0)
736 | # samples by cells by genes
737 | fits_u.append(np.stack(minibatch_samples_u, axis=0))
738 | if return_mean:
739 | # mean over samples axis
740 | fits_u[-1] = np.mean(fits_u[-1], axis=0)
741 |
742 | if n_samples > 1:
743 | # The -2 axis correspond to cells.
744 | fits_s = np.concatenate(fits_s, axis=-2)
745 | fits_u = np.concatenate(fits_u, axis=-2)
746 | else:
747 | fits_s = np.concatenate(fits_s, axis=0)
748 | fits_u = np.concatenate(fits_u, axis=0)
749 |
750 | if return_numpy is None or return_numpy is False:
751 | df_s = pd.DataFrame(
752 | fits_s,
753 | columns=adata.var_names[gene_mask],
754 | index=adata.obs_names[indices],
755 | )
756 | df_u = pd.DataFrame(
757 | fits_u,
758 | columns=adata.var_names[gene_mask],
759 | index=adata.obs_names[indices],
760 | )
761 | return df_s, df_u
762 | else:
763 | return fits_s, fits_u
764 |
765 | @torch.inference_mode()
766 | def get_gene_likelihood(
767 | self,
768 | adata: Optional[AnnData] = None,
769 | indices: Optional[Sequence[int]] = None,
770 | gene_list: Optional[Sequence[str]] = None,
771 | n_samples: int = 1,
772 | batch_size: Optional[int] = None,
773 | return_mean: bool = True,
774 | return_numpy: Optional[bool] = None,
775 | ) -> Union[np.ndarray, pd.DataFrame]:
776 | r"""Returns the likelihood per gene. Higher is better.
777 |
778 | This is denoted as :math:`\rho_n` in the scVI paper.
779 |
780 | Parameters
781 | ----------
782 | adata
783 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
784 | AnnData object used to initialize the model.
785 | indices
786 | Indices of cells in adata to use. If `None`, all cells are used.
787 | transform_batch
788 | Batch to condition on.
789 | If transform_batch is:
790 |
791 | - None, then real observed batch is used.
792 | - int, then batch transform_batch is used.
793 | gene_list
794 | Return frequencies of expression for a subset of genes.
795 | This can save memory when working with large datasets and few genes are
796 | of interest.
797 | library_size
798 | Scale the expression frequencies to a common library size.
799 | This allows gene expression levels to be interpreted on a common scale of relevant
800 | magnitude. If set to `"latent"`, use the latent libary size.
801 | n_samples
802 | Number of posterior samples to use for estimation.
803 | batch_size
804 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
805 | return_mean
806 | Whether to return the mean of the samples.
807 | return_numpy
808 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
809 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
810 | Otherwise, it defaults to `True`.
811 |
812 | Returns
813 | -------
814 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
815 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
816 | """
817 | adata = self._validate_anndata(adata)
818 | scdl = self._make_data_loader(
819 | adata=adata, indices=indices, batch_size=batch_size
820 | )
821 |
822 | if gene_list is None:
823 | gene_mask = slice(None)
824 | else:
825 | all_genes = adata.var_names
826 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
827 |
828 | if n_samples > 1 and return_mean is False:
829 | if return_numpy is False:
830 | warnings.warn(
831 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
832 | )
833 | return_numpy = True
834 | if indices is None:
835 | indices = np.arange(adata.n_obs)
836 |
837 | rls = []
838 | for tensors in scdl:
839 | minibatch_samples = []
840 | for _ in range(n_samples):
841 | inference_outputs, generative_outputs = self.module.forward(
842 | tensors=tensors,
843 | compute_loss=False,
844 | )
845 | spliced = tensors[REGISTRY_KEYS.X_KEY]
846 | unspliced = tensors[REGISTRY_KEYS.U_KEY]
847 |
848 | gamma = inference_outputs["gamma"]
849 | beta = inference_outputs["beta"]
850 | alpha = inference_outputs["alpha"]
851 | alpha_1 = inference_outputs["alpha_1"]
852 | lambda_alpha = inference_outputs["lambda_alpha"]
853 | px_pi = generative_outputs["px_pi"]
854 | scale = generative_outputs["scale"]
855 | px_rho = generative_outputs["px_rho"]
856 | px_tau = generative_outputs["px_tau"]
857 |
858 | (
859 | mixture_dist_s,
860 | mixture_dist_u,
861 | _,
862 | ) = self.module.get_px(
863 | px_pi,
864 | px_rho,
865 | px_tau,
866 | scale,
867 | gamma,
868 | beta,
869 | alpha,
870 | alpha_1,
871 | lambda_alpha,
872 | )
873 | device = gamma.device
874 | reconst_loss_s = -mixture_dist_s.log_prob(spliced.to(device))
875 | reconst_loss_u = -mixture_dist_u.log_prob(unspliced.to(device))
876 | output = -(reconst_loss_s + reconst_loss_u)
877 | output = output[..., gene_mask]
878 | output = output.cpu().numpy()
879 | minibatch_samples.append(output)
880 | # samples by cells by genes by four
881 | rls.append(np.stack(minibatch_samples, axis=0))
882 | if return_mean:
883 | rls[-1] = np.mean(rls[-1], axis=0)
884 |
885 | rls = np.concatenate(rls, axis=0)
886 | return rls
887 |
888 | @torch.inference_mode()
889 | def get_rates(self):
890 | gamma, beta, alpha, alpha_1, lambda_alpha = self.module._get_rates()
891 |
892 | return {
893 | "beta": beta.cpu().numpy(),
894 | "gamma": gamma.cpu().numpy(),
895 | "alpha": alpha.cpu().numpy(),
896 | "alpha_1": alpha_1.cpu().numpy(),
897 | "lambda_alpha": lambda_alpha.cpu().numpy(),
898 | }
899 |
900 | @classmethod
901 | @setup_anndata_dsp.dedent
902 | def setup_anndata(
903 | cls,
904 | adata: AnnData,
905 | spliced_layer: str,
906 | unspliced_layer: str,
907 | **kwargs,
908 | ) -> Optional[AnnData]:
909 | """%(summary)s.
910 |
911 | Parameters
912 | ----------
913 | %(param_adata)s
914 | spliced_layer
915 | Layer in adata with spliced normalized expression
916 | unspliced_layer
917 | Layer in adata with unspliced normalized expression.
918 |
919 | Returns
920 | -------
921 | %(returns)s
922 | """
923 | setup_method_args = cls._get_setup_method_args(**locals())
924 | anndata_fields = [
925 | LayerField(REGISTRY_KEYS.X_KEY, spliced_layer, is_count_data=False),
926 | LayerField(REGISTRY_KEYS.U_KEY, unspliced_layer, is_count_data=False),
927 | ]
928 | adata_manager = AnnDataManager(
929 | fields=anndata_fields, setup_method_args=setup_method_args
930 | )
931 | adata_manager.register_fields(adata, **kwargs)
932 | cls.register_manager(adata_manager)
933 |
934 | def get_directional_uncertainty(
935 | self,
936 | adata: Optional[AnnData] = None,
937 | n_samples: int = 50,
938 | gene_list: Iterable[str] = None,
939 | n_jobs: int = -1,
940 | ):
941 | adata = self._validate_anndata(adata)
942 |
943 | logger.info("Sampling from model...")
944 | velocities_all = self.get_velocity(
945 | n_samples=n_samples, return_mean=False, gene_list=gene_list
946 | ) # (n_samples, n_cells, n_genes)
947 |
948 | df, cosine_sims = _compute_directional_statistics_tensor(
949 | tensor=velocities_all, n_jobs=n_jobs, n_cells=adata.n_obs
950 | )
951 | df.index = adata.obs_names
952 |
953 | return df, cosine_sims
954 |
955 | def get_permutation_scores(
956 | self, labels_key: str, adata: Optional[AnnData] = None
957 | ) -> Tuple[pd.DataFrame, AnnData]:
958 | """Compute permutation scores.
959 |
960 | Parameters
961 | ----------
962 | labels_key
963 | Key in adata.obs encoding cell types
964 | adata
965 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
966 | AnnData object used to initialize the model.
967 |
968 | Returns
969 | -------
970 | Tuple of DataFrame and AnnData. DataFrame is genes by cell types with score per cell type.
971 | AnnData is the permutated version of the original AnnData.
972 | """
973 | adata = self._validate_anndata(adata)
974 | adata_manager = self.get_anndata_manager(adata)
975 | if labels_key not in adata.obs:
976 | raise ValueError(f"{labels_key} not found in adata.obs")
977 |
978 | # shuffle spliced then unspliced
979 | bdata = self._shuffle_layer_celltype(
980 | adata_manager, labels_key, REGISTRY_KEYS.X_KEY
981 | )
982 | bdata_manager = self.get_anndata_manager(bdata)
983 | bdata = self._shuffle_layer_celltype(
984 | bdata_manager, labels_key, REGISTRY_KEYS.U_KEY
985 | )
986 | bdata_manager = self.get_anndata_manager(bdata)
987 |
988 | ms_ = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
989 | mu_ = adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)
990 |
991 | ms_p = bdata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
992 | mu_p = bdata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)
993 |
994 | spliced_, unspliced_ = self.get_expression_fit(adata, n_samples=10)
995 | root_squared_error = np.abs(spliced_ - ms_)
996 | root_squared_error += np.abs(unspliced_ - mu_)
997 |
998 | spliced_p, unspliced_p = self.get_expression_fit(bdata, n_samples=10)
999 | root_squared_error_p = np.abs(spliced_p - ms_p)
1000 | root_squared_error_p += np.abs(unspliced_p - mu_p)
1001 |
1002 | celltypes = np.unique(adata.obs[labels_key])
1003 |
1004 | dynamical_df = pd.DataFrame(
1005 | index=adata.var_names,
1006 | columns=celltypes,
1007 | data=np.zeros((adata.shape[1], len(celltypes))),
1008 | )
1009 | N = 200
1010 | for ct in celltypes:
1011 | for g in adata.var_names.tolist():
1012 | x = root_squared_error_p[g][adata.obs[labels_key] == ct]
1013 | y = root_squared_error[g][adata.obs[labels_key] == ct]
1014 | ratio = ttest_ind(x[:N], y[:N])[0]
1015 | dynamical_df.loc[g, ct] = ratio
1016 |
1017 | return dynamical_df, bdata
1018 |
1019 | def _shuffle_layer_celltype(
1020 | self, adata_manager: AnnDataManager, labels_key: str, registry_key: str
1021 | ) -> AnnData:
1022 | """Shuffle cells within cell types for each gene."""
1023 | from scvi.data._constants import _SCVI_UUID_KEY
1024 |
1025 | bdata = adata_manager.adata.copy()
1026 | labels = bdata.obs[labels_key]
1027 | del bdata.uns[_SCVI_UUID_KEY]
1028 | self._validate_anndata(bdata)
1029 | bdata_manager = self.get_anndata_manager(bdata)
1030 |
1031 | # get registry info to later set data back in bdata
1032 | # in a way that doesn't require actual knowledge of location
1033 | unspliced = bdata_manager.get_from_registry(registry_key)
1034 | u_registry = bdata_manager.data_registry[registry_key]
1035 | attr_name = u_registry.attr_name
1036 | attr_key = u_registry.attr_key
1037 |
1038 | for lab in np.unique(labels):
1039 | mask = np.asarray(labels == lab)
1040 | unspliced_ct = unspliced[mask].copy()
1041 | unspliced_ct = np.apply_along_axis(
1042 | np.random.permutation, axis=0, arr=unspliced_ct
1043 | )
1044 | unspliced[mask] = unspliced_ct
1045 | # e.g., if using adata.X
1046 | if attr_key is None:
1047 | setattr(bdata, attr_name, unspliced)
1048 | # e.g., if using a layer
1049 | elif attr_key is not None:
1050 | attribute = getattr(bdata, attr_name)
1051 | attribute[attr_key] = unspliced
1052 | setattr(bdata, attr_name, attribute)
1053 |
1054 | return bdata
1055 |
1056 |
1057 | def _compute_directional_statistics_tensor(
1058 | tensor: np.ndarray, n_jobs: int, n_cells: int
1059 | ) -> pd.DataFrame:
1060 | df = pd.DataFrame(index=np.arange(n_cells))
1061 | df["directional_variance"] = np.nan
1062 | df["directional_difference"] = np.nan
1063 | df["directional_cosine_sim_variance"] = np.nan
1064 | df["directional_cosine_sim_difference"] = np.nan
1065 | df["directional_cosine_sim_mean"] = np.nan
1066 | logger.info("Computing the uncertainties...")
1067 | results = Parallel(n_jobs=n_jobs, verbose=3)(
1068 | delayed(_directional_statistics_per_cell)(tensor[:, cell_index, :])
1069 | for cell_index in range(n_cells)
1070 | )
1071 | # cells by samples
1072 | cosine_sims = np.stack([results[i][0] for i in range(n_cells)])
1073 | df.loc[:, "directional_cosine_sim_variance"] = [
1074 | results[i][1] for i in range(n_cells)
1075 | ]
1076 | df.loc[:, "directional_cosine_sim_difference"] = [
1077 | results[i][2] for i in range(n_cells)
1078 | ]
1079 | df.loc[:, "directional_variance"] = [results[i][3] for i in range(n_cells)]
1080 | df.loc[:, "directional_difference"] = [results[i][4] for i in range(n_cells)]
1081 | df.loc[:, "directional_cosine_sim_mean"] = [results[i][5] for i in range(n_cells)]
1082 |
1083 | return df, cosine_sims
1084 |
1085 |
1086 | def _directional_statistics_per_cell(
1087 | tensor: np.ndarray,
1088 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
1089 | """Internal function for parallelization.
1090 |
1091 | Parameters
1092 | ----------
1093 | tensor
1094 | Shape of samples by genes for a given cell.
1095 | """
1096 | n_samples = tensor.shape[0]
1097 | # over samples axis
1098 | mean_velocity_of_cell = tensor.mean(0)
1099 | cosine_sims = [
1100 | _cosine_sim(tensor[i, :], mean_velocity_of_cell) for i in range(n_samples)
1101 | ]
1102 | angle_samples = [np.arccos(el) for el in cosine_sims]
1103 | return (
1104 | cosine_sims,
1105 | np.var(cosine_sims),
1106 | np.percentile(cosine_sims, 95) - np.percentile(cosine_sims, 5),
1107 | np.var(angle_samples),
1108 | np.percentile(angle_samples, 95) - np.percentile(angle_samples, 5),
1109 | np.mean(cosine_sims),
1110 | )
1111 |
1112 |
1113 | def _centered_unit_vector(vector: np.ndarray) -> np.ndarray:
1114 | """Returns the centered unit vector of the vector."""
1115 | vector = vector - np.mean(vector)
1116 | return vector / np.linalg.norm(vector)
1117 |
1118 |
1119 | def _cosine_sim(v1: np.ndarray, v2: np.ndarray) -> np.ndarray:
1120 | """Returns cosine similarity of the vectors."""
1121 | v1_u = _centered_unit_vector(v1)
1122 | v2_u = _centered_unit_vector(v2)
1123 | return np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)
1124 |
--------------------------------------------------------------------------------
/velovi/_module.py:
--------------------------------------------------------------------------------
1 | """Main module."""
2 | from typing import Callable, Iterable, Literal, Optional
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
8 | from scvi.nn import Encoder, FCLayers
9 | from torch import nn as nn
10 | from torch.distributions import Categorical, Dirichlet, MixtureSameFamily, Normal
11 | from torch.distributions import kl_divergence as kl
12 |
13 | from ._constants import REGISTRY_KEYS
14 |
15 | torch.backends.cudnn.benchmark = True
16 |
17 |
18 | class DecoderVELOVI(nn.Module):
19 | """Decodes data from latent space of ``n_input`` dimensions ``n_output``dimensions.
20 |
21 | Uses a fully-connected neural network of ``n_hidden`` layers.
22 |
23 | Parameters
24 | ----------
25 | n_input
26 | The dimensionality of the input (latent space)
27 | n_output
28 | The dimensionality of the output (data space)
29 | n_cat_list
30 | A list containing the number of categories
31 | for each category of interest. Each category will be
32 | included using a one-hot encoding
33 | n_layers
34 | The number of fully-connected hidden layers
35 | n_hidden
36 | The number of nodes per hidden layer
37 | dropout_rate
38 | Dropout rate to apply to each of the hidden layers
39 | inject_covariates
40 | Whether to inject covariates in each layer, or just the first (default).
41 | use_batch_norm
42 | Whether to use batch norm in layers
43 | use_layer_norm
44 | Whether to use layer norm in layers
45 | linear_decoder
46 | Whether to use linear decoder for time
47 | """
48 |
49 | def __init__(
50 | self,
51 | n_input: int,
52 | n_output: int,
53 | n_cat_list: Iterable[int] = None,
54 | n_layers: int = 1,
55 | n_hidden: int = 128,
56 | inject_covariates: bool = True,
57 | use_batch_norm: bool = True,
58 | use_layer_norm: bool = False,
59 | dropout_rate: float = 0.0,
60 | linear_decoder: bool = False,
61 | **kwargs,
62 | ):
63 | super().__init__()
64 | self.n_ouput = n_output
65 | self.linear_decoder = linear_decoder
66 | self.rho_first_decoder = FCLayers(
67 | n_in=n_input,
68 | n_out=n_hidden if not linear_decoder else n_output,
69 | n_cat_list=n_cat_list,
70 | n_layers=n_layers if not linear_decoder else 1,
71 | n_hidden=n_hidden,
72 | dropout_rate=dropout_rate,
73 | inject_covariates=inject_covariates,
74 | use_batch_norm=use_batch_norm,
75 | use_layer_norm=use_layer_norm if not linear_decoder else False,
76 | use_activation=not linear_decoder,
77 | bias=not linear_decoder,
78 | **kwargs,
79 | )
80 |
81 | self.pi_first_decoder = FCLayers(
82 | n_in=n_input,
83 | n_out=n_hidden,
84 | n_cat_list=n_cat_list,
85 | n_layers=n_layers,
86 | n_hidden=n_hidden,
87 | dropout_rate=dropout_rate,
88 | inject_covariates=inject_covariates,
89 | use_batch_norm=use_batch_norm,
90 | use_layer_norm=use_layer_norm,
91 | **kwargs,
92 | )
93 |
94 | # categorical pi
95 | # 4 states
96 | self.px_pi_decoder = nn.Linear(n_hidden, 4 * n_output)
97 |
98 | # rho for induction
99 | self.px_rho_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid())
100 |
101 | # tau for repression
102 | self.px_tau_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid())
103 |
104 | self.linear_scaling_tau = nn.Parameter(torch.zeros(n_output))
105 | self.linear_scaling_tau_intercept = nn.Parameter(torch.zeros(n_output))
106 |
107 | def forward(self, z: torch.Tensor, latent_dim: int = None):
108 | """The forward computation for a single sample.
109 |
110 | #. Decodes the data from the latent space using the decoder network
111 | #. Returns parameters for the ZINB distribution of expression
112 | #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None``
113 |
114 | Parameters
115 | ----------
116 | z :
117 | tensor with shape ``(n_input,)``
118 | cat_list
119 | list of category membership(s) for this sample
120 |
121 | Returns
122 | -------
123 | 4-tuple of :py:class:`torch.Tensor`
124 | parameters for the ZINB distribution of expression
125 |
126 | """
127 | z_in = z
128 | if latent_dim is not None:
129 | mask = torch.zeros_like(z)
130 | mask[..., latent_dim] = 1
131 | z_in = z * mask
132 | # The decoder returns values for the parameters of the ZINB distribution
133 | rho_first = self.rho_first_decoder(z_in)
134 |
135 | if not self.linear_decoder:
136 | px_rho = self.px_rho_decoder(rho_first)
137 | px_tau = self.px_tau_decoder(rho_first)
138 | else:
139 | px_rho = nn.Sigmoid()(rho_first)
140 | px_tau = 1 - nn.Sigmoid()(
141 | rho_first * self.linear_scaling_tau.exp()
142 | + self.linear_scaling_tau_intercept
143 | )
144 |
145 | # cells by genes by 4
146 | pi_first = self.pi_first_decoder(z)
147 | px_pi = nn.Softplus()(
148 | torch.reshape(self.px_pi_decoder(pi_first), (z.shape[0], self.n_ouput, 4))
149 | )
150 |
151 | return px_pi, px_rho, px_tau
152 |
153 |
154 | # VAE model
155 | class VELOVAE(BaseModuleClass):
156 | """Variational auto-encoder model.
157 |
158 | This is an implementation of the veloVI model descibed in :cite:p:`GayosoWeiler2022`
159 |
160 | Parameters
161 | ----------
162 | n_input
163 | Number of input genes
164 | n_hidden
165 | Number of nodes per hidden layer
166 | n_latent
167 | Dimensionality of the latent space
168 | n_layers
169 | Number of hidden layers used for encoder and decoder NNs
170 | dropout_rate
171 | Dropout rate for neural networks
172 | log_variational
173 | Log(data+1) prior to encoding for numerical stability. Not normalization.
174 | latent_distribution
175 | One of
176 |
177 | * ``'normal'`` - Isotropic normal
178 | * ``'ln'`` - Logistic normal with normal params N(0, 1)
179 | use_layer_norm
180 | Whether to use layer norm in layers
181 | use_observed_lib_size
182 | Use observed library size for RNA as scaling factor in mean of conditional distribution
183 | var_activation
184 | Callable used to ensure positivity of the variational distributions' variance.
185 | When `None`, defaults to `torch.exp`.
186 | """
187 |
188 | def __init__(
189 | self,
190 | n_input: int,
191 | true_time_switch: Optional[np.ndarray] = None,
192 | n_hidden: int = 128,
193 | n_latent: int = 10,
194 | n_layers: int = 1,
195 | dropout_rate: float = 0.1,
196 | log_variational: bool = False,
197 | latent_distribution: str = "normal",
198 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
199 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both",
200 | use_observed_lib_size: bool = True,
201 | var_activation: Optional[Callable] = torch.nn.Softplus(),
202 | model_steady_states: bool = True,
203 | gamma_unconstr_init: Optional[np.ndarray] = None,
204 | alpha_unconstr_init: Optional[np.ndarray] = None,
205 | alpha_1_unconstr_init: Optional[np.ndarray] = None,
206 | lambda_alpha_unconstr_init: Optional[np.ndarray] = None,
207 | switch_spliced: Optional[np.ndarray] = None,
208 | switch_unspliced: Optional[np.ndarray] = None,
209 | t_max: float = 20,
210 | penalty_scale: float = 0.2,
211 | dirichlet_concentration: float = 0.25,
212 | linear_decoder: bool = False,
213 | time_dep_transcription_rate: bool = False,
214 | ):
215 | super().__init__()
216 | self.n_latent = n_latent
217 | self.log_variational = log_variational
218 | self.latent_distribution = latent_distribution
219 | self.use_observed_lib_size = use_observed_lib_size
220 | self.n_input = n_input
221 | self.model_steady_states = model_steady_states
222 | self.t_max = t_max
223 | self.penalty_scale = penalty_scale
224 | self.dirichlet_concentration = dirichlet_concentration
225 | self.time_dep_transcription_rate = time_dep_transcription_rate
226 |
227 | if switch_spliced is not None:
228 | self.register_buffer("switch_spliced", torch.from_numpy(switch_spliced))
229 | else:
230 | self.switch_spliced = None
231 | if switch_unspliced is not None:
232 | self.register_buffer("switch_unspliced", torch.from_numpy(switch_unspliced))
233 | else:
234 | self.switch_unspliced = None
235 |
236 | n_genes = n_input * 2
237 |
238 | # switching time
239 | self.switch_time_unconstr = torch.nn.Parameter(7 + 0.5 * torch.randn(n_input))
240 | if true_time_switch is not None:
241 | self.register_buffer("true_time_switch", torch.from_numpy(true_time_switch))
242 | else:
243 | self.true_time_switch = None
244 |
245 | # degradation
246 | if gamma_unconstr_init is None:
247 | self.gamma_mean_unconstr = torch.nn.Parameter(-1 * torch.ones(n_input))
248 | else:
249 | self.gamma_mean_unconstr = torch.nn.Parameter(
250 | torch.from_numpy(gamma_unconstr_init)
251 | )
252 |
253 | # splicing
254 | # first samples around 1
255 | self.beta_mean_unconstr = torch.nn.Parameter(0.5 * torch.ones(n_input))
256 |
257 | # transcription
258 | if alpha_unconstr_init is None:
259 | self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_input))
260 | else:
261 | self.alpha_unconstr = torch.nn.Parameter(
262 | torch.from_numpy(alpha_unconstr_init)
263 | )
264 |
265 | # TODO: Add `require_grad`
266 | if alpha_1_unconstr_init is None:
267 | self.alpha_1_unconstr = torch.nn.Parameter(0 * torch.ones(n_input))
268 | else:
269 | self.alpha_1_unconstr = torch.nn.Parameter(
270 | torch.from_numpy(alpha_1_unconstr_init)
271 | )
272 | self.alpha_1_unconstr.requires_grad = time_dep_transcription_rate
273 |
274 | if lambda_alpha_unconstr_init is None:
275 | self.lambda_alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_input))
276 | else:
277 | self.lambda_alpha_unconstr = torch.nn.Parameter(
278 | torch.from_numpy(lambda_alpha_unconstr_init)
279 | )
280 | self.lambda_alpha_unconstr.requires_grad = time_dep_transcription_rate
281 |
282 | # likelihood dispersion
283 | # for now, with normal dist, this is just the variance
284 | self.scale_unconstr = torch.nn.Parameter(-1 * torch.ones(n_genes, 4))
285 |
286 | use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
287 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both"
288 | use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
289 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both"
290 | self.use_batch_norm_decoder = use_batch_norm_decoder
291 |
292 | # z encoder goes from the n_input-dimensional data to an n_latent-d
293 | # latent space representation
294 | n_input_encoder = n_genes
295 | self.z_encoder = Encoder(
296 | n_input_encoder,
297 | n_latent,
298 | n_layers=n_layers,
299 | n_hidden=n_hidden,
300 | dropout_rate=dropout_rate,
301 | distribution=latent_distribution,
302 | use_batch_norm=use_batch_norm_encoder,
303 | use_layer_norm=use_layer_norm_encoder,
304 | var_activation=var_activation,
305 | activation_fn=torch.nn.ReLU,
306 | )
307 | # decoder goes from n_latent-dimensional space to n_input-d data
308 | n_input_decoder = n_latent
309 | self.decoder = DecoderVELOVI(
310 | n_input_decoder,
311 | n_input,
312 | n_layers=n_layers,
313 | n_hidden=n_hidden,
314 | use_batch_norm=use_batch_norm_decoder,
315 | use_layer_norm=use_layer_norm_decoder,
316 | activation_fn=torch.nn.ReLU,
317 | linear_decoder=linear_decoder,
318 | )
319 |
320 | def _get_inference_input(self, tensors):
321 | spliced = tensors[REGISTRY_KEYS.X_KEY]
322 | unspliced = tensors[REGISTRY_KEYS.U_KEY]
323 |
324 | input_dict = {
325 | "spliced": spliced,
326 | "unspliced": unspliced,
327 | }
328 | return input_dict
329 |
330 | def _get_generative_input(self, tensors, inference_outputs):
331 | z = inference_outputs["z"]
332 | gamma = inference_outputs["gamma"]
333 | beta = inference_outputs["beta"]
334 | alpha = inference_outputs["alpha"]
335 | alpha_1 = inference_outputs["alpha_1"]
336 | lambda_alpha = inference_outputs["lambda_alpha"]
337 |
338 | input_dict = {
339 | "z": z,
340 | "gamma": gamma,
341 | "beta": beta,
342 | "alpha": alpha,
343 | "alpha_1": alpha_1,
344 | "lambda_alpha": lambda_alpha,
345 | }
346 | return input_dict
347 |
348 | @auto_move_data
349 | def inference(
350 | self,
351 | spliced,
352 | unspliced,
353 | n_samples=1,
354 | ):
355 | """High level inference method.
356 |
357 | Runs the inference (encoder) model.
358 | """
359 | spliced_ = spliced
360 | unspliced_ = unspliced
361 | if self.log_variational:
362 | spliced_ = torch.log(0.01 + spliced)
363 | unspliced_ = torch.log(0.01 + unspliced)
364 |
365 | encoder_input = torch.cat((spliced_, unspliced_), dim=-1)
366 |
367 | qz_m, qz_v, z = self.z_encoder(encoder_input)
368 |
369 | if n_samples > 1:
370 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
371 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
372 | # when z is normal, untran_z == z
373 | untran_z = Normal(qz_m, qz_v.sqrt()).sample()
374 | z = self.z_encoder.z_transformation(untran_z)
375 |
376 | gamma, beta, alpha, alpha_1, lambda_alpha = self._get_rates()
377 |
378 | outputs = {
379 | "z": z,
380 | "qz_m": qz_m,
381 | "qz_v": qz_v,
382 | "gamma": gamma,
383 | "beta": beta,
384 | "alpha": alpha,
385 | "alpha_1": alpha_1,
386 | "lambda_alpha": lambda_alpha,
387 | }
388 | return outputs
389 |
390 | def _get_rates(self):
391 | # globals
392 | # degradation
393 | gamma = torch.clamp(F.softplus(self.gamma_mean_unconstr), 0, 50)
394 | # splicing
395 | beta = torch.clamp(F.softplus(self.beta_mean_unconstr), 0, 50)
396 | # transcription
397 | alpha = torch.clamp(F.softplus(self.alpha_unconstr), 0, 50)
398 | if self.time_dep_transcription_rate:
399 | alpha_1 = torch.clamp(F.softplus(self.alpha_1_unconstr), 0, 50)
400 | lambda_alpha = torch.clamp(F.softplus(self.lambda_alpha_unconstr), 0, 50)
401 | else:
402 | alpha_1 = self.alpha_1_unconstr
403 | lambda_alpha = self.lambda_alpha_unconstr
404 |
405 | return gamma, beta, alpha, alpha_1, lambda_alpha
406 |
407 | @auto_move_data
408 | def generative(self, z, gamma, beta, alpha, alpha_1, lambda_alpha, latent_dim=None):
409 | """Runs the generative model."""
410 | decoder_input = z
411 | px_pi_alpha, px_rho, px_tau = self.decoder(decoder_input, latent_dim=latent_dim)
412 | px_pi = Dirichlet(px_pi_alpha).rsample()
413 |
414 | scale_unconstr = self.scale_unconstr
415 | scale = F.softplus(scale_unconstr)
416 |
417 | mixture_dist_s, mixture_dist_u, end_penalty = self.get_px(
418 | px_pi,
419 | px_rho,
420 | px_tau,
421 | scale,
422 | gamma,
423 | beta,
424 | alpha,
425 | alpha_1,
426 | lambda_alpha,
427 | )
428 |
429 | return {
430 | "px_pi": px_pi,
431 | "px_rho": px_rho,
432 | "px_tau": px_tau,
433 | "scale": scale,
434 | "px_pi_alpha": px_pi_alpha,
435 | "mixture_dist_u": mixture_dist_u,
436 | "mixture_dist_s": mixture_dist_s,
437 | "end_penalty": end_penalty,
438 | }
439 |
440 | def loss(
441 | self,
442 | tensors,
443 | inference_outputs,
444 | generative_outputs,
445 | kl_weight: float = 1.0,
446 | n_obs: float = 1.0,
447 | ):
448 | spliced = tensors[REGISTRY_KEYS.X_KEY]
449 | unspliced = tensors[REGISTRY_KEYS.U_KEY]
450 |
451 | qz_m = inference_outputs["qz_m"]
452 | qz_v = inference_outputs["qz_v"]
453 |
454 | px_pi = generative_outputs["px_pi"]
455 | px_pi_alpha = generative_outputs["px_pi_alpha"]
456 |
457 | end_penalty = generative_outputs["end_penalty"]
458 | mixture_dist_s = generative_outputs["mixture_dist_s"]
459 | mixture_dist_u = generative_outputs["mixture_dist_u"]
460 |
461 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1)
462 |
463 | reconst_loss_s = -mixture_dist_s.log_prob(spliced)
464 | reconst_loss_u = -mixture_dist_u.log_prob(unspliced)
465 |
466 | reconst_loss = reconst_loss_u.sum(dim=-1) + reconst_loss_s.sum(dim=-1)
467 |
468 | kl_pi = kl(
469 | Dirichlet(px_pi_alpha),
470 | Dirichlet(self.dirichlet_concentration * torch.ones_like(px_pi)),
471 | ).sum(dim=-1)
472 |
473 | # local loss
474 | kl_local = kl_divergence_z + kl_pi
475 | weighted_kl_local = kl_weight * (kl_divergence_z) + kl_pi
476 |
477 | local_loss = torch.mean(reconst_loss + weighted_kl_local)
478 |
479 | loss = local_loss + self.penalty_scale * (1 - kl_weight) * end_penalty
480 |
481 | loss_recoder = LossOutput(
482 | loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local
483 | )
484 |
485 | return loss_recoder
486 |
487 | @auto_move_data
488 | def get_px(
489 | self,
490 | px_pi,
491 | px_rho,
492 | px_tau,
493 | scale,
494 | gamma,
495 | beta,
496 | alpha,
497 | alpha_1,
498 | lambda_alpha,
499 | ) -> torch.Tensor:
500 | t_s = torch.clamp(F.softplus(self.switch_time_unconstr), 0, self.t_max)
501 |
502 | n_cells = px_pi.shape[0]
503 |
504 | # component dist
505 | comp_dist = Categorical(probs=px_pi)
506 |
507 | # induction
508 | mean_u_ind, mean_s_ind = self._get_induction_unspliced_spliced(
509 | alpha, alpha_1, lambda_alpha, beta, gamma, t_s * px_rho
510 | )
511 |
512 | if self.time_dep_transcription_rate:
513 | mean_u_ind_steady = (alpha_1 / beta).expand(n_cells, self.n_input)
514 | mean_s_ind_steady = (alpha_1 / gamma).expand(n_cells, self.n_input)
515 | else:
516 | mean_u_ind_steady = (alpha / beta).expand(n_cells, self.n_input)
517 | mean_s_ind_steady = (alpha / gamma).expand(n_cells, self.n_input)
518 | scale_u = scale[: self.n_input, :].expand(n_cells, self.n_input, 4).sqrt()
519 |
520 | # repression
521 | u_0, s_0 = self._get_induction_unspliced_spliced(
522 | alpha, alpha_1, lambda_alpha, beta, gamma, t_s
523 | )
524 |
525 | tau = px_tau
526 | mean_u_rep, mean_s_rep = self._get_repression_unspliced_spliced(
527 | u_0,
528 | s_0,
529 | beta,
530 | gamma,
531 | (self.t_max - t_s) * tau,
532 | )
533 | mean_u_rep_steady = torch.zeros_like(mean_u_ind)
534 | mean_s_rep_steady = torch.zeros_like(mean_u_ind)
535 | scale_s = scale[self.n_input :, :].expand(n_cells, self.n_input, 4).sqrt()
536 |
537 | end_penalty = ((u_0 - self.switch_unspliced).pow(2)).sum() + (
538 | (s_0 - self.switch_spliced).pow(2)
539 | ).sum()
540 |
541 | # unspliced
542 | mean_u = torch.stack(
543 | (
544 | mean_u_ind,
545 | mean_u_ind_steady,
546 | mean_u_rep,
547 | mean_u_rep_steady,
548 | ),
549 | dim=2,
550 | )
551 | scale_u = torch.stack(
552 | (
553 | scale_u[..., 0],
554 | scale_u[..., 0],
555 | scale_u[..., 0],
556 | 0.1 * scale_u[..., 0],
557 | ),
558 | dim=2,
559 | )
560 | dist_u = Normal(mean_u, scale_u)
561 | mixture_dist_u = MixtureSameFamily(comp_dist, dist_u)
562 |
563 | # spliced
564 | mean_s = torch.stack(
565 | (mean_s_ind, mean_s_ind_steady, mean_s_rep, mean_s_rep_steady),
566 | dim=2,
567 | )
568 | scale_s = torch.stack(
569 | (
570 | scale_s[..., 0],
571 | scale_s[..., 0],
572 | scale_s[..., 0],
573 | 0.1 * scale_s[..., 0],
574 | ),
575 | dim=2,
576 | )
577 | dist_s = Normal(mean_s, scale_s)
578 | mixture_dist_s = MixtureSameFamily(comp_dist, dist_s)
579 |
580 | return mixture_dist_s, mixture_dist_u, end_penalty
581 |
582 | def _get_induction_unspliced_spliced(
583 | self, alpha, alpha_1, lambda_alpha, beta, gamma, t, eps=1e-6
584 | ):
585 | if self.time_dep_transcription_rate:
586 | unspliced = alpha_1 / beta * (1 - torch.exp(-beta * t)) - (
587 | alpha_1 - alpha
588 | ) / (beta - lambda_alpha) * (
589 | torch.exp(-lambda_alpha * t) - torch.exp(-beta * t)
590 | )
591 |
592 | spliced = (
593 | alpha_1 / gamma * (1 - torch.exp(-gamma * t))
594 | + alpha_1
595 | / (gamma - beta + eps)
596 | * (torch.exp(-gamma * t) - torch.exp(-beta * t))
597 | - beta
598 | * (alpha_1 - alpha)
599 | / (beta - lambda_alpha + eps)
600 | / (gamma - lambda_alpha + eps)
601 | * (torch.exp(-lambda_alpha * t) - torch.exp(-gamma * t))
602 | + beta
603 | * (alpha_1 - alpha)
604 | / (beta - lambda_alpha + eps)
605 | / (gamma - beta + eps)
606 | * (torch.exp(-beta * t) - torch.exp(-gamma * t))
607 | )
608 | else:
609 | unspliced = (alpha / beta) * (1 - torch.exp(-beta * t))
610 | spliced = (alpha / gamma) * (1 - torch.exp(-gamma * t)) + (
611 | alpha / ((gamma - beta) + eps)
612 | ) * (torch.exp(-gamma * t) - torch.exp(-beta * t))
613 |
614 | return unspliced, spliced
615 |
616 | def _get_repression_unspliced_spliced(self, u_0, s_0, beta, gamma, t, eps=1e-6):
617 | unspliced = torch.exp(-beta * t) * u_0
618 | spliced = s_0 * torch.exp(-gamma * t) - (
619 | beta * u_0 / ((gamma - beta) + eps)
620 | ) * (torch.exp(-gamma * t) - torch.exp(-beta * t))
621 | return unspliced, spliced
622 |
623 | def sample(
624 | self,
625 | ) -> np.ndarray:
626 | """Not implemented."""
627 | raise NotImplementedError
628 |
629 | @torch.no_grad()
630 | def get_loadings(self) -> np.ndarray:
631 | """Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder."""
632 | # This is BW, where B is diag(b) batch norm, W is weight matrix
633 | if self.decoder.linear_decoder is False:
634 | raise ValueError("Model not trained with linear decoder")
635 | w = self.decoder.rho_first_decoder.fc_layers[0][0].weight
636 | if self.use_batch_norm_decoder:
637 | bn = self.decoder.rho_first_decoder.fc_layers[0][1]
638 | sigma = torch.sqrt(bn.running_var + bn.eps)
639 | gamma = bn.weight
640 | b = gamma / sigma
641 | b_identity = torch.diag(b)
642 | loadings = torch.matmul(b_identity, w)
643 | else:
644 | loadings = w
645 | loadings = loadings.detach().cpu().numpy()
646 |
647 | return loadings
648 |
--------------------------------------------------------------------------------
/velovi/_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, Union
3 | from urllib.request import urlretrieve
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import scvelo as scv
8 | from anndata import AnnData
9 | from sklearn.preprocessing import MinMaxScaler
10 |
11 |
12 | def get_permutation_scores(save_path: Union[str, Path] = Path("data/")) -> pd.DataFrame:
13 | """Get the reference permutation scores on positive and negative controls.
14 |
15 | Parameters
16 | ----------
17 | save_path
18 | path to save the csv file
19 |
20 | """
21 | if isinstance(save_path, str):
22 | save_path = Path(save_path)
23 | save_path.mkdir(parents=True, exist_ok=True)
24 |
25 | if not (save_path / "permutation_scores.csv").is_file():
26 | URL = "https://figshare.com/ndownloader/files/36658185"
27 | urlretrieve(url=URL, filename=save_path / "permutation_scores.csv")
28 |
29 | return pd.read_csv(save_path / "permutation_scores.csv")
30 |
31 |
32 | def preprocess_data(
33 | adata: AnnData,
34 | spliced_layer: Optional[str] = "Ms",
35 | unspliced_layer: Optional[str] = "Mu",
36 | min_max_scale: bool = True,
37 | filter_on_r2: bool = True,
38 | ) -> AnnData:
39 | """Preprocess data.
40 |
41 | This function removes poorly detected genes and minmax scales the data.
42 |
43 | Parameters
44 | ----------
45 | adata
46 | Annotated data matrix.
47 | spliced_layer
48 | Name of the spliced layer.
49 | unspliced_layer
50 | Name of the unspliced layer.
51 | min_max_scale
52 | Min-max scale spliced and unspliced
53 | filter_on_r2
54 | Filter out genes according to linear regression fit
55 |
56 | Returns
57 | -------
58 | Preprocessed adata.
59 | """
60 | if min_max_scale:
61 | scaler = MinMaxScaler()
62 | adata.layers[spliced_layer] = scaler.fit_transform(adata.layers[spliced_layer])
63 |
64 | scaler = MinMaxScaler()
65 | adata.layers[unspliced_layer] = scaler.fit_transform(
66 | adata.layers[unspliced_layer]
67 | )
68 |
69 | if filter_on_r2:
70 | scv.tl.velocity(adata, mode="deterministic")
71 |
72 | adata = adata[
73 | :, np.logical_and(adata.var.velocity_r2 > 0, adata.var.velocity_gamma > 0)
74 | ].copy()
75 | adata = adata[:, adata.var.velocity_genes].copy()
76 |
77 | return adata
78 |
--------------------------------------------------------------------------------