├── .editorconfig ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── _templates │ ├── autosummary │ │ └── class.rst │ └── class_no_inherited.rst ├── api │ └── index.md ├── conf.py ├── extensions │ └── typed_returns.py ├── index.md ├── make.bat ├── references.bib ├── references.md ├── release_notes │ ├── index.rst │ └── v0.1.0.rst └── tutorial.ipynb ├── pyproject.toml ├── readthedocs.yml ├── setup.py ├── tests ├── __init__.py └── test_velovi.py └── velovi ├── __init__.py ├── _constants.py ├── _model.py ├── _module.py └── _utils.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*.*.*" 7 | 8 | jobs: 9 | release: 10 | name: Release 11 | runs-on: ubuntu-latest 12 | steps: 13 | # will use ref/SHA that triggered it 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up Python 3.10 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: "3.9" 21 | 22 | - name: Install poetry 23 | uses: abatilo/actions-poetry@v2.0.0 24 | with: 25 | poetry-version: 1.4.2 26 | 27 | - name: Build project for distribution 28 | run: poetry build 29 | 30 | - name: Check Version 31 | id: check-version 32 | run: | 33 | [[ "$(poetry version --short)" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]] \ 34 | || echo ::set-output name=prerelease::true 35 | 36 | - name: Publish to PyPI 37 | env: 38 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} 39 | run: poetry publish 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: velovi 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11"] 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Cache pip 26 | uses: actions/cache@v2 27 | with: 28 | path: ~/.cache/pip 29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 30 | restore-keys: | 31 | ${{ runner.os }}-pip- 32 | - name: Install dependencies 33 | run: | 34 | pip install pytest-cov 35 | pip install .[dev] 36 | - name: Test with pytest 37 | run: | 38 | pytest --cov-report=xml --cov=velovi 39 | - name: After success 40 | run: | 41 | bash <(curl -s https://codecov.io/bash) 42 | pip list 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # DS_Store 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # vscode 135 | .vscode/settings.json 136 | docs/api/reference/ 137 | *.h5ad 138 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - commit 6 | - push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/psf/black 10 | rev: "23.1.0" 11 | hooks: 12 | - id: black 13 | - repo: https://github.com/asottile/blacken-docs 14 | rev: 1.13.0 15 | hooks: 16 | - id: blacken-docs 17 | - repo: https://github.com/pre-commit/mirrors-prettier 18 | rev: v3.0.0-alpha.6 19 | hooks: 20 | - id: prettier 21 | # Newer versions of node don't work on systems that have an older version of GLIBC 22 | # (in particular Ubuntu 18.04 and Centos 7) 23 | # EOL of Centos 7 is in 2024-06, we can probably get rid of this then. 24 | # See https://github.com/scverse/cookiecutter-scverse/issues/143 and 25 | # https://github.com/jupyterlab/jupyterlab/issues/12675 26 | language_version: "17.9.1" 27 | - repo: https://github.com/charliermarsh/ruff-pre-commit 28 | rev: v0.0.254 29 | hooks: 30 | - id: ruff 31 | args: [--fix, --exit-non-zero-on-fix] 32 | - repo: https://github.com/pre-commit/pre-commit-hooks 33 | rev: v4.4.0 34 | hooks: 35 | - id: detect-private-key 36 | - id: check-ast 37 | - id: end-of-file-fixer 38 | - id: mixed-line-ending 39 | args: [--fix=lf] 40 | - id: trailing-whitespace 41 | - id: check-case-conflict 42 | - repo: local 43 | hooks: 44 | - id: forbid-to-commit 45 | name: Don't commit rej files 46 | entry: | 47 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 48 | Fix the merge conflicts manually and remove the .rej files. 49 | language: fail 50 | files: '.*\.rej$' 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Yosef Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # velovi 2 | 3 | [![Tests][badge-tests]][link-tests] 4 | [![Documentation][badge-docs]][link-docs] 5 | 6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/yoseflab/velovi/test.yml?branch=main 7 | [link-tests]: https://github.com/yoseflab/velovi/actions/workflows/test.yml 8 | [badge-docs]: https://img.shields.io/readthedocs/velovi 9 | 10 | 🚧 :warning: This package is no longer being actively developed or maintained. Please use the 11 | [scvi-tools](https://github.com/scverse/scvi-tools) package instead. See this 12 | [thread](https://github.com/scverse/scvi-tools/issues/2610) for more details. :warning: 🚧 13 | 14 | Variational inference for RNA velocity. This is an experimental repo for the veloVI model. Installation instructions and tutorials are in the docs. 15 | 16 | ## Getting started 17 | 18 | Please refer to the [documentation][link-docs]. 19 | 20 | ## Installation 21 | 22 | You need to have Python 3.8 or newer installed on your system. If you don't have 23 | Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 24 | 25 | There are several alternative options to install velovi: 26 | 27 | 34 | 35 | 1. Install the latest release on PyPI: 36 | 37 | ```bash 38 | pip install velovi 39 | ``` 40 | 41 | 2. Install the latest development version: 42 | 43 | ```bash 44 | pip install git+https://github.com/yoseflab/velovi.git@main 45 | ``` 46 | 47 | ## Release notes 48 | 49 | See the [changelog][changelog]. 50 | 51 | ## Contact 52 | 53 | For questions and help requests, you can reach out in the [scverse discourse][scverse-discourse]. 54 | If you found a bug, please use the [issue tracker][issue-tracker]. 55 | 56 | ## Citation 57 | 58 | ``` 59 | @article{gayoso2022deep, 60 | title={Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells}, 61 | author={Gayoso, Adam and Weiler, Philipp and Lotfollahi, Mohammad and Klein, Dominik and Hong, Justin and Streets, Aaron M and Theis, Fabian J and Yosef, Nir}, 62 | journal={bioRxiv}, 63 | pages={2022--08}, 64 | year={2022}, 65 | publisher={Cold Spring Harbor Laboratory} 66 | } 67 | ``` 68 | 69 | [scverse-discourse]: https://discourse.scverse.org/ 70 | [issue-tracker]: https://github.com/yoseflab/velovi/issues 71 | [changelog]: https://velovi.readthedocs.io/latest/changelog.html 72 | [link-docs]: https://velovi.readthedocs.io 73 | [link-api]: https://velovi.readthedocs.io/latest/api.html 74 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Run to check if valid 2 | # curl --data-binary @codecov.yml https://codecov.io/validate 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: 80% 8 | threshold: 1% 9 | patch: off 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = scvi 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/_templates/class_no_inherited.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | :show-inheritance: 9 | 10 | {% block attributes %} 11 | {% if attributes %} 12 | Attributes table 13 | ~~~~~~~~~~~~~~~~ 14 | 15 | .. autosummary:: 16 | {% for item in attributes %} 17 | {%- if item not in inherited_members%} 18 | ~{{ fullname }}.{{ item }} 19 | {%- endif -%} 20 | {%- endfor %} 21 | {% endif %} 22 | {% endblock %} 23 | 24 | 25 | {% block methods %} 26 | {% if methods %} 27 | Methods table 28 | ~~~~~~~~~~~~~~ 29 | 30 | .. autosummary:: 31 | {% for item in methods %} 32 | {%- if item != '__init__' and item not in inherited_members%} 33 | ~{{ fullname }}.{{ item }} 34 | {%- endif -%} 35 | 36 | {%- endfor %} 37 | {% endif %} 38 | {% endblock %} 39 | 40 | {% block attributes_documentation %} 41 | {% if attributes %} 42 | Attributes 43 | ~~~~~~~~~~ 44 | 45 | {% for item in attributes %} 46 | {%- if item not in inherited_members%} 47 | 48 | .. autoattribute:: {{ [objname, item] | join(".") }} 49 | {%- endif -%} 50 | {%- endfor %} 51 | 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block methods_documentation %} 56 | {% if methods %} 57 | Methods 58 | ~~~~~~~ 59 | 60 | {% for item in methods %} 61 | {%- if item != '__init__' and item not in inherited_members%} 62 | 63 | .. automethod:: {{ [objname, item] | join(".") }} 64 | {%- endif -%} 65 | {%- endfor %} 66 | 67 | {% endif %} 68 | {% endblock %} 69 | -------------------------------------------------------------------------------- /docs/api/index.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: velovi 5 | ``` 6 | 7 | ```{eval-rst} 8 | .. autosummary:: 9 | :toctree: reference/ 10 | :nosignatures: 11 | 12 | VELOVI 13 | ``` 14 | 15 | ```{eval-rst} 16 | .. autosummary:: 17 | :toctree: reference/ 18 | :template: class_no_inherited.rst 19 | :nosignatures: 20 | 21 | VELOVAE 22 | ``` 23 | 24 | ```{eval-rst} 25 | .. autosummary:: 26 | :toctree: reference/ 27 | :nosignatures: 28 | 29 | get_permutation_scores 30 | ``` 31 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import subprocess 3 | import os 4 | import importlib 5 | import inspect 6 | import re 7 | import sys 8 | from datetime import datetime 9 | from importlib.metadata import metadata 10 | from pathlib import Path 11 | 12 | HERE = Path(__file__).parent 13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] 14 | 15 | 16 | # -- Project information ----------------------------------------------------- 17 | 18 | project_name = "velovi" 19 | info = metadata(project_name) 20 | package_name = "velovi" 21 | author = info["Author"] 22 | copyright = f"{datetime.now():%Y}, {author}." 23 | version = info["Version"] 24 | repository_url = f"https://github.com/yoseflab/{project_name}" 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = info["Version"] 28 | 29 | bibtex_bibfiles = ["references.bib"] 30 | templates_path = ["_templates"] 31 | nitpicky = True # Warn about broken links 32 | needs_sphinx = "4.0" 33 | 34 | html_context = { 35 | "display_github": True, # Integrate GitHub 36 | "github_user": "yoseflab", # Username 37 | "github_repo": project_name, # Repo name 38 | "github_version": "main", # Version 39 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 40 | } 41 | 42 | # -- General configuration --------------------------------------------------- 43 | 44 | # Add any Sphinx extension module names here, as strings. 45 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 46 | extensions = [ 47 | "myst_nb", 48 | "sphinx.ext.autodoc", 49 | "sphinx.ext.linkcode", 50 | "sphinx.ext.intersphinx", 51 | "sphinx.ext.autosummary", 52 | "sphinx.ext.napoleon", 53 | "sphinxcontrib.bibtex", 54 | "sphinx.ext.mathjax", 55 | "sphinx.ext.extlinks", 56 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 57 | "sphinx_copybutton", 58 | ] 59 | 60 | autosummary_generate = True 61 | autodoc_member_order = "bysource" 62 | default_role = "literal" 63 | autodoc_typehints = "description" 64 | bibtex_reference_style = "author_year" 65 | napoleon_google_docstring = True 66 | napoleon_numpy_docstring = True 67 | napoleon_include_init_with_doc = False 68 | napoleon_use_rtype = True # having a separate entry generally helps readability 69 | napoleon_use_param = True 70 | myst_enable_extensions = [ 71 | "amsmath", 72 | "colon_fence", 73 | "deflist", 74 | "dollarmath", 75 | "html_image", 76 | "html_admonition", 77 | ] 78 | myst_url_schemes = ("http", "https", "mailto") 79 | nb_output_stderr = "remove" 80 | nb_execution_mode = "off" 81 | nb_merge_streams = True 82 | 83 | source_suffix = { 84 | ".rst": "restructuredtext", 85 | ".ipynb": "myst-nb", 86 | ".myst": "myst-nb", 87 | } 88 | 89 | intersphinx_mapping = { 90 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 91 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None), 92 | "matplotlib": ("https://matplotlib.org/", None), 93 | "numpy": ("https://numpy.org/doc/stable/", None), 94 | "pandas": ("https://pandas.pydata.org/docs/", None), 95 | "python": ("https://docs.python.org/3", None), 96 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), 97 | "sklearn": ("https://scikit-learn.org/stable/", None), 98 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), 99 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 100 | "torch": ("https://pytorch.org/docs/master/", None), 101 | "plottable": ("https://plottable.readthedocs.io/en/latest/", None), 102 | "scvi-tools": ("https://docs.scvi-tools.org/en/stable/", None), 103 | "mudata": ("https://mudata.readthedocs.io/en/latest/", None), 104 | } 105 | 106 | # List of patterns, relative to source directory, that match files and 107 | # directories to ignore when looking for source files. 108 | # This pattern also affects html_static_path and html_extra_path. 109 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 110 | 111 | # extlinks config 112 | extlinks = { 113 | "issue": (f"{repository_url}/issues/%s", "#%s"), 114 | "pr": (f"{repository_url}/pull/%s", "#%s"), 115 | "ghuser": ("https://github.com/%s", "@%s"), 116 | } 117 | 118 | 119 | # -- Linkcode settings ------------------------------------------------- 120 | 121 | 122 | def git(*args): 123 | """Run a git command and return the output.""" 124 | return subprocess.check_output(["git", *args]).strip().decode() 125 | 126 | 127 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192 128 | # Current git reference. Uses branch/tag name if found, otherwise uses commit hash 129 | git_ref = None 130 | try: 131 | git_ref = git("name-rev", "--name-only", "--no-undefined", "HEAD") 132 | git_ref = re.sub(r"^(remotes/[^/]+|tags)/", "", git_ref) 133 | except Exception: 134 | pass 135 | 136 | # (if no name found or relative ref, use commit hash instead) 137 | if not git_ref or re.search(r"[\^~]", git_ref): 138 | try: 139 | git_ref = git("rev-parse", "HEAD") 140 | except Exception: 141 | git_ref = "main" 142 | 143 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192 144 | github_repo = "https://github.com/" + html_context["github_user"] + "/" + project_name 145 | _project_module_path = os.path.dirname(importlib.util.find_spec(package_name).origin) # type: ignore 146 | 147 | 148 | def linkcode_resolve(domain, info): 149 | """Resolve links for the linkcode extension.""" 150 | if domain != "py": 151 | return None 152 | 153 | try: 154 | obj: Any = sys.modules[info["module"]] 155 | for part in info["fullname"].split("."): 156 | obj = getattr(obj, part) 157 | obj = inspect.unwrap(obj) 158 | 159 | if isinstance(obj, property): 160 | obj = inspect.unwrap(obj.fget) # type: ignore 161 | 162 | path = os.path.relpath(inspect.getsourcefile(obj), start=_project_module_path) # type: ignore 163 | src, lineno = inspect.getsourcelines(obj) 164 | except Exception: 165 | return None 166 | 167 | path = f"{path}#L{lineno}-L{lineno + len(src) - 1}" 168 | return f"{github_repo}/blob/{git_ref}/{package_name}/{path}" 169 | 170 | 171 | # -- Options for HTML output ------------------------------------------------- 172 | 173 | # The theme to use for HTML and HTML Help pages. See the documentation for 174 | # a list of builtin themes. 175 | # 176 | html_theme = "sphinx_book_theme" 177 | html_static_path = ["_static"] 178 | html_title = "velovi" 179 | 180 | html_theme_options = { 181 | "repository_url": github_repo, 182 | "use_repository_button": True, 183 | } 184 | 185 | pygments_style = "default" 186 | 187 | nitpick_ignore = [ 188 | # If building the documentation fails because of a missing link that is outside your control, 189 | # you can add an exception to this list. 190 | ] 191 | 192 | 193 | def setup(app): 194 | """App setup hook.""" 195 | app.add_config_value( 196 | "recommonmark_config", 197 | { 198 | "auto_toc_tree_section": "Contents", 199 | "enable_auto_toc_tree": True, 200 | "enable_math": True, 201 | "enable_inline_math": False, 202 | "enable_eval_rst": True, 203 | }, 204 | True, 205 | ) 206 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | import re 4 | 5 | from sphinx.application import Sphinx 6 | from sphinx.ext.napoleon import NumpyDocstring 7 | 8 | 9 | def process_return(lines): 10 | """Process the return section of a docstring.""" 11 | for line in lines: 12 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) 13 | if m: 14 | # Once this is in scanpydoc, we can use the fancy hover stuff 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def scanpy_parse_returns_section(self, section): 21 | """Parse the returns section of the docstring.""" 22 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section()))) 23 | lines = self._format_block(":returns: ", lines_raw) 24 | if lines and lines[-1]: 25 | lines.append("") 26 | return lines 27 | 28 | 29 | def setup(app: Sphinx): 30 | """Setup the extension.""" 31 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section 32 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | Welcome to the velovi documentation. 6 | 7 | ```{toctree} 8 | :hidden: true 9 | :maxdepth: 1 10 | 11 | api/index 12 | tutorial 13 | release_notes/index 14 | references 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=scvi 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{GayosoWeiler2022, 2 | title={Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells}, 3 | author={Gayoso, Adam and Weiler, Philipp and Lotfollahi, Mohammad and Klein, Dominik and Hong, Justin and Streets, Aaron M and Theis, Fabian J and Yosef, Nir}, 4 | journal={bioRxiv}, 5 | pages={2022--08}, 6 | year={2022}, 7 | publisher={Cold Spring Harbor Laboratory} 8 | } 9 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/release_notes/index.rst: -------------------------------------------------------------------------------- 1 | Release notes 2 | ============= 3 | 4 | Version 0.1 5 | ----------- 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | v0.1.0 10 | -------------------------------------------------------------------------------- /docs/release_notes/v0.1.0.rst: -------------------------------------------------------------------------------- 1 | New in 0.1.0 (2022-MM-DD) 2 | ------------------------- 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.0.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | authors = ["Adam Gayoso "] 7 | classifiers = [ 8 | "Development Status :: 4 - Beta", 9 | "Intended Audience :: Science/Research", 10 | "Natural Language :: English", 11 | "Programming Language :: Python :: 3.9", 12 | "Programming Language :: Python :: 3.10", 13 | "Programming Language :: Python :: 3.11", 14 | "Operating System :: MacOS :: MacOS X", 15 | "Operating System :: Microsoft :: Windows", 16 | "Operating System :: POSIX :: Linux", 17 | "Topic :: Scientific/Engineering :: Bio-Informatics", 18 | ] 19 | description = "Estimation of RNA velocity with variational inference." 20 | documentation = "https://scvi-tools.org" 21 | homepage = "https://github.com/YosefLab/velovi" 22 | license = "BSD-3-Clause" 23 | name = "velovi" 24 | packages = [ 25 | {include = "velovi"}, 26 | ] 27 | readme = "README.md" 28 | version = "0.4.0" 29 | 30 | [tool.poetry.dependencies] 31 | anndata = ">=0.7.5" 32 | black = {version = ">=20.8b1", optional = true} 33 | codecov = {version = ">=2.0.8", optional = true} 34 | ruff = {version = "*", optional = true} 35 | importlib-metadata = {version = "^1.0", python = "<3.8"} 36 | ipython = {version = ">=7.1.1", optional = true} 37 | jupyter = {version = ">=1.0", optional = true} 38 | pre-commit = {version = ">=2.7.1", optional = true} 39 | sphinx-book-theme = {version = ">=1.0.0", optional = true} 40 | myst-nb = {version = "*", optional = true} 41 | sphinx-copybutton = {version = "*", optional = true} 42 | sphinxcontrib-bibtex = {version = "*", optional = true} 43 | ipykernel = {version = "*", optional = true} 44 | pytest = {version = ">=4.4", optional = true} 45 | pytest-cov = {version = "*", optional = true} 46 | python = ">=3.9,<4.0" 47 | python-igraph = {version = "*", optional = true} 48 | scanpy = {version = ">=1.6", optional = true} 49 | scanpydoc = {version = ">=0.5", optional = true} 50 | scvelo = ">=0.2.5" 51 | scvi-tools = ">=1.0.0" 52 | scikit-learn = ">=0.21.2" 53 | sphinx = {version = ">=4.1", optional = true} 54 | sphinx-autodoc-typehints = {version = "*", optional = true} 55 | 56 | [tool.poetry.extras] 57 | dev = ["black", "pytest", "pytest-cov", "ruff", "codecov", "scanpy", "loompy", "jupyter", "pre-commit"] 58 | docs = [ 59 | "sphinx", 60 | "scanpydoc", 61 | "ipython", 62 | "myst-nb", 63 | "sphinx-book-theme", 64 | "sphinx-copybutton", 65 | "sphinxcontrib-bibtex", 66 | "ipykernel", 67 | "ipython", 68 | ] 69 | tutorials = ["scanpy"] 70 | 71 | [tool.poetry.dev-dependencies] 72 | 73 | 74 | [tool.coverage.run] 75 | source = ["velovi"] 76 | omit = [ 77 | "**/test_*.py", 78 | ] 79 | 80 | [tool.pytest.ini_options] 81 | testpaths = ["tests"] 82 | xfail_strict = true 83 | 84 | 85 | [tool.black] 86 | include = '\.pyi?$' 87 | exclude = ''' 88 | ( 89 | /( 90 | \.eggs 91 | | \.git 92 | | \.hg 93 | | \.mypy_cache 94 | | \.tox 95 | | \.venv 96 | | _build 97 | | buck-out 98 | | build 99 | | dist 100 | )/ 101 | ) 102 | ''' 103 | 104 | [tool.ruff] 105 | src = ["."] 106 | line-length = 119 107 | target-version = "py38" 108 | select = [ 109 | "F", # Errors detected by Pyflakes 110 | "E", # Error detected by Pycodestyle 111 | "W", # Warning detected by Pycodestyle 112 | "I", # isort 113 | "D", # pydocstyle 114 | "B", # flake8-bugbear 115 | "TID", # flake8-tidy-imports 116 | "C4", # flake8-comprehensions 117 | "BLE", # flake8-blind-except 118 | "UP", # pyupgrade 119 | "RUF100", # Report unused noqa directives 120 | ] 121 | ignore = [ 122 | # line too long -> we accept long comment lines; black gets rid of long code lines 123 | "E501", 124 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 125 | "E731", 126 | # allow I, O, l as variable names -> I is the identity matrix 127 | "E741", 128 | # Missing docstring in public package 129 | "D104", 130 | # Missing docstring in public module 131 | "D100", 132 | # Missing docstring in __init__ 133 | "D107", 134 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 135 | "B008", 136 | # __magic__ methods are are often self-explanatory, allow missing docstrings 137 | "D105", 138 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 139 | "D400", 140 | # First line should be in imperative mood; try rephrasing 141 | "D401", 142 | ## Disable one in each pair of mutually incompatible rules 143 | # We don’t want a blank line before a class docstring 144 | "D203", 145 | # We want docstrings to start immediately after the opening triple quote 146 | "D213", 147 | # Missing argument description in the docstring TODO: enable 148 | "D417", 149 | ] 150 | 151 | [tool.ruff.per-file-ignores] 152 | "docs/*" = ["I", "BLE001"] 153 | "tests/*" = ["D"] 154 | "*/__init__.py" = ["F401"] 155 | "velovi/__init__.py" = ["I"] 156 | 157 | [tool.jupytext] 158 | formats = "ipynb,md" 159 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-20.04 4 | tools: 5 | python: "3.10" 6 | sphinx: 7 | configuration: docs/conf.py 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - docs 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry 4 | 5 | import setuptools 6 | 7 | if __name__ == "__main__": 8 | setuptools.setup(name="velovi") 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/velovi/63de3f0b5c2588da056b8e95018f4f136255935d/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_velovi.py: -------------------------------------------------------------------------------- 1 | import scvelo as scv 2 | from scvi.data import synthetic_iid 3 | 4 | 5 | def test_preprocess_data(): 6 | adata = synthetic_iid() 7 | adata.layers["spliced"] = adata.X.copy() 8 | adata.layers["unspliced"] = adata.X.copy() 9 | scv.pp.normalize_per_cell(adata) 10 | scv.pp.log1p(adata) 11 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30) 12 | # TODO(adamgayoso): use real data for this test 13 | # preprocess_data(adata) 14 | 15 | 16 | def test_velovi(): 17 | from velovi import VELOVI 18 | 19 | n_latent = 5 20 | adata = synthetic_iid() 21 | adata.layers["spliced"] = adata.X.copy() 22 | adata.layers["unspliced"] = adata.X.copy() 23 | VELOVI.setup_anndata(adata, unspliced_layer="unspliced", spliced_layer="spliced") 24 | model = VELOVI(adata, n_latent=n_latent) 25 | model.train(1, check_val_every_n_epoch=1, train_size=0.5) 26 | model.get_latent_representation() 27 | model.get_velocity() 28 | model.get_latent_time() 29 | model.get_state_assignment() 30 | model.get_expression_fit() 31 | model.get_directional_uncertainty() 32 | model.get_permutation_scores(labels_key="labels") 33 | 34 | model.history 35 | 36 | # tests __repr__ 37 | print(model) 38 | -------------------------------------------------------------------------------- /velovi/__init__.py: -------------------------------------------------------------------------------- 1 | """velovi.""" 2 | 3 | import logging 4 | import warnings 5 | 6 | # warning has to be at the top level to print on import 7 | warnings.warn( 8 | "The velovi package is no longer being actively developed or maintained as of v0.4.0. Please " 9 | "use the implementation in the scvi-tools package instead. For more information, see " 10 | "https://github.com/scverse/scvi-tools/issues/2610.", 11 | UserWarning, 12 | stacklevel=1, 13 | ) 14 | 15 | from rich.console import Console # noqa 16 | from rich.logging import RichHandler # noqa 17 | 18 | from ._constants import REGISTRY_KEYS # noqa 19 | from ._model import VELOVI, VELOVAE # noqa 20 | from ._utils import get_permutation_scores, preprocess_data # noqa 21 | 22 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 23 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 24 | try: 25 | import importlib.metadata as importlib_metadata 26 | except ModuleNotFoundError: 27 | import importlib_metadata 28 | 29 | package_name = "velovi" 30 | __version__ = importlib_metadata.version(package_name) 31 | 32 | logger = logging.getLogger(__name__) 33 | # set the logging level 34 | logger.setLevel(logging.INFO) 35 | 36 | # nice logging outputs 37 | console = Console(force_terminal=True) 38 | if console.is_jupyter is True: 39 | console.is_jupyter = False 40 | ch = RichHandler(show_path=False, console=console, show_time=False) 41 | formatter = logging.Formatter("velovi: %(message)s") 42 | ch.setFormatter(formatter) 43 | logger.addHandler(ch) 44 | 45 | # this prevents double outputs 46 | logger.propagate = False 47 | 48 | 49 | __all__ = [ 50 | "VELOVI", 51 | "VELOVAE", 52 | "REGISTRY_KEYS", 53 | "get_permutation_scores", 54 | "preprocess_data", 55 | ] 56 | -------------------------------------------------------------------------------- /velovi/_constants.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class _REGISTRY_KEYS_NT(NamedTuple): 5 | X_KEY: str = "X" 6 | U_KEY: str = "U" 7 | 8 | 9 | REGISTRY_KEYS = _REGISTRY_KEYS_NT() 10 | -------------------------------------------------------------------------------- /velovi/_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import Iterable, List, Literal, Optional, Sequence, Tuple, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.nn.functional as F 9 | from anndata import AnnData 10 | from joblib import Parallel, delayed 11 | from scipy.stats import ttest_ind 12 | from scvi.data import AnnDataManager 13 | from scvi.data.fields import LayerField 14 | from scvi.dataloaders import DataSplitter 15 | from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin 16 | from scvi.train import TrainingPlan, TrainRunner 17 | from scvi.utils._docstrings import devices_dsp, setup_anndata_dsp 18 | 19 | from ._constants import REGISTRY_KEYS 20 | from ._module import VELOVAE 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def _softplus_inverse(x: np.ndarray) -> np.ndarray: 26 | x = torch.from_numpy(x) 27 | x_inv = torch.where(x > 20, x, x.expm1().log()).numpy() 28 | return x_inv 29 | 30 | 31 | class VELOVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): 32 | """Velocity Variational Inference. 33 | 34 | Parameters 35 | ---------- 36 | adata 37 | AnnData object that has been registered via :func:`~velovi.VELOVI.setup_anndata`. 38 | n_hidden 39 | Number of nodes per hidden layer. 40 | n_latent 41 | Dimensionality of the latent space. 42 | n_layers 43 | Number of hidden layers used for encoder and decoder NNs. 44 | dropout_rate 45 | Dropout rate for neural networks. 46 | gamma_init_data 47 | Initialize gamma using the data-driven technique. 48 | linear_decoder 49 | Use a linear decoder from latent space to time. 50 | **model_kwargs 51 | Keyword args for :class:`~velovi.VELOVAE` 52 | """ 53 | 54 | def __init__( 55 | self, 56 | adata: AnnData, 57 | n_hidden: int = 256, 58 | n_latent: int = 10, 59 | n_layers: int = 1, 60 | dropout_rate: float = 0.1, 61 | gamma_init_data: bool = False, 62 | linear_decoder: bool = False, 63 | **model_kwargs, 64 | ): 65 | super().__init__(adata) 66 | self.n_latent = n_latent 67 | 68 | spliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) 69 | unspliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY) 70 | 71 | sorted_unspliced = np.argsort(unspliced, axis=0) 72 | ind = int(adata.n_obs * 0.99) 73 | us_upper_ind = sorted_unspliced[ind:, :] 74 | 75 | us_upper = [] 76 | ms_upper = [] 77 | for i in range(len(us_upper_ind)): 78 | row = us_upper_ind[i] 79 | us_upper += [unspliced[row, np.arange(adata.n_vars)][np.newaxis, :]] 80 | ms_upper += [spliced[row, np.arange(adata.n_vars)][np.newaxis, :]] 81 | us_upper = np.median(np.concatenate(us_upper, axis=0), axis=0) 82 | ms_upper = np.median(np.concatenate(ms_upper, axis=0), axis=0) 83 | 84 | alpha_unconstr = _softplus_inverse(us_upper) 85 | alpha_unconstr = np.asarray(alpha_unconstr).ravel() 86 | 87 | alpha_1_unconstr = np.zeros(us_upper.shape).ravel() 88 | lambda_alpha_unconstr = np.zeros(us_upper.shape).ravel() 89 | 90 | if gamma_init_data: 91 | gamma_unconstr = np.clip(_softplus_inverse(us_upper / ms_upper), None, 10) 92 | else: 93 | gamma_unconstr = None 94 | 95 | self.module = VELOVAE( 96 | n_input=self.summary_stats["n_vars"], 97 | n_hidden=n_hidden, 98 | n_latent=n_latent, 99 | n_layers=n_layers, 100 | dropout_rate=dropout_rate, 101 | gamma_unconstr_init=gamma_unconstr, 102 | alpha_unconstr_init=alpha_unconstr, 103 | alpha_1_unconstr_init=alpha_1_unconstr, 104 | lambda_alpha_unconstr_init=lambda_alpha_unconstr, 105 | switch_spliced=ms_upper, 106 | switch_unspliced=us_upper, 107 | linear_decoder=linear_decoder, 108 | **model_kwargs, 109 | ) 110 | self._model_summary_string = ( 111 | "VELOVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: " 112 | "{}" 113 | ).format( 114 | n_hidden, 115 | n_latent, 116 | n_layers, 117 | dropout_rate, 118 | ) 119 | self.init_params_ = self._get_init_params(locals()) 120 | 121 | @devices_dsp.dedent 122 | def train( 123 | self, 124 | max_epochs: Optional[int] = 500, 125 | lr: float = 1e-2, 126 | weight_decay: float = 1e-2, 127 | accelerator: str = "auto", 128 | devices: Union[int, list[int], str] = "auto", 129 | train_size: float = 0.9, 130 | validation_size: Optional[float] = None, 131 | batch_size: int = 256, 132 | early_stopping: bool = True, 133 | gradient_clip_val: float = 10, 134 | plan_kwargs: Optional[dict] = None, 135 | **trainer_kwargs, 136 | ): 137 | """Train the model. 138 | 139 | Parameters 140 | ---------- 141 | max_epochs 142 | Number of passes through the dataset. If `None`, defaults to 143 | `np.min([round((20000 / n_cells) * 400), 400])` 144 | lr 145 | Learning rate for optimization 146 | weight_decay 147 | Weight decay for optimization 148 | %(param_accelerator)s 149 | %(param_devices)s 150 | train_size 151 | Size of training set in the range [0.0, 1.0]. 152 | validation_size 153 | Size of the test set. If `None`, defaults to 1 - `train_size`. If 154 | `train_size + validation_size < 1`, the remaining cells belong to a test set. 155 | batch_size 156 | Minibatch size to use during training. 157 | early_stopping 158 | Perform early stopping. Additional arguments can be passed in `**kwargs`. 159 | See :class:`~scvi.train.Trainer` for further options. 160 | gradient_clip_val 161 | Val for gradient clipping 162 | plan_kwargs 163 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to 164 | `train()` will overwrite values present in `plan_kwargs`, when appropriate. 165 | **trainer_kwargs 166 | Other keyword args for :class:`~scvi.train.Trainer`. 167 | """ 168 | user_plan_kwargs = plan_kwargs.copy() if isinstance(plan_kwargs, dict) else {} 169 | plan_kwargs = {"lr": lr, "weight_decay": weight_decay, "optimizer": "AdamW"} 170 | plan_kwargs.update(user_plan_kwargs) 171 | 172 | user_train_kwargs = trainer_kwargs.copy() 173 | trainer_kwargs = {"gradient_clip_val": gradient_clip_val} 174 | trainer_kwargs.update(user_train_kwargs) 175 | 176 | data_splitter = DataSplitter( 177 | self.adata_manager, 178 | train_size=train_size, 179 | validation_size=validation_size, 180 | batch_size=batch_size, 181 | ) 182 | training_plan = TrainingPlan(self.module, **plan_kwargs) 183 | 184 | es = "early_stopping" 185 | trainer_kwargs[es] = ( 186 | early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] 187 | ) 188 | runner = TrainRunner( 189 | self, 190 | training_plan=training_plan, 191 | data_splitter=data_splitter, 192 | max_epochs=max_epochs, 193 | accelerator=accelerator, 194 | devices=devices, 195 | **trainer_kwargs, 196 | ) 197 | return runner() 198 | 199 | @torch.inference_mode() 200 | def get_state_assignment( 201 | self, 202 | adata: Optional[AnnData] = None, 203 | indices: Optional[Sequence[int]] = None, 204 | gene_list: Optional[Sequence[str]] = None, 205 | hard_assignment: bool = False, 206 | n_samples: int = 20, 207 | batch_size: Optional[int] = None, 208 | return_mean: bool = True, 209 | return_numpy: Optional[bool] = None, 210 | ) -> Tuple[Union[np.ndarray, pd.DataFrame], List[str]]: 211 | """Returns cells by genes by states probabilities. 212 | 213 | Parameters 214 | ---------- 215 | adata 216 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 217 | AnnData object used to initialize the model. 218 | indices 219 | Indices of cells in adata to use. If `None`, all cells are used. 220 | gene_list 221 | Return frequencies of expression for a subset of genes. 222 | This can save memory when working with large datasets and few genes are 223 | of interest. 224 | hard_assignment 225 | Return a hard state assignment 226 | n_samples 227 | Number of posterior samples to use for estimation. 228 | batch_size 229 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 230 | return_mean 231 | Whether to return the mean of the samples. 232 | return_numpy 233 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes 234 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`. 235 | Otherwise, it defaults to `True`. 236 | 237 | Returns 238 | ------- 239 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. 240 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. 241 | """ 242 | adata = self._validate_anndata(adata) 243 | scdl = self._make_data_loader( 244 | adata=adata, indices=indices, batch_size=batch_size 245 | ) 246 | 247 | if gene_list is None: 248 | gene_mask = slice(None) 249 | else: 250 | all_genes = adata.var_names 251 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 252 | 253 | if n_samples > 1 and return_mean is False: 254 | if return_numpy is False: 255 | warnings.warn( 256 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray" 257 | ) 258 | return_numpy = True 259 | if indices is None: 260 | indices = np.arange(adata.n_obs) 261 | 262 | states = [] 263 | for tensors in scdl: 264 | minibatch_samples = [] 265 | for _ in range(n_samples): 266 | _, generative_outputs = self.module.forward( 267 | tensors=tensors, 268 | compute_loss=False, 269 | ) 270 | output = generative_outputs["px_pi"] 271 | output = output[..., gene_mask, :] 272 | output = output.cpu().numpy() 273 | minibatch_samples.append(output) 274 | # samples by cells by genes by four 275 | states.append(np.stack(minibatch_samples, axis=0)) 276 | if return_mean: 277 | states[-1] = np.mean(states[-1], axis=0) 278 | 279 | states = np.concatenate(states, axis=0) 280 | state_cats = [ 281 | "induction", 282 | "induction_steady", 283 | "repression", 284 | "repression_steady", 285 | ] 286 | if hard_assignment and return_mean: 287 | hard_assign = states.argmax(-1) 288 | 289 | hard_assign = pd.DataFrame( 290 | data=hard_assign, index=adata.obs_names, columns=adata.var_names 291 | ) 292 | for i, s in enumerate(state_cats): 293 | hard_assign = hard_assign.replace(i, s) 294 | 295 | states = hard_assign 296 | 297 | return states, state_cats 298 | 299 | @torch.inference_mode() 300 | def get_latent_time( 301 | self, 302 | adata: Optional[AnnData] = None, 303 | indices: Optional[Sequence[int]] = None, 304 | gene_list: Optional[Sequence[str]] = None, 305 | time_statistic: Literal["mean", "max"] = "mean", 306 | n_samples: int = 1, 307 | n_samples_overall: Optional[int] = None, 308 | batch_size: Optional[int] = None, 309 | return_mean: bool = True, 310 | return_numpy: Optional[bool] = None, 311 | ) -> Union[np.ndarray, pd.DataFrame]: 312 | """Returns the cells by genes latent time. 313 | 314 | Parameters 315 | ---------- 316 | adata 317 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 318 | AnnData object used to initialize the model. 319 | indices 320 | Indices of cells in adata to use. If `None`, all cells are used. 321 | gene_list 322 | Return frequencies of expression for a subset of genes. 323 | This can save memory when working with large datasets and few genes are 324 | of interest. 325 | time_statistic 326 | Whether to compute expected time over states, or maximum a posteriori time over maximal 327 | probability state. 328 | n_samples 329 | Number of posterior samples to use for estimation. 330 | n_samples_overall 331 | Number of overall samples to return. Setting this forces n_samples=1. 332 | batch_size 333 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 334 | return_mean 335 | Whether to return the mean of the samples. 336 | return_numpy 337 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes 338 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`. 339 | Otherwise, it defaults to `True`. 340 | 341 | Returns 342 | ------- 343 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. 344 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. 345 | """ 346 | adata = self._validate_anndata(adata) 347 | if indices is None: 348 | indices = np.arange(adata.n_obs) 349 | if n_samples_overall is not None: 350 | indices = np.random.choice(indices, n_samples_overall) 351 | scdl = self._make_data_loader( 352 | adata=adata, indices=indices, batch_size=batch_size 353 | ) 354 | 355 | if gene_list is None: 356 | gene_mask = slice(None) 357 | else: 358 | all_genes = adata.var_names 359 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 360 | 361 | if n_samples > 1 and return_mean is False: 362 | if return_numpy is False: 363 | warnings.warn( 364 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray" 365 | ) 366 | return_numpy = True 367 | if indices is None: 368 | indices = np.arange(adata.n_obs) 369 | 370 | times = [] 371 | for tensors in scdl: 372 | minibatch_samples = [] 373 | for _ in range(n_samples): 374 | _, generative_outputs = self.module.forward( 375 | tensors=tensors, 376 | compute_loss=False, 377 | ) 378 | pi = generative_outputs["px_pi"] 379 | ind_prob = pi[..., 0] 380 | steady_prob = pi[..., 1] 381 | rep_prob = pi[..., 2] 382 | # rep_steady_prob = pi[..., 3] 383 | switch_time = F.softplus(self.module.switch_time_unconstr) 384 | 385 | ind_time = generative_outputs["px_rho"] * switch_time 386 | rep_time = switch_time + ( 387 | generative_outputs["px_tau"] * (self.module.t_max - switch_time) 388 | ) 389 | 390 | if time_statistic == "mean": 391 | output = ( 392 | ind_prob * ind_time 393 | + rep_prob * rep_time 394 | + steady_prob * switch_time 395 | # + rep_steady_prob * self.module.t_max 396 | ) 397 | else: 398 | t = torch.stack( 399 | [ 400 | ind_time, 401 | switch_time.expand(ind_time.shape), 402 | rep_time, 403 | torch.zeros_like(ind_time), 404 | ], 405 | dim=2, 406 | ) 407 | max_prob = torch.amax(pi, dim=-1) 408 | max_prob = torch.stack([max_prob] * 4, dim=2) 409 | max_prob_mask = pi.ge(max_prob) 410 | output = (t * max_prob_mask).sum(dim=-1) 411 | 412 | output = output[..., gene_mask] 413 | output = output.cpu().numpy() 414 | minibatch_samples.append(output) 415 | # samples by cells by genes by four 416 | times.append(np.stack(minibatch_samples, axis=0)) 417 | if return_mean: 418 | times[-1] = np.mean(times[-1], axis=0) 419 | 420 | if n_samples > 1: 421 | # The -2 axis correspond to cells. 422 | times = np.concatenate(times, axis=-2) 423 | else: 424 | times = np.concatenate(times, axis=0) 425 | 426 | if return_numpy is None or return_numpy is False: 427 | return pd.DataFrame( 428 | times, 429 | columns=adata.var_names[gene_mask], 430 | index=adata.obs_names[indices], 431 | ) 432 | else: 433 | return times 434 | 435 | @torch.inference_mode() 436 | def get_velocity( 437 | self, 438 | adata: Optional[AnnData] = None, 439 | indices: Optional[Sequence[int]] = None, 440 | gene_list: Optional[Sequence[str]] = None, 441 | n_samples: int = 1, 442 | n_samples_overall: Optional[int] = None, 443 | batch_size: Optional[int] = None, 444 | return_mean: bool = True, 445 | return_numpy: Optional[bool] = None, 446 | velo_statistic: str = "mean", 447 | velo_mode: Literal["spliced", "unspliced"] = "spliced", 448 | clip: bool = True, 449 | ) -> Union[np.ndarray, pd.DataFrame]: 450 | """Returns cells by genes velocity estimates. 451 | 452 | Parameters 453 | ---------- 454 | adata 455 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 456 | AnnData object used to initialize the model. 457 | indices 458 | Indices of cells in adata to use. If `None`, all cells are used. 459 | gene_list 460 | Return velocities for a subset of genes. 461 | This can save memory when working with large datasets and few genes are 462 | of interest. 463 | n_samples 464 | Number of posterior samples to use for estimation for each cell. 465 | n_samples_overall 466 | Number of overall samples to return. Setting this forces n_samples=1. 467 | batch_size 468 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 469 | return_mean 470 | Whether to return the mean of the samples. 471 | return_numpy 472 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes 473 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`. 474 | Otherwise, it defaults to `True`. 475 | velo_statistic 476 | Whether to compute expected velocity over states, or maximum a posteriori velocity over maximal 477 | probability state. 478 | velo_mode 479 | Compute ds/dt or du/dt. 480 | clip 481 | Clip to minus spliced value 482 | 483 | Returns 484 | ------- 485 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. 486 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. 487 | """ 488 | adata = self._validate_anndata(adata) 489 | if indices is None: 490 | indices = np.arange(adata.n_obs) 491 | if n_samples_overall is not None: 492 | indices = np.random.choice(indices, n_samples_overall) 493 | n_samples = 1 494 | scdl = self._make_data_loader( 495 | adata=adata, indices=indices, batch_size=batch_size 496 | ) 497 | 498 | if gene_list is None: 499 | gene_mask = slice(None) 500 | else: 501 | all_genes = adata.var_names 502 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 503 | 504 | if n_samples > 1 and return_mean is False: 505 | if return_numpy is False: 506 | warnings.warn( 507 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray" 508 | ) 509 | return_numpy = True 510 | if indices is None: 511 | indices = np.arange(adata.n_obs) 512 | 513 | velos = [] 514 | for tensors in scdl: 515 | minibatch_samples = [] 516 | for _ in range(n_samples): 517 | inference_outputs, generative_outputs = self.module.forward( 518 | tensors=tensors, 519 | compute_loss=False, 520 | ) 521 | pi = generative_outputs["px_pi"] 522 | alpha = inference_outputs["alpha"] 523 | alpha_1 = inference_outputs["alpha_1"] 524 | lambda_alpha = inference_outputs["lambda_alpha"] 525 | beta = inference_outputs["beta"] 526 | gamma = inference_outputs["gamma"] 527 | tau = generative_outputs["px_tau"] 528 | rho = generative_outputs["px_rho"] 529 | 530 | ind_prob = pi[..., 0] 531 | steady_prob = pi[..., 1] 532 | rep_prob = pi[..., 2] 533 | switch_time = F.softplus(self.module.switch_time_unconstr) 534 | 535 | ind_time = switch_time * rho 536 | u_0, s_0 = self.module._get_induction_unspliced_spliced( 537 | alpha, alpha_1, lambda_alpha, beta, gamma, switch_time 538 | ) 539 | rep_time = (self.module.t_max - switch_time) * tau 540 | mean_u_rep, mean_s_rep = self.module._get_repression_unspliced_spliced( 541 | u_0, 542 | s_0, 543 | beta, 544 | gamma, 545 | rep_time, 546 | ) 547 | if velo_mode == "spliced": 548 | velo_rep = beta * mean_u_rep - gamma * mean_s_rep 549 | else: 550 | velo_rep = -beta * mean_u_rep 551 | mean_u_ind, mean_s_ind = self.module._get_induction_unspliced_spliced( 552 | alpha, alpha_1, lambda_alpha, beta, gamma, ind_time 553 | ) 554 | if velo_mode == "spliced": 555 | velo_ind = beta * mean_u_ind - gamma * mean_s_ind 556 | else: 557 | transcription_rate = alpha_1 - (alpha_1 - alpha) * torch.exp( 558 | -lambda_alpha * ind_time 559 | ) 560 | velo_ind = transcription_rate - beta * mean_u_ind 561 | 562 | if velo_mode == "spliced": 563 | # velo_steady = beta * u_0 - gamma * s_0 564 | velo_steady = torch.zeros_like(velo_ind) 565 | else: 566 | # velo_steady = alpha - beta * u_0 567 | velo_steady = torch.zeros_like(velo_ind) 568 | 569 | # expectation 570 | if velo_statistic == "mean": 571 | output = ( 572 | ind_prob * velo_ind 573 | + rep_prob * velo_rep 574 | + steady_prob * velo_steady 575 | ) 576 | # maximum 577 | else: 578 | v = torch.stack( 579 | [ 580 | velo_ind, 581 | velo_steady.expand(velo_ind.shape), 582 | velo_rep, 583 | torch.zeros_like(velo_rep), 584 | ], 585 | dim=2, 586 | ) 587 | max_prob = torch.amax(pi, dim=-1) 588 | max_prob = torch.stack([max_prob] * 4, dim=2) 589 | max_prob_mask = pi.ge(max_prob) 590 | output = (v * max_prob_mask).sum(dim=-1) 591 | 592 | output = output[..., gene_mask] 593 | output = output.cpu().numpy() 594 | minibatch_samples.append(output) 595 | # samples by cells by genes 596 | velos.append(np.stack(minibatch_samples, axis=0)) 597 | if return_mean: 598 | # mean over samples axis 599 | velos[-1] = np.mean(velos[-1], axis=0) 600 | 601 | if n_samples > 1: 602 | # The -2 axis correspond to cells. 603 | velos = np.concatenate(velos, axis=-2) 604 | else: 605 | velos = np.concatenate(velos, axis=0) 606 | 607 | spliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) 608 | 609 | if clip: 610 | velos = np.clip(velos, -spliced[indices], None) 611 | 612 | if return_numpy is None or return_numpy is False: 613 | return pd.DataFrame( 614 | velos, 615 | columns=adata.var_names[gene_mask], 616 | index=adata.obs_names[indices], 617 | ) 618 | else: 619 | return velos 620 | 621 | @torch.inference_mode() 622 | def get_expression_fit( 623 | self, 624 | adata: Optional[AnnData] = None, 625 | indices: Optional[Sequence[int]] = None, 626 | gene_list: Optional[Sequence[str]] = None, 627 | n_samples: int = 1, 628 | batch_size: Optional[int] = None, 629 | return_mean: bool = True, 630 | return_numpy: Optional[bool] = None, 631 | restrict_to_latent_dim: Optional[int] = None, 632 | ) -> Union[np.ndarray, pd.DataFrame]: 633 | r"""Returns the fitted spliced and unspliced abundance (s(t) and u(t)). 634 | 635 | Parameters 636 | ---------- 637 | adata 638 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 639 | AnnData object used to initialize the model. 640 | indices 641 | Indices of cells in adata to use. If `None`, all cells are used. 642 | gene_list 643 | Return frequencies of expression for a subset of genes. 644 | This can save memory when working with large datasets and few genes are 645 | of interest. 646 | n_samples 647 | Number of posterior samples to use for estimation. 648 | batch_size 649 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 650 | return_mean 651 | Whether to return the mean of the samples. 652 | return_numpy 653 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes 654 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`. 655 | Otherwise, it defaults to `True`. 656 | 657 | Returns 658 | ------- 659 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. 660 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. 661 | """ 662 | adata = self._validate_anndata(adata) 663 | 664 | scdl = self._make_data_loader( 665 | adata=adata, indices=indices, batch_size=batch_size 666 | ) 667 | 668 | if gene_list is None: 669 | gene_mask = slice(None) 670 | else: 671 | all_genes = adata.var_names 672 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 673 | 674 | if n_samples > 1 and return_mean is False: 675 | if return_numpy is False: 676 | warnings.warn( 677 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray" 678 | ) 679 | return_numpy = True 680 | if indices is None: 681 | indices = np.arange(adata.n_obs) 682 | 683 | fits_s = [] 684 | fits_u = [] 685 | for tensors in scdl: 686 | minibatch_samples_s = [] 687 | minibatch_samples_u = [] 688 | for _ in range(n_samples): 689 | inference_outputs, generative_outputs = self.module.forward( 690 | tensors=tensors, 691 | compute_loss=False, 692 | generative_kwargs={"latent_dim": restrict_to_latent_dim}, 693 | ) 694 | 695 | gamma = inference_outputs["gamma"] 696 | beta = inference_outputs["beta"] 697 | alpha = inference_outputs["alpha"] 698 | alpha_1 = inference_outputs["alpha_1"] 699 | lambda_alpha = inference_outputs["lambda_alpha"] 700 | px_pi = generative_outputs["px_pi"] 701 | scale = generative_outputs["scale"] 702 | px_rho = generative_outputs["px_rho"] 703 | px_tau = generative_outputs["px_tau"] 704 | 705 | ( 706 | mixture_dist_s, 707 | mixture_dist_u, 708 | _, 709 | ) = self.module.get_px( 710 | px_pi, 711 | px_rho, 712 | px_tau, 713 | scale, 714 | gamma, 715 | beta, 716 | alpha, 717 | alpha_1, 718 | lambda_alpha, 719 | ) 720 | fit_s = mixture_dist_s.mean 721 | fit_u = mixture_dist_u.mean 722 | 723 | fit_s = fit_s[..., gene_mask] 724 | fit_s = fit_s.cpu().numpy() 725 | fit_u = fit_u[..., gene_mask] 726 | fit_u = fit_u.cpu().numpy() 727 | 728 | minibatch_samples_s.append(fit_s) 729 | minibatch_samples_u.append(fit_u) 730 | 731 | # samples by cells by genes 732 | fits_s.append(np.stack(minibatch_samples_s, axis=0)) 733 | if return_mean: 734 | # mean over samples axis 735 | fits_s[-1] = np.mean(fits_s[-1], axis=0) 736 | # samples by cells by genes 737 | fits_u.append(np.stack(minibatch_samples_u, axis=0)) 738 | if return_mean: 739 | # mean over samples axis 740 | fits_u[-1] = np.mean(fits_u[-1], axis=0) 741 | 742 | if n_samples > 1: 743 | # The -2 axis correspond to cells. 744 | fits_s = np.concatenate(fits_s, axis=-2) 745 | fits_u = np.concatenate(fits_u, axis=-2) 746 | else: 747 | fits_s = np.concatenate(fits_s, axis=0) 748 | fits_u = np.concatenate(fits_u, axis=0) 749 | 750 | if return_numpy is None or return_numpy is False: 751 | df_s = pd.DataFrame( 752 | fits_s, 753 | columns=adata.var_names[gene_mask], 754 | index=adata.obs_names[indices], 755 | ) 756 | df_u = pd.DataFrame( 757 | fits_u, 758 | columns=adata.var_names[gene_mask], 759 | index=adata.obs_names[indices], 760 | ) 761 | return df_s, df_u 762 | else: 763 | return fits_s, fits_u 764 | 765 | @torch.inference_mode() 766 | def get_gene_likelihood( 767 | self, 768 | adata: Optional[AnnData] = None, 769 | indices: Optional[Sequence[int]] = None, 770 | gene_list: Optional[Sequence[str]] = None, 771 | n_samples: int = 1, 772 | batch_size: Optional[int] = None, 773 | return_mean: bool = True, 774 | return_numpy: Optional[bool] = None, 775 | ) -> Union[np.ndarray, pd.DataFrame]: 776 | r"""Returns the likelihood per gene. Higher is better. 777 | 778 | This is denoted as :math:`\rho_n` in the scVI paper. 779 | 780 | Parameters 781 | ---------- 782 | adata 783 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 784 | AnnData object used to initialize the model. 785 | indices 786 | Indices of cells in adata to use. If `None`, all cells are used. 787 | transform_batch 788 | Batch to condition on. 789 | If transform_batch is: 790 | 791 | - None, then real observed batch is used. 792 | - int, then batch transform_batch is used. 793 | gene_list 794 | Return frequencies of expression for a subset of genes. 795 | This can save memory when working with large datasets and few genes are 796 | of interest. 797 | library_size 798 | Scale the expression frequencies to a common library size. 799 | This allows gene expression levels to be interpreted on a common scale of relevant 800 | magnitude. If set to `"latent"`, use the latent libary size. 801 | n_samples 802 | Number of posterior samples to use for estimation. 803 | batch_size 804 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 805 | return_mean 806 | Whether to return the mean of the samples. 807 | return_numpy 808 | Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes 809 | gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`. 810 | Otherwise, it defaults to `True`. 811 | 812 | Returns 813 | ------- 814 | If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. 815 | Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. 816 | """ 817 | adata = self._validate_anndata(adata) 818 | scdl = self._make_data_loader( 819 | adata=adata, indices=indices, batch_size=batch_size 820 | ) 821 | 822 | if gene_list is None: 823 | gene_mask = slice(None) 824 | else: 825 | all_genes = adata.var_names 826 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 827 | 828 | if n_samples > 1 and return_mean is False: 829 | if return_numpy is False: 830 | warnings.warn( 831 | "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray" 832 | ) 833 | return_numpy = True 834 | if indices is None: 835 | indices = np.arange(adata.n_obs) 836 | 837 | rls = [] 838 | for tensors in scdl: 839 | minibatch_samples = [] 840 | for _ in range(n_samples): 841 | inference_outputs, generative_outputs = self.module.forward( 842 | tensors=tensors, 843 | compute_loss=False, 844 | ) 845 | spliced = tensors[REGISTRY_KEYS.X_KEY] 846 | unspliced = tensors[REGISTRY_KEYS.U_KEY] 847 | 848 | gamma = inference_outputs["gamma"] 849 | beta = inference_outputs["beta"] 850 | alpha = inference_outputs["alpha"] 851 | alpha_1 = inference_outputs["alpha_1"] 852 | lambda_alpha = inference_outputs["lambda_alpha"] 853 | px_pi = generative_outputs["px_pi"] 854 | scale = generative_outputs["scale"] 855 | px_rho = generative_outputs["px_rho"] 856 | px_tau = generative_outputs["px_tau"] 857 | 858 | ( 859 | mixture_dist_s, 860 | mixture_dist_u, 861 | _, 862 | ) = self.module.get_px( 863 | px_pi, 864 | px_rho, 865 | px_tau, 866 | scale, 867 | gamma, 868 | beta, 869 | alpha, 870 | alpha_1, 871 | lambda_alpha, 872 | ) 873 | device = gamma.device 874 | reconst_loss_s = -mixture_dist_s.log_prob(spliced.to(device)) 875 | reconst_loss_u = -mixture_dist_u.log_prob(unspliced.to(device)) 876 | output = -(reconst_loss_s + reconst_loss_u) 877 | output = output[..., gene_mask] 878 | output = output.cpu().numpy() 879 | minibatch_samples.append(output) 880 | # samples by cells by genes by four 881 | rls.append(np.stack(minibatch_samples, axis=0)) 882 | if return_mean: 883 | rls[-1] = np.mean(rls[-1], axis=0) 884 | 885 | rls = np.concatenate(rls, axis=0) 886 | return rls 887 | 888 | @torch.inference_mode() 889 | def get_rates(self): 890 | gamma, beta, alpha, alpha_1, lambda_alpha = self.module._get_rates() 891 | 892 | return { 893 | "beta": beta.cpu().numpy(), 894 | "gamma": gamma.cpu().numpy(), 895 | "alpha": alpha.cpu().numpy(), 896 | "alpha_1": alpha_1.cpu().numpy(), 897 | "lambda_alpha": lambda_alpha.cpu().numpy(), 898 | } 899 | 900 | @classmethod 901 | @setup_anndata_dsp.dedent 902 | def setup_anndata( 903 | cls, 904 | adata: AnnData, 905 | spliced_layer: str, 906 | unspliced_layer: str, 907 | **kwargs, 908 | ) -> Optional[AnnData]: 909 | """%(summary)s. 910 | 911 | Parameters 912 | ---------- 913 | %(param_adata)s 914 | spliced_layer 915 | Layer in adata with spliced normalized expression 916 | unspliced_layer 917 | Layer in adata with unspliced normalized expression. 918 | 919 | Returns 920 | ------- 921 | %(returns)s 922 | """ 923 | setup_method_args = cls._get_setup_method_args(**locals()) 924 | anndata_fields = [ 925 | LayerField(REGISTRY_KEYS.X_KEY, spliced_layer, is_count_data=False), 926 | LayerField(REGISTRY_KEYS.U_KEY, unspliced_layer, is_count_data=False), 927 | ] 928 | adata_manager = AnnDataManager( 929 | fields=anndata_fields, setup_method_args=setup_method_args 930 | ) 931 | adata_manager.register_fields(adata, **kwargs) 932 | cls.register_manager(adata_manager) 933 | 934 | def get_directional_uncertainty( 935 | self, 936 | adata: Optional[AnnData] = None, 937 | n_samples: int = 50, 938 | gene_list: Iterable[str] = None, 939 | n_jobs: int = -1, 940 | ): 941 | adata = self._validate_anndata(adata) 942 | 943 | logger.info("Sampling from model...") 944 | velocities_all = self.get_velocity( 945 | n_samples=n_samples, return_mean=False, gene_list=gene_list 946 | ) # (n_samples, n_cells, n_genes) 947 | 948 | df, cosine_sims = _compute_directional_statistics_tensor( 949 | tensor=velocities_all, n_jobs=n_jobs, n_cells=adata.n_obs 950 | ) 951 | df.index = adata.obs_names 952 | 953 | return df, cosine_sims 954 | 955 | def get_permutation_scores( 956 | self, labels_key: str, adata: Optional[AnnData] = None 957 | ) -> Tuple[pd.DataFrame, AnnData]: 958 | """Compute permutation scores. 959 | 960 | Parameters 961 | ---------- 962 | labels_key 963 | Key in adata.obs encoding cell types 964 | adata 965 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 966 | AnnData object used to initialize the model. 967 | 968 | Returns 969 | ------- 970 | Tuple of DataFrame and AnnData. DataFrame is genes by cell types with score per cell type. 971 | AnnData is the permutated version of the original AnnData. 972 | """ 973 | adata = self._validate_anndata(adata) 974 | adata_manager = self.get_anndata_manager(adata) 975 | if labels_key not in adata.obs: 976 | raise ValueError(f"{labels_key} not found in adata.obs") 977 | 978 | # shuffle spliced then unspliced 979 | bdata = self._shuffle_layer_celltype( 980 | adata_manager, labels_key, REGISTRY_KEYS.X_KEY 981 | ) 982 | bdata_manager = self.get_anndata_manager(bdata) 983 | bdata = self._shuffle_layer_celltype( 984 | bdata_manager, labels_key, REGISTRY_KEYS.U_KEY 985 | ) 986 | bdata_manager = self.get_anndata_manager(bdata) 987 | 988 | ms_ = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) 989 | mu_ = adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY) 990 | 991 | ms_p = bdata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) 992 | mu_p = bdata_manager.get_from_registry(REGISTRY_KEYS.U_KEY) 993 | 994 | spliced_, unspliced_ = self.get_expression_fit(adata, n_samples=10) 995 | root_squared_error = np.abs(spliced_ - ms_) 996 | root_squared_error += np.abs(unspliced_ - mu_) 997 | 998 | spliced_p, unspliced_p = self.get_expression_fit(bdata, n_samples=10) 999 | root_squared_error_p = np.abs(spliced_p - ms_p) 1000 | root_squared_error_p += np.abs(unspliced_p - mu_p) 1001 | 1002 | celltypes = np.unique(adata.obs[labels_key]) 1003 | 1004 | dynamical_df = pd.DataFrame( 1005 | index=adata.var_names, 1006 | columns=celltypes, 1007 | data=np.zeros((adata.shape[1], len(celltypes))), 1008 | ) 1009 | N = 200 1010 | for ct in celltypes: 1011 | for g in adata.var_names.tolist(): 1012 | x = root_squared_error_p[g][adata.obs[labels_key] == ct] 1013 | y = root_squared_error[g][adata.obs[labels_key] == ct] 1014 | ratio = ttest_ind(x[:N], y[:N])[0] 1015 | dynamical_df.loc[g, ct] = ratio 1016 | 1017 | return dynamical_df, bdata 1018 | 1019 | def _shuffle_layer_celltype( 1020 | self, adata_manager: AnnDataManager, labels_key: str, registry_key: str 1021 | ) -> AnnData: 1022 | """Shuffle cells within cell types for each gene.""" 1023 | from scvi.data._constants import _SCVI_UUID_KEY 1024 | 1025 | bdata = adata_manager.adata.copy() 1026 | labels = bdata.obs[labels_key] 1027 | del bdata.uns[_SCVI_UUID_KEY] 1028 | self._validate_anndata(bdata) 1029 | bdata_manager = self.get_anndata_manager(bdata) 1030 | 1031 | # get registry info to later set data back in bdata 1032 | # in a way that doesn't require actual knowledge of location 1033 | unspliced = bdata_manager.get_from_registry(registry_key) 1034 | u_registry = bdata_manager.data_registry[registry_key] 1035 | attr_name = u_registry.attr_name 1036 | attr_key = u_registry.attr_key 1037 | 1038 | for lab in np.unique(labels): 1039 | mask = np.asarray(labels == lab) 1040 | unspliced_ct = unspliced[mask].copy() 1041 | unspliced_ct = np.apply_along_axis( 1042 | np.random.permutation, axis=0, arr=unspliced_ct 1043 | ) 1044 | unspliced[mask] = unspliced_ct 1045 | # e.g., if using adata.X 1046 | if attr_key is None: 1047 | setattr(bdata, attr_name, unspliced) 1048 | # e.g., if using a layer 1049 | elif attr_key is not None: 1050 | attribute = getattr(bdata, attr_name) 1051 | attribute[attr_key] = unspliced 1052 | setattr(bdata, attr_name, attribute) 1053 | 1054 | return bdata 1055 | 1056 | 1057 | def _compute_directional_statistics_tensor( 1058 | tensor: np.ndarray, n_jobs: int, n_cells: int 1059 | ) -> pd.DataFrame: 1060 | df = pd.DataFrame(index=np.arange(n_cells)) 1061 | df["directional_variance"] = np.nan 1062 | df["directional_difference"] = np.nan 1063 | df["directional_cosine_sim_variance"] = np.nan 1064 | df["directional_cosine_sim_difference"] = np.nan 1065 | df["directional_cosine_sim_mean"] = np.nan 1066 | logger.info("Computing the uncertainties...") 1067 | results = Parallel(n_jobs=n_jobs, verbose=3)( 1068 | delayed(_directional_statistics_per_cell)(tensor[:, cell_index, :]) 1069 | for cell_index in range(n_cells) 1070 | ) 1071 | # cells by samples 1072 | cosine_sims = np.stack([results[i][0] for i in range(n_cells)]) 1073 | df.loc[:, "directional_cosine_sim_variance"] = [ 1074 | results[i][1] for i in range(n_cells) 1075 | ] 1076 | df.loc[:, "directional_cosine_sim_difference"] = [ 1077 | results[i][2] for i in range(n_cells) 1078 | ] 1079 | df.loc[:, "directional_variance"] = [results[i][3] for i in range(n_cells)] 1080 | df.loc[:, "directional_difference"] = [results[i][4] for i in range(n_cells)] 1081 | df.loc[:, "directional_cosine_sim_mean"] = [results[i][5] for i in range(n_cells)] 1082 | 1083 | return df, cosine_sims 1084 | 1085 | 1086 | def _directional_statistics_per_cell( 1087 | tensor: np.ndarray, 1088 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 1089 | """Internal function for parallelization. 1090 | 1091 | Parameters 1092 | ---------- 1093 | tensor 1094 | Shape of samples by genes for a given cell. 1095 | """ 1096 | n_samples = tensor.shape[0] 1097 | # over samples axis 1098 | mean_velocity_of_cell = tensor.mean(0) 1099 | cosine_sims = [ 1100 | _cosine_sim(tensor[i, :], mean_velocity_of_cell) for i in range(n_samples) 1101 | ] 1102 | angle_samples = [np.arccos(el) for el in cosine_sims] 1103 | return ( 1104 | cosine_sims, 1105 | np.var(cosine_sims), 1106 | np.percentile(cosine_sims, 95) - np.percentile(cosine_sims, 5), 1107 | np.var(angle_samples), 1108 | np.percentile(angle_samples, 95) - np.percentile(angle_samples, 5), 1109 | np.mean(cosine_sims), 1110 | ) 1111 | 1112 | 1113 | def _centered_unit_vector(vector: np.ndarray) -> np.ndarray: 1114 | """Returns the centered unit vector of the vector.""" 1115 | vector = vector - np.mean(vector) 1116 | return vector / np.linalg.norm(vector) 1117 | 1118 | 1119 | def _cosine_sim(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: 1120 | """Returns cosine similarity of the vectors.""" 1121 | v1_u = _centered_unit_vector(v1) 1122 | v2_u = _centered_unit_vector(v2) 1123 | return np.clip(np.dot(v1_u, v2_u), -1.0, 1.0) 1124 | -------------------------------------------------------------------------------- /velovi/_module.py: -------------------------------------------------------------------------------- 1 | """Main module.""" 2 | from typing import Callable, Iterable, Literal, Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 8 | from scvi.nn import Encoder, FCLayers 9 | from torch import nn as nn 10 | from torch.distributions import Categorical, Dirichlet, MixtureSameFamily, Normal 11 | from torch.distributions import kl_divergence as kl 12 | 13 | from ._constants import REGISTRY_KEYS 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | 18 | class DecoderVELOVI(nn.Module): 19 | """Decodes data from latent space of ``n_input`` dimensions ``n_output``dimensions. 20 | 21 | Uses a fully-connected neural network of ``n_hidden`` layers. 22 | 23 | Parameters 24 | ---------- 25 | n_input 26 | The dimensionality of the input (latent space) 27 | n_output 28 | The dimensionality of the output (data space) 29 | n_cat_list 30 | A list containing the number of categories 31 | for each category of interest. Each category will be 32 | included using a one-hot encoding 33 | n_layers 34 | The number of fully-connected hidden layers 35 | n_hidden 36 | The number of nodes per hidden layer 37 | dropout_rate 38 | Dropout rate to apply to each of the hidden layers 39 | inject_covariates 40 | Whether to inject covariates in each layer, or just the first (default). 41 | use_batch_norm 42 | Whether to use batch norm in layers 43 | use_layer_norm 44 | Whether to use layer norm in layers 45 | linear_decoder 46 | Whether to use linear decoder for time 47 | """ 48 | 49 | def __init__( 50 | self, 51 | n_input: int, 52 | n_output: int, 53 | n_cat_list: Iterable[int] = None, 54 | n_layers: int = 1, 55 | n_hidden: int = 128, 56 | inject_covariates: bool = True, 57 | use_batch_norm: bool = True, 58 | use_layer_norm: bool = False, 59 | dropout_rate: float = 0.0, 60 | linear_decoder: bool = False, 61 | **kwargs, 62 | ): 63 | super().__init__() 64 | self.n_ouput = n_output 65 | self.linear_decoder = linear_decoder 66 | self.rho_first_decoder = FCLayers( 67 | n_in=n_input, 68 | n_out=n_hidden if not linear_decoder else n_output, 69 | n_cat_list=n_cat_list, 70 | n_layers=n_layers if not linear_decoder else 1, 71 | n_hidden=n_hidden, 72 | dropout_rate=dropout_rate, 73 | inject_covariates=inject_covariates, 74 | use_batch_norm=use_batch_norm, 75 | use_layer_norm=use_layer_norm if not linear_decoder else False, 76 | use_activation=not linear_decoder, 77 | bias=not linear_decoder, 78 | **kwargs, 79 | ) 80 | 81 | self.pi_first_decoder = FCLayers( 82 | n_in=n_input, 83 | n_out=n_hidden, 84 | n_cat_list=n_cat_list, 85 | n_layers=n_layers, 86 | n_hidden=n_hidden, 87 | dropout_rate=dropout_rate, 88 | inject_covariates=inject_covariates, 89 | use_batch_norm=use_batch_norm, 90 | use_layer_norm=use_layer_norm, 91 | **kwargs, 92 | ) 93 | 94 | # categorical pi 95 | # 4 states 96 | self.px_pi_decoder = nn.Linear(n_hidden, 4 * n_output) 97 | 98 | # rho for induction 99 | self.px_rho_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid()) 100 | 101 | # tau for repression 102 | self.px_tau_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid()) 103 | 104 | self.linear_scaling_tau = nn.Parameter(torch.zeros(n_output)) 105 | self.linear_scaling_tau_intercept = nn.Parameter(torch.zeros(n_output)) 106 | 107 | def forward(self, z: torch.Tensor, latent_dim: int = None): 108 | """The forward computation for a single sample. 109 | 110 | #. Decodes the data from the latent space using the decoder network 111 | #. Returns parameters for the ZINB distribution of expression 112 | #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None`` 113 | 114 | Parameters 115 | ---------- 116 | z : 117 | tensor with shape ``(n_input,)`` 118 | cat_list 119 | list of category membership(s) for this sample 120 | 121 | Returns 122 | ------- 123 | 4-tuple of :py:class:`torch.Tensor` 124 | parameters for the ZINB distribution of expression 125 | 126 | """ 127 | z_in = z 128 | if latent_dim is not None: 129 | mask = torch.zeros_like(z) 130 | mask[..., latent_dim] = 1 131 | z_in = z * mask 132 | # The decoder returns values for the parameters of the ZINB distribution 133 | rho_first = self.rho_first_decoder(z_in) 134 | 135 | if not self.linear_decoder: 136 | px_rho = self.px_rho_decoder(rho_first) 137 | px_tau = self.px_tau_decoder(rho_first) 138 | else: 139 | px_rho = nn.Sigmoid()(rho_first) 140 | px_tau = 1 - nn.Sigmoid()( 141 | rho_first * self.linear_scaling_tau.exp() 142 | + self.linear_scaling_tau_intercept 143 | ) 144 | 145 | # cells by genes by 4 146 | pi_first = self.pi_first_decoder(z) 147 | px_pi = nn.Softplus()( 148 | torch.reshape(self.px_pi_decoder(pi_first), (z.shape[0], self.n_ouput, 4)) 149 | ) 150 | 151 | return px_pi, px_rho, px_tau 152 | 153 | 154 | # VAE model 155 | class VELOVAE(BaseModuleClass): 156 | """Variational auto-encoder model. 157 | 158 | This is an implementation of the veloVI model descibed in :cite:p:`GayosoWeiler2022` 159 | 160 | Parameters 161 | ---------- 162 | n_input 163 | Number of input genes 164 | n_hidden 165 | Number of nodes per hidden layer 166 | n_latent 167 | Dimensionality of the latent space 168 | n_layers 169 | Number of hidden layers used for encoder and decoder NNs 170 | dropout_rate 171 | Dropout rate for neural networks 172 | log_variational 173 | Log(data+1) prior to encoding for numerical stability. Not normalization. 174 | latent_distribution 175 | One of 176 | 177 | * ``'normal'`` - Isotropic normal 178 | * ``'ln'`` - Logistic normal with normal params N(0, 1) 179 | use_layer_norm 180 | Whether to use layer norm in layers 181 | use_observed_lib_size 182 | Use observed library size for RNA as scaling factor in mean of conditional distribution 183 | var_activation 184 | Callable used to ensure positivity of the variational distributions' variance. 185 | When `None`, defaults to `torch.exp`. 186 | """ 187 | 188 | def __init__( 189 | self, 190 | n_input: int, 191 | true_time_switch: Optional[np.ndarray] = None, 192 | n_hidden: int = 128, 193 | n_latent: int = 10, 194 | n_layers: int = 1, 195 | dropout_rate: float = 0.1, 196 | log_variational: bool = False, 197 | latent_distribution: str = "normal", 198 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", 199 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", 200 | use_observed_lib_size: bool = True, 201 | var_activation: Optional[Callable] = torch.nn.Softplus(), 202 | model_steady_states: bool = True, 203 | gamma_unconstr_init: Optional[np.ndarray] = None, 204 | alpha_unconstr_init: Optional[np.ndarray] = None, 205 | alpha_1_unconstr_init: Optional[np.ndarray] = None, 206 | lambda_alpha_unconstr_init: Optional[np.ndarray] = None, 207 | switch_spliced: Optional[np.ndarray] = None, 208 | switch_unspliced: Optional[np.ndarray] = None, 209 | t_max: float = 20, 210 | penalty_scale: float = 0.2, 211 | dirichlet_concentration: float = 0.25, 212 | linear_decoder: bool = False, 213 | time_dep_transcription_rate: bool = False, 214 | ): 215 | super().__init__() 216 | self.n_latent = n_latent 217 | self.log_variational = log_variational 218 | self.latent_distribution = latent_distribution 219 | self.use_observed_lib_size = use_observed_lib_size 220 | self.n_input = n_input 221 | self.model_steady_states = model_steady_states 222 | self.t_max = t_max 223 | self.penalty_scale = penalty_scale 224 | self.dirichlet_concentration = dirichlet_concentration 225 | self.time_dep_transcription_rate = time_dep_transcription_rate 226 | 227 | if switch_spliced is not None: 228 | self.register_buffer("switch_spliced", torch.from_numpy(switch_spliced)) 229 | else: 230 | self.switch_spliced = None 231 | if switch_unspliced is not None: 232 | self.register_buffer("switch_unspliced", torch.from_numpy(switch_unspliced)) 233 | else: 234 | self.switch_unspliced = None 235 | 236 | n_genes = n_input * 2 237 | 238 | # switching time 239 | self.switch_time_unconstr = torch.nn.Parameter(7 + 0.5 * torch.randn(n_input)) 240 | if true_time_switch is not None: 241 | self.register_buffer("true_time_switch", torch.from_numpy(true_time_switch)) 242 | else: 243 | self.true_time_switch = None 244 | 245 | # degradation 246 | if gamma_unconstr_init is None: 247 | self.gamma_mean_unconstr = torch.nn.Parameter(-1 * torch.ones(n_input)) 248 | else: 249 | self.gamma_mean_unconstr = torch.nn.Parameter( 250 | torch.from_numpy(gamma_unconstr_init) 251 | ) 252 | 253 | # splicing 254 | # first samples around 1 255 | self.beta_mean_unconstr = torch.nn.Parameter(0.5 * torch.ones(n_input)) 256 | 257 | # transcription 258 | if alpha_unconstr_init is None: 259 | self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_input)) 260 | else: 261 | self.alpha_unconstr = torch.nn.Parameter( 262 | torch.from_numpy(alpha_unconstr_init) 263 | ) 264 | 265 | # TODO: Add `require_grad` 266 | if alpha_1_unconstr_init is None: 267 | self.alpha_1_unconstr = torch.nn.Parameter(0 * torch.ones(n_input)) 268 | else: 269 | self.alpha_1_unconstr = torch.nn.Parameter( 270 | torch.from_numpy(alpha_1_unconstr_init) 271 | ) 272 | self.alpha_1_unconstr.requires_grad = time_dep_transcription_rate 273 | 274 | if lambda_alpha_unconstr_init is None: 275 | self.lambda_alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_input)) 276 | else: 277 | self.lambda_alpha_unconstr = torch.nn.Parameter( 278 | torch.from_numpy(lambda_alpha_unconstr_init) 279 | ) 280 | self.lambda_alpha_unconstr.requires_grad = time_dep_transcription_rate 281 | 282 | # likelihood dispersion 283 | # for now, with normal dist, this is just the variance 284 | self.scale_unconstr = torch.nn.Parameter(-1 * torch.ones(n_genes, 4)) 285 | 286 | use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" 287 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" 288 | use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" 289 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" 290 | self.use_batch_norm_decoder = use_batch_norm_decoder 291 | 292 | # z encoder goes from the n_input-dimensional data to an n_latent-d 293 | # latent space representation 294 | n_input_encoder = n_genes 295 | self.z_encoder = Encoder( 296 | n_input_encoder, 297 | n_latent, 298 | n_layers=n_layers, 299 | n_hidden=n_hidden, 300 | dropout_rate=dropout_rate, 301 | distribution=latent_distribution, 302 | use_batch_norm=use_batch_norm_encoder, 303 | use_layer_norm=use_layer_norm_encoder, 304 | var_activation=var_activation, 305 | activation_fn=torch.nn.ReLU, 306 | ) 307 | # decoder goes from n_latent-dimensional space to n_input-d data 308 | n_input_decoder = n_latent 309 | self.decoder = DecoderVELOVI( 310 | n_input_decoder, 311 | n_input, 312 | n_layers=n_layers, 313 | n_hidden=n_hidden, 314 | use_batch_norm=use_batch_norm_decoder, 315 | use_layer_norm=use_layer_norm_decoder, 316 | activation_fn=torch.nn.ReLU, 317 | linear_decoder=linear_decoder, 318 | ) 319 | 320 | def _get_inference_input(self, tensors): 321 | spliced = tensors[REGISTRY_KEYS.X_KEY] 322 | unspliced = tensors[REGISTRY_KEYS.U_KEY] 323 | 324 | input_dict = { 325 | "spliced": spliced, 326 | "unspliced": unspliced, 327 | } 328 | return input_dict 329 | 330 | def _get_generative_input(self, tensors, inference_outputs): 331 | z = inference_outputs["z"] 332 | gamma = inference_outputs["gamma"] 333 | beta = inference_outputs["beta"] 334 | alpha = inference_outputs["alpha"] 335 | alpha_1 = inference_outputs["alpha_1"] 336 | lambda_alpha = inference_outputs["lambda_alpha"] 337 | 338 | input_dict = { 339 | "z": z, 340 | "gamma": gamma, 341 | "beta": beta, 342 | "alpha": alpha, 343 | "alpha_1": alpha_1, 344 | "lambda_alpha": lambda_alpha, 345 | } 346 | return input_dict 347 | 348 | @auto_move_data 349 | def inference( 350 | self, 351 | spliced, 352 | unspliced, 353 | n_samples=1, 354 | ): 355 | """High level inference method. 356 | 357 | Runs the inference (encoder) model. 358 | """ 359 | spliced_ = spliced 360 | unspliced_ = unspliced 361 | if self.log_variational: 362 | spliced_ = torch.log(0.01 + spliced) 363 | unspliced_ = torch.log(0.01 + unspliced) 364 | 365 | encoder_input = torch.cat((spliced_, unspliced_), dim=-1) 366 | 367 | qz_m, qz_v, z = self.z_encoder(encoder_input) 368 | 369 | if n_samples > 1: 370 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) 371 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) 372 | # when z is normal, untran_z == z 373 | untran_z = Normal(qz_m, qz_v.sqrt()).sample() 374 | z = self.z_encoder.z_transformation(untran_z) 375 | 376 | gamma, beta, alpha, alpha_1, lambda_alpha = self._get_rates() 377 | 378 | outputs = { 379 | "z": z, 380 | "qz_m": qz_m, 381 | "qz_v": qz_v, 382 | "gamma": gamma, 383 | "beta": beta, 384 | "alpha": alpha, 385 | "alpha_1": alpha_1, 386 | "lambda_alpha": lambda_alpha, 387 | } 388 | return outputs 389 | 390 | def _get_rates(self): 391 | # globals 392 | # degradation 393 | gamma = torch.clamp(F.softplus(self.gamma_mean_unconstr), 0, 50) 394 | # splicing 395 | beta = torch.clamp(F.softplus(self.beta_mean_unconstr), 0, 50) 396 | # transcription 397 | alpha = torch.clamp(F.softplus(self.alpha_unconstr), 0, 50) 398 | if self.time_dep_transcription_rate: 399 | alpha_1 = torch.clamp(F.softplus(self.alpha_1_unconstr), 0, 50) 400 | lambda_alpha = torch.clamp(F.softplus(self.lambda_alpha_unconstr), 0, 50) 401 | else: 402 | alpha_1 = self.alpha_1_unconstr 403 | lambda_alpha = self.lambda_alpha_unconstr 404 | 405 | return gamma, beta, alpha, alpha_1, lambda_alpha 406 | 407 | @auto_move_data 408 | def generative(self, z, gamma, beta, alpha, alpha_1, lambda_alpha, latent_dim=None): 409 | """Runs the generative model.""" 410 | decoder_input = z 411 | px_pi_alpha, px_rho, px_tau = self.decoder(decoder_input, latent_dim=latent_dim) 412 | px_pi = Dirichlet(px_pi_alpha).rsample() 413 | 414 | scale_unconstr = self.scale_unconstr 415 | scale = F.softplus(scale_unconstr) 416 | 417 | mixture_dist_s, mixture_dist_u, end_penalty = self.get_px( 418 | px_pi, 419 | px_rho, 420 | px_tau, 421 | scale, 422 | gamma, 423 | beta, 424 | alpha, 425 | alpha_1, 426 | lambda_alpha, 427 | ) 428 | 429 | return { 430 | "px_pi": px_pi, 431 | "px_rho": px_rho, 432 | "px_tau": px_tau, 433 | "scale": scale, 434 | "px_pi_alpha": px_pi_alpha, 435 | "mixture_dist_u": mixture_dist_u, 436 | "mixture_dist_s": mixture_dist_s, 437 | "end_penalty": end_penalty, 438 | } 439 | 440 | def loss( 441 | self, 442 | tensors, 443 | inference_outputs, 444 | generative_outputs, 445 | kl_weight: float = 1.0, 446 | n_obs: float = 1.0, 447 | ): 448 | spliced = tensors[REGISTRY_KEYS.X_KEY] 449 | unspliced = tensors[REGISTRY_KEYS.U_KEY] 450 | 451 | qz_m = inference_outputs["qz_m"] 452 | qz_v = inference_outputs["qz_v"] 453 | 454 | px_pi = generative_outputs["px_pi"] 455 | px_pi_alpha = generative_outputs["px_pi_alpha"] 456 | 457 | end_penalty = generative_outputs["end_penalty"] 458 | mixture_dist_s = generative_outputs["mixture_dist_s"] 459 | mixture_dist_u = generative_outputs["mixture_dist_u"] 460 | 461 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1) 462 | 463 | reconst_loss_s = -mixture_dist_s.log_prob(spliced) 464 | reconst_loss_u = -mixture_dist_u.log_prob(unspliced) 465 | 466 | reconst_loss = reconst_loss_u.sum(dim=-1) + reconst_loss_s.sum(dim=-1) 467 | 468 | kl_pi = kl( 469 | Dirichlet(px_pi_alpha), 470 | Dirichlet(self.dirichlet_concentration * torch.ones_like(px_pi)), 471 | ).sum(dim=-1) 472 | 473 | # local loss 474 | kl_local = kl_divergence_z + kl_pi 475 | weighted_kl_local = kl_weight * (kl_divergence_z) + kl_pi 476 | 477 | local_loss = torch.mean(reconst_loss + weighted_kl_local) 478 | 479 | loss = local_loss + self.penalty_scale * (1 - kl_weight) * end_penalty 480 | 481 | loss_recoder = LossOutput( 482 | loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local 483 | ) 484 | 485 | return loss_recoder 486 | 487 | @auto_move_data 488 | def get_px( 489 | self, 490 | px_pi, 491 | px_rho, 492 | px_tau, 493 | scale, 494 | gamma, 495 | beta, 496 | alpha, 497 | alpha_1, 498 | lambda_alpha, 499 | ) -> torch.Tensor: 500 | t_s = torch.clamp(F.softplus(self.switch_time_unconstr), 0, self.t_max) 501 | 502 | n_cells = px_pi.shape[0] 503 | 504 | # component dist 505 | comp_dist = Categorical(probs=px_pi) 506 | 507 | # induction 508 | mean_u_ind, mean_s_ind = self._get_induction_unspliced_spliced( 509 | alpha, alpha_1, lambda_alpha, beta, gamma, t_s * px_rho 510 | ) 511 | 512 | if self.time_dep_transcription_rate: 513 | mean_u_ind_steady = (alpha_1 / beta).expand(n_cells, self.n_input) 514 | mean_s_ind_steady = (alpha_1 / gamma).expand(n_cells, self.n_input) 515 | else: 516 | mean_u_ind_steady = (alpha / beta).expand(n_cells, self.n_input) 517 | mean_s_ind_steady = (alpha / gamma).expand(n_cells, self.n_input) 518 | scale_u = scale[: self.n_input, :].expand(n_cells, self.n_input, 4).sqrt() 519 | 520 | # repression 521 | u_0, s_0 = self._get_induction_unspliced_spliced( 522 | alpha, alpha_1, lambda_alpha, beta, gamma, t_s 523 | ) 524 | 525 | tau = px_tau 526 | mean_u_rep, mean_s_rep = self._get_repression_unspliced_spliced( 527 | u_0, 528 | s_0, 529 | beta, 530 | gamma, 531 | (self.t_max - t_s) * tau, 532 | ) 533 | mean_u_rep_steady = torch.zeros_like(mean_u_ind) 534 | mean_s_rep_steady = torch.zeros_like(mean_u_ind) 535 | scale_s = scale[self.n_input :, :].expand(n_cells, self.n_input, 4).sqrt() 536 | 537 | end_penalty = ((u_0 - self.switch_unspliced).pow(2)).sum() + ( 538 | (s_0 - self.switch_spliced).pow(2) 539 | ).sum() 540 | 541 | # unspliced 542 | mean_u = torch.stack( 543 | ( 544 | mean_u_ind, 545 | mean_u_ind_steady, 546 | mean_u_rep, 547 | mean_u_rep_steady, 548 | ), 549 | dim=2, 550 | ) 551 | scale_u = torch.stack( 552 | ( 553 | scale_u[..., 0], 554 | scale_u[..., 0], 555 | scale_u[..., 0], 556 | 0.1 * scale_u[..., 0], 557 | ), 558 | dim=2, 559 | ) 560 | dist_u = Normal(mean_u, scale_u) 561 | mixture_dist_u = MixtureSameFamily(comp_dist, dist_u) 562 | 563 | # spliced 564 | mean_s = torch.stack( 565 | (mean_s_ind, mean_s_ind_steady, mean_s_rep, mean_s_rep_steady), 566 | dim=2, 567 | ) 568 | scale_s = torch.stack( 569 | ( 570 | scale_s[..., 0], 571 | scale_s[..., 0], 572 | scale_s[..., 0], 573 | 0.1 * scale_s[..., 0], 574 | ), 575 | dim=2, 576 | ) 577 | dist_s = Normal(mean_s, scale_s) 578 | mixture_dist_s = MixtureSameFamily(comp_dist, dist_s) 579 | 580 | return mixture_dist_s, mixture_dist_u, end_penalty 581 | 582 | def _get_induction_unspliced_spliced( 583 | self, alpha, alpha_1, lambda_alpha, beta, gamma, t, eps=1e-6 584 | ): 585 | if self.time_dep_transcription_rate: 586 | unspliced = alpha_1 / beta * (1 - torch.exp(-beta * t)) - ( 587 | alpha_1 - alpha 588 | ) / (beta - lambda_alpha) * ( 589 | torch.exp(-lambda_alpha * t) - torch.exp(-beta * t) 590 | ) 591 | 592 | spliced = ( 593 | alpha_1 / gamma * (1 - torch.exp(-gamma * t)) 594 | + alpha_1 595 | / (gamma - beta + eps) 596 | * (torch.exp(-gamma * t) - torch.exp(-beta * t)) 597 | - beta 598 | * (alpha_1 - alpha) 599 | / (beta - lambda_alpha + eps) 600 | / (gamma - lambda_alpha + eps) 601 | * (torch.exp(-lambda_alpha * t) - torch.exp(-gamma * t)) 602 | + beta 603 | * (alpha_1 - alpha) 604 | / (beta - lambda_alpha + eps) 605 | / (gamma - beta + eps) 606 | * (torch.exp(-beta * t) - torch.exp(-gamma * t)) 607 | ) 608 | else: 609 | unspliced = (alpha / beta) * (1 - torch.exp(-beta * t)) 610 | spliced = (alpha / gamma) * (1 - torch.exp(-gamma * t)) + ( 611 | alpha / ((gamma - beta) + eps) 612 | ) * (torch.exp(-gamma * t) - torch.exp(-beta * t)) 613 | 614 | return unspliced, spliced 615 | 616 | def _get_repression_unspliced_spliced(self, u_0, s_0, beta, gamma, t, eps=1e-6): 617 | unspliced = torch.exp(-beta * t) * u_0 618 | spliced = s_0 * torch.exp(-gamma * t) - ( 619 | beta * u_0 / ((gamma - beta) + eps) 620 | ) * (torch.exp(-gamma * t) - torch.exp(-beta * t)) 621 | return unspliced, spliced 622 | 623 | def sample( 624 | self, 625 | ) -> np.ndarray: 626 | """Not implemented.""" 627 | raise NotImplementedError 628 | 629 | @torch.no_grad() 630 | def get_loadings(self) -> np.ndarray: 631 | """Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder.""" 632 | # This is BW, where B is diag(b) batch norm, W is weight matrix 633 | if self.decoder.linear_decoder is False: 634 | raise ValueError("Model not trained with linear decoder") 635 | w = self.decoder.rho_first_decoder.fc_layers[0][0].weight 636 | if self.use_batch_norm_decoder: 637 | bn = self.decoder.rho_first_decoder.fc_layers[0][1] 638 | sigma = torch.sqrt(bn.running_var + bn.eps) 639 | gamma = bn.weight 640 | b = gamma / sigma 641 | b_identity = torch.diag(b) 642 | loadings = torch.matmul(b_identity, w) 643 | else: 644 | loadings = w 645 | loadings = loadings.detach().cpu().numpy() 646 | 647 | return loadings 648 | -------------------------------------------------------------------------------- /velovi/_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | from urllib.request import urlretrieve 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import scvelo as scv 8 | from anndata import AnnData 9 | from sklearn.preprocessing import MinMaxScaler 10 | 11 | 12 | def get_permutation_scores(save_path: Union[str, Path] = Path("data/")) -> pd.DataFrame: 13 | """Get the reference permutation scores on positive and negative controls. 14 | 15 | Parameters 16 | ---------- 17 | save_path 18 | path to save the csv file 19 | 20 | """ 21 | if isinstance(save_path, str): 22 | save_path = Path(save_path) 23 | save_path.mkdir(parents=True, exist_ok=True) 24 | 25 | if not (save_path / "permutation_scores.csv").is_file(): 26 | URL = "https://figshare.com/ndownloader/files/36658185" 27 | urlretrieve(url=URL, filename=save_path / "permutation_scores.csv") 28 | 29 | return pd.read_csv(save_path / "permutation_scores.csv") 30 | 31 | 32 | def preprocess_data( 33 | adata: AnnData, 34 | spliced_layer: Optional[str] = "Ms", 35 | unspliced_layer: Optional[str] = "Mu", 36 | min_max_scale: bool = True, 37 | filter_on_r2: bool = True, 38 | ) -> AnnData: 39 | """Preprocess data. 40 | 41 | This function removes poorly detected genes and minmax scales the data. 42 | 43 | Parameters 44 | ---------- 45 | adata 46 | Annotated data matrix. 47 | spliced_layer 48 | Name of the spliced layer. 49 | unspliced_layer 50 | Name of the unspliced layer. 51 | min_max_scale 52 | Min-max scale spliced and unspliced 53 | filter_on_r2 54 | Filter out genes according to linear regression fit 55 | 56 | Returns 57 | ------- 58 | Preprocessed adata. 59 | """ 60 | if min_max_scale: 61 | scaler = MinMaxScaler() 62 | adata.layers[spliced_layer] = scaler.fit_transform(adata.layers[spliced_layer]) 63 | 64 | scaler = MinMaxScaler() 65 | adata.layers[unspliced_layer] = scaler.fit_transform( 66 | adata.layers[unspliced_layer] 67 | ) 68 | 69 | if filter_on_r2: 70 | scv.tl.velocity(adata, mode="deterministic") 71 | 72 | adata = adata[ 73 | :, np.logical_and(adata.var.velocity_r2 > 0, adata.var.velocity_gamma > 0) 74 | ].copy() 75 | adata = adata[:, adata.var.velocity_genes].copy() 76 | 77 | return adata 78 | --------------------------------------------------------------------------------