├── .bumpversion.cfg ├── .codecov.yaml ├── .cruft.json ├── .editorconfig ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml └── workflows │ ├── build.yaml │ ├── sync.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ └── .gitkeep ├── _templates │ ├── .gitkeep │ └── autosummary │ │ └── class.rst ├── api.md ├── changelog.md ├── conf.py ├── contributing.md ├── extensions │ └── typed_returns.py ├── index.md ├── make.bat ├── references.bib ├── references.md └── template_usage.md ├── pyproject.toml ├── src └── simple_scvi │ ├── __init__.py │ ├── _mymodel.py │ ├── _mymodule.py │ ├── _mypyromodel.py │ └── _mypyromodule.py └── tests └── test_basic.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.1 3 | tag = True 4 | commit = True 5 | 6 | [bumpversion:file:./pyproject.toml] 7 | search = version = "{current_version}" 8 | replace = version = "{new_version}" 9 | -------------------------------------------------------------------------------- /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/scverse/cookiecutter-scverse", 3 | "commit": "7cc5403b05e299d7a4bb169c2bd8c27a2a7676f3", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "simple-scvi", 8 | "package_name": "simple_scvi", 9 | "project_description": "External and simple implementation of scVI", 10 | "author_full_name": "Adam Gayoso", 11 | "author_email": "adamgayoso@berkeley.edu", 12 | "github_user": "adamgayoso", 13 | "project_repo": "https://github.com/scverse/simple-scvi", 14 | "license": "BSD 3-Clause License", 15 | "_copy_without_render": [ 16 | ".github/workflows/**.yaml", 17 | "docs/_templates/autosummary/**.rst" 18 | ], 19 | "_template": "https://github.com/scverse/cookiecutter-scverse" 20 | } 21 | }, 22 | "directory": null 23 | } 24 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [Makefile] 12 | indent_style = tab 13 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | ignore = 4 | # Unnecessary dict call - rewrite as a literal. 5 | C408 6 | # line break before a binary operator -> black does not adhere to PEP8 7 | W503 8 | # line break occured after a binary operator -> black does not adhere to PEP8 9 | W504 10 | # line too long -> we accept long comment lines; black gets rid of long code lines 11 | E501 12 | # whitespace before : -> black does not adhere to PEP8 13 | E203 14 | # missing whitespace after ,', ';', or ':' -> black does not adhere to PEP8 15 | E231 16 | # continuation line over-indented for hanging indent -> black does not adhere to PEP8 17 | E126 18 | # too many leading '#' for block comment -> this is fine for indicating sections 19 | E262 20 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 21 | E731 22 | # allow I, O, l as variable names -> I is the identity matrix 23 | E741 24 | # Missing docstring in public package 25 | D104 26 | # Missing docstring in public module 27 | D100 28 | # Missing docstring in __init__ 29 | D107 30 | # Missing docstring in magic method 31 | D105 32 | # format string does contain unindexed parameters 33 | P101 34 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 35 | D400 36 | # First line should be in imperative mood; try rephrasing 37 | D401 38 | exclude = .git,__pycache__,build,docs/_build,dist,scvi/__init__.py 39 | per-file-ignores = 40 | tests/*: D 41 | */__init__.py: F401 42 | extend-immutable-calls = 43 | # Add functions returning immutable values here to avoid B008 44 | pathlib.Path 45 | Path 46 | rst-roles = 47 | class, 48 | func, 49 | ref, 50 | meth, 51 | doc, 52 | py:class, 53 | method, 54 | attr, 55 | cite:p, 56 | cite:t, 57 | rst-directives = 58 | envvar, 59 | exception, 60 | rst-substitutions = 61 | version, 62 | extend-ignore = 63 | RST307,RST210,RST201,RST203,RST301 64 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report something that is broken or incorrect 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 9 | detailing how to provide the necessary information for us to reproduce your bug. In brief: 10 | * Please provide exact steps how to reproduce the bug in a clean Python environment. 11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procecure. 12 | * Sometimes it is not possible to share the data but usually it is possible to replicate problems on publicly 13 | available datasets or to share a subset of your data. 14 | 15 | - type: textarea 16 | id: report 17 | attributes: 18 | label: Report 19 | description: A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: versions 25 | attributes: 26 | label: Version information 27 | description: | 28 | Please paste below the output of 29 | 30 | ```python 31 | import session_info 32 | session_info.show(html=False, dependencies=True) 33 | ``` 34 | placeholder: | 35 | ----- 36 | anndata 0.8.0rc2.dev27+ge524389 37 | session_info 1.0.0 38 | ----- 39 | asttokens NA 40 | awkward 1.8.0 41 | backcall 0.2.0 42 | cython_runtime NA 43 | dateutil 2.8.2 44 | debugpy 1.6.0 45 | decorator 5.1.1 46 | entrypoints 0.4 47 | executing 0.8.3 48 | h5py 3.7.0 49 | ipykernel 6.15.0 50 | jedi 0.18.1 51 | mpl_toolkits NA 52 | natsort 8.1.0 53 | numpy 1.22.4 54 | packaging 21.3 55 | pandas 1.4.2 56 | parso 0.8.3 57 | pexpect 4.8.0 58 | pickleshare 0.7.5 59 | pkg_resources NA 60 | prompt_toolkit 3.0.29 61 | psutil 5.9.1 62 | ptyprocess 0.7.0 63 | pure_eval 0.2.2 64 | pydev_ipython NA 65 | pydevconsole NA 66 | pydevd 2.8.0 67 | pydevd_file_utils NA 68 | pydevd_plugins NA 69 | pydevd_tracing NA 70 | pygments 2.12.0 71 | pytz 2022.1 72 | scipy 1.8.1 73 | setuptools 62.5.0 74 | setuptools_scm NA 75 | six 1.16.0 76 | stack_data 0.3.0 77 | tornado 6.1 78 | traitlets 5.3.0 79 | wcwidth 0.2.5 80 | zmq 23.1.0 81 | ----- 82 | IPython 8.4.0 83 | jupyter_client 7.3.4 84 | jupyter_core 4.10.0 85 | ----- 86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0] 87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35 88 | ----- 89 | Session information updated at 2022-07-07 17:55 90 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for simple-scvi 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Check Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | package: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 3.10 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: "3.10" 18 | - name: Install build dependencies 19 | run: python -m pip install --upgrade pip wheel twine build 20 | - name: Build package 21 | run: python -m build 22 | - name: Check package 23 | run: twine check --strict dist/*.whl 24 | -------------------------------------------------------------------------------- /.github/workflows/sync.yaml: -------------------------------------------------------------------------------- 1 | name: Sync Template 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 2 * * *" # every night at 2:00 UTC 7 | 8 | jobs: 9 | sync: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.10 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.10" 17 | - name: Install dependencies 18 | # for now, pin cookiecutter version, due to https://github.com/cruft/cruft/issues/166 19 | run: python -m pip install --upgrade cruft "cookiecutter<2" pre-commit toml 20 | - name: Find Latest Tag 21 | uses: oprypin/find-latest-tag@v1.1.0 22 | id: get-latest-tag 23 | with: 24 | repository: scverse/cookiecutter-scverse 25 | releases-only: false 26 | sort-tags: true 27 | regex: '^v\d+\.\d+\.\d+$' # vX.X.X 28 | - name: Sync 29 | run: | 30 | cruft update --checkout ${{ steps.get-latest-tag.outputs.tag }} --skip-apply-ask --project-dir . 31 | - name: Create Pull Request 32 | uses: peter-evans/create-pull-request@v4 33 | with: 34 | commit-message: Automated template update from cookiecutter-scverse 35 | branch: template-update 36 | title: Automated template update from cookiecutter-scverse 37 | body: | 38 | A new version of the [scverse cookiecutter template](https://github.com/scverse/cookiecutter-scverse/releases) 39 | got released. This PR adds all new changes to your repository and helps to to stay in sync with 40 | the latest best-practice template maintained by the scverse team. 41 | 42 | **If a merge conflict arised, a `.rej` file with the rejected patch is generated. You'll need to 43 | manually merge these changes.** 44 | 45 | For more information about the template sync, please refer to the 46 | [template documentation](https://cookiecutter-scverse-instance.readthedocs.io/en/latest/developer_docs.html#automated-template-sync). 47 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | runs-on: ${{ matrix.os }} 12 | defaults: 13 | run: 14 | shell: bash -e {0} # -e to fail on error 15 | 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python: ["3.8", "3.10"] 20 | os: [ubuntu-latest] 21 | 22 | env: 23 | OS: ${{ matrix.os }} 24 | PYTHON: ${{ matrix.python }} 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python ${{ matrix.python }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python }} 32 | 33 | - name: Get pip cache dir 34 | id: pip-cache-dir 35 | run: | 36 | echo "::set-output name=dir::$(pip cache dir)" 37 | - name: Restore pip cache 38 | uses: actions/cache@v2 39 | with: 40 | path: ${{ steps.pip-cache-dir.outputs.dir }} 41 | key: pip-${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('**/pyproject.toml') }} 42 | restore-keys: | 43 | pip-${{ runner.os }}-${{ env.pythonLocation }}- 44 | - name: Install test dependencies 45 | run: | 46 | python -m pip install --upgrade pip wheel 47 | pip install codecov 48 | - name: Install dependencies 49 | run: | 50 | pip install ".[dev,test]" 51 | - name: Test 52 | env: 53 | MPLBACKEND: agg 54 | PLATFORM: ${{ matrix.os }} 55 | DISPLAY: :42 56 | run: | 57 | pytest -v --cov --color=yes 58 | - name: Upload coverage 59 | env: 60 | CODECOV_NAME: ${{ matrix.python }}-${{ matrix.os }} 61 | run: | 62 | codecov --required --flags=unittests 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | .DS_Store 3 | *~ 4 | 5 | # Compiled files 6 | __pycache__/ 7 | 8 | # Distribution / packaging 9 | /build/ 10 | /dist/ 11 | /*.egg-info/ 12 | 13 | # Tests and coverage 14 | /.pytest_cache/ 15 | /.cache/ 16 | /data/ 17 | 18 | # docs 19 | /docs/generated/ 20 | /docs/_build/ 21 | 22 | # IDEs 23 | /.idea/ 24 | /.vscode/ 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - commit 6 | - push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/psf/black 10 | rev: 24.4.2 11 | hooks: 12 | - id: black 13 | - repo: https://github.com/pre-commit/mirrors-prettier 14 | rev: v4.0.0-alpha.8 15 | hooks: 16 | - id: prettier 17 | - repo: https://github.com/asottile/blacken-docs 18 | rev: 1.16.0 19 | hooks: 20 | - id: blacken-docs 21 | - repo: https://github.com/PyCQA/isort 22 | rev: 5.13.2 23 | hooks: 24 | - id: isort 25 | - repo: https://github.com/asottile/yesqa 26 | rev: v1.5.0 27 | hooks: 28 | - id: yesqa 29 | additional_dependencies: 30 | - flake8-tidy-imports 31 | - flake8-docstrings 32 | - flake8-rst-docstrings 33 | - flake8-comprehensions 34 | - flake8-bugbear 35 | - flake8-blind-except 36 | - repo: https://github.com/pre-commit/pre-commit-hooks 37 | rev: v4.6.0 38 | hooks: 39 | - id: detect-private-key 40 | - id: check-ast 41 | - id: end-of-file-fixer 42 | - id: mixed-line-ending 43 | args: [--fix=lf] 44 | - id: trailing-whitespace 45 | - id: check-case-conflict 46 | - repo: https://github.com/PyCQA/autoflake 47 | rev: v2.3.1 48 | hooks: 49 | - id: autoflake 50 | args: 51 | - --in-place 52 | - --remove-all-unused-imports 53 | - --remove-unused-variable 54 | - --ignore-init-module-imports 55 | - repo: https://github.com/PyCQA/flake8 56 | rev: 7.1.0 57 | hooks: 58 | - id: flake8 59 | additional_dependencies: 60 | - flake8-tidy-imports 61 | - flake8-docstrings 62 | - flake8-rst-docstrings 63 | - flake8-comprehensions 64 | - flake8-bugbear 65 | - flake8-blind-except 66 | - repo: https://github.com/asottile/pyupgrade 67 | rev: v3.16.0 68 | hooks: 69 | - id: pyupgrade 70 | args: [--py3-plus, --py38-plus, --keep-runtime-typing] 71 | - repo: local 72 | hooks: 73 | - id: forbid-to-commit 74 | name: Don't commit rej files 75 | entry: | 76 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 77 | Fix the merge conflicts manually and remove the .rej files. 78 | language: fail 79 | files: '.*\.rej$' 80 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog][], 6 | and this project adheres to [Semantic Versioning][]. 7 | 8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html 10 | 11 | ## [Unreleased] 12 | 13 | ### Added 14 | 15 | - Basic tool, preprocessing and plotting functions 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Adam Gayoso 4 | Copyright (c) 2025, scverse® 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simple-scvi 2 | 3 | [![Tests][badge-tests]][link-tests] 4 | [![Documentation][badge-docs]][link-docs] 5 | 6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/scverse/simple-scvi/test.yaml?branch=main 7 | [link-tests]: https://github.com/scverse/simple-scvi/actions/workflows/test.yml 8 | [badge-docs]: https://img.shields.io/readthedocs/simple-scvi 9 | 10 | External and simple implementation of scVI. This repository shows a minimal implementation of the [scVI](https://www.nature.com/articles/s41592-018-0229-2) model using [scvi-tools](https://scvi-tools.org) in an externally deployed package. 11 | 12 | This package was initialized using the [cookicutter-scverse](https://github.com/scverse/cookiecutter-scverse) template. We advise all external projects to use the cookicutter template. 13 | 14 | ## Getting started 15 | 16 | Please refer to the [documentation][link-docs]. In particular, the 17 | 18 | - [API documentation][link-api]. 19 | 20 | ## Installation 21 | 22 | You need to have Python 3.8 or newer installed on your system. If you don't have 23 | Python installed, we recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge). 24 | 25 | There are several alternative options to install simple-scvi: 26 | 27 | 34 | 35 | 1. Install the latest development version: 36 | 37 | ```bash 38 | pip install git+https://github.com/adamgayoso/simple-scvi.git@main 39 | ``` 40 | 41 | ## Release notes 42 | 43 | See the [changelog][changelog]. 44 | 45 | ## Contact 46 | 47 | For questions and help requests, you can reach out in the [scverse discourse][scverse-discourse]. 48 | If you found a bug, please use the [issue tracker][issue-tracker]. 49 | 50 | ## Citation 51 | 52 | ``` 53 | @article{gayoso2022python, 54 | title={A Python library for probabilistic analysis of single-cell omics data}, 55 | author={Gayoso, Adam and Lopez, Romain and Xing, Galen and Boyeau, Pierre and Valiollah Pour Amiri, Valeh and Hong, Justin and Wu, Katherine and Jayasuriya, Michael and Mehlman, Edouard and Langevin, Maxime and others}, 56 | journal={Nature biotechnology}, 57 | volume={40}, 58 | number={2}, 59 | pages={163--166}, 60 | year={2022}, 61 | publisher={Nature Publishing Group US New York} 62 | } 63 | ``` 64 | 65 | [scverse-discourse]: https://discourse.scverse.org/ 66 | [issue-tracker]: https://github.com/adamgayoso/simple-scvi/issues 67 | [changelog]: https://simple-scvi.readthedocs.io/latest/changelog.html 68 | [link-docs]: https://simple-scvi.readthedocs.io 69 | [link-api]: https://simple-scvi.readthedocs.io/latest/api.html 70 | 71 | [//]: # (numfocus-fiscal-sponsor-attribution) 72 | 73 | simple-scvi is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/). 74 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs. 75 | 76 |
77 | 78 | 82 | 83 |
84 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/simple-scvi/76eb8c139202b81373f3828b2ce28a0f74f5d381/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/simple-scvi/76eb8c139202b81373f3828b2ce28a0f74f5d381/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | {{ item }} 43 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 44 | 45 | .. autoattribute:: {{ [objname, item] | join(".") }} 46 | {%- endfor %} 47 | 48 | {% endif %} 49 | {% endblock %} 50 | 51 | {% block methods_documentation %} 52 | {% if methods %} 53 | Methods 54 | ~~~~~~~ 55 | 56 | {% for item in methods %} 57 | {%- if item != '__init__' %} 58 | 59 | {{ item }} 60 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 61 | 62 | .. automethod:: {{ [objname, item] | join(".") }} 63 | {%- endif -%} 64 | {%- endfor %} 65 | 66 | {% endif %} 67 | {% endblock %} 68 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Core 4 | 5 | ```{eval-rst} 6 | .. module:: simple_scvi 7 | .. currentmodule:: simple_scvi 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | 12 | MyModel 13 | MyPyroModel 14 | ``` 15 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | import sys 9 | from datetime import datetime 10 | from importlib.metadata import metadata 11 | from pathlib import Path 12 | 13 | HERE = Path(__file__).parent 14 | sys.path.insert(0, str(HERE / "extensions")) 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | info = metadata("simple-scvi") 20 | project_name = info["Name"] 21 | author = info["Author"] 22 | copyright = f"{datetime.now():%Y}, {author}." 23 | version = info["Version"] 24 | repository_url = f"https://github.com/adamgayoso/{project_name}" 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = info["Version"] 28 | 29 | bibtex_bibfiles = ["references.bib"] 30 | templates_path = ["_templates"] 31 | nitpicky = True # Warn about broken links 32 | needs_sphinx = "4.0" 33 | 34 | html_context = { 35 | "display_github": True, # Integrate GitHub 36 | "github_user": "adamgayoso", # Username 37 | "github_repo": project_name, # Repo name 38 | "github_version": "main", # Version 39 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 40 | } 41 | 42 | # -- General configuration --------------------------------------------------- 43 | 44 | # Add any Sphinx extension module names here, as strings. 45 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 46 | extensions = [ 47 | "myst_nb", 48 | "sphinx_copybutton", 49 | "sphinx.ext.autodoc", 50 | "sphinx.ext.intersphinx", 51 | "sphinx.ext.autosummary", 52 | "sphinx.ext.napoleon", 53 | "sphinxcontrib.bibtex", 54 | "sphinx_autodoc_typehints", 55 | "sphinx.ext.mathjax", 56 | "IPython.sphinxext.ipython_console_highlighting", 57 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 58 | ] 59 | 60 | autosummary_generate = True 61 | autodoc_member_order = "groupwise" 62 | default_role = "literal" 63 | napoleon_google_docstring = False 64 | napoleon_numpy_docstring = True 65 | napoleon_include_init_with_doc = False 66 | napoleon_use_rtype = True # having a separate entry generally helps readability 67 | napoleon_use_param = True 68 | myst_heading_anchors = 3 # create anchors for h1-h3 69 | myst_enable_extensions = [ 70 | "amsmath", 71 | "colon_fence", 72 | "deflist", 73 | "dollarmath", 74 | "html_image", 75 | "html_admonition", 76 | ] 77 | myst_url_schemes = ("http", "https", "mailto") 78 | nb_output_stderr = "remove" 79 | nb_execution_mode = "off" 80 | nb_merge_streams = True 81 | typehints_defaults = "braces" 82 | 83 | source_suffix = { 84 | ".rst": "restructuredtext", 85 | ".ipynb": "myst-nb", 86 | ".myst": "myst-nb", 87 | } 88 | 89 | intersphinx_mapping = { 90 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 91 | "numpy": ("https://numpy.org/doc/stable/", None), 92 | } 93 | 94 | # List of patterns, relative to source directory, that match files and 95 | # directories to ignore when looking for source files. 96 | # This pattern also affects html_static_path and html_extra_path. 97 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 98 | 99 | 100 | # -- Options for HTML output ------------------------------------------------- 101 | 102 | # The theme to use for HTML and HTML Help pages. See the documentation for 103 | # a list of builtin themes. 104 | # 105 | html_theme = "sphinx_book_theme" 106 | html_static_path = ["_static"] 107 | html_title = project_name 108 | 109 | html_theme_options = { 110 | "repository_url": repository_url, 111 | "use_repository_button": True, 112 | } 113 | 114 | pygments_style = "default" 115 | 116 | nitpick_ignore = [ 117 | # If building the documentation fails because of a missing link that is outside your control, 118 | # you can add an exception to this list. 119 | # ("py:class", "igraph.Graph"), 120 | ] 121 | 122 | 123 | def setup(app): 124 | """App setup hook.""" 125 | app.add_config_value( 126 | "recommonmark_config", 127 | { 128 | "auto_toc_tree_section": "Contents", 129 | "enable_auto_toc_tree": True, 130 | "enable_math": True, 131 | "enable_inline_math": False, 132 | "enable_eval_rst": True, 133 | }, 134 | True, 135 | ) 136 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this repo, too. 4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important 5 | information to get you started on contributing. 6 | 7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer 8 | to the [scanpy developer guide][]. 9 | 10 | ## Installing dev dependencies 11 | 12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build 13 | the documentation_. It's easy to install them using `pip`: 14 | 15 | ```bash 16 | cd simple-scvi 17 | pip install -e ".[dev,test,doc]" 18 | ``` 19 | 20 | ## Code-style 21 | 22 | This template uses [pre-commit][] to enforce consistent code-styles. On every commit, pre-commit checks will either 23 | automatically fix issues with the code, or raise an error message. See [pre-commit checks](template_usage.md#pre-commit-checks) for 24 | a full list of checks enabled for this repository. 25 | 26 | To enable pre-commit locally, simply run 27 | 28 | ```bash 29 | pre-commit install 30 | ``` 31 | 32 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time. 33 | 34 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before 35 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message. 36 | 37 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use 38 | 39 | ```bash 40 | git pull --rebase 41 | ``` 42 | 43 | to integrate the changes into yours. 44 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage. 45 | 46 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [black][black-editors] 47 | and [prettier][prettier-editors]. 48 | 49 | [black-editors]: https://black.readthedocs.io/en/stable/integrations/editors.html 50 | [prettier-editors]: https://prettier.io/docs/en/editors.html 51 | 52 | ## Writing tests 53 | 54 | ```{note} 55 | Remember to first install the package with `pip install '-e[dev,test]'` 56 | ``` 57 | 58 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added 59 | to the package. 60 | 61 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the 62 | command line by executing 63 | 64 | ```bash 65 | pytest 66 | ``` 67 | 68 | in the root of the repository. Continuous integration will automatically run the tests on all pull requests. 69 | 70 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests 71 | 72 | ## Publishing a release 73 | 74 | ### Updating the version number 75 | 76 | Before making a release, you need to update the version number. Please adhere to [Semantic Versioning][semver], in brief 77 | 78 | > Given a version number MAJOR.MINOR.PATCH, increment the: 79 | > 80 | > 1. MAJOR version when you make incompatible API changes, 81 | > 2. MINOR version when you add functionality in a backwards compatible manner, and 82 | > 3. PATCH version when you make backwards compatible bug fixes. 83 | > 84 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 85 | 86 | We use [bump2version][] to automatically update the version number in all places and automatically create a git tag. 87 | Run one of the following commands in the root of the repository 88 | 89 | ```bash 90 | bump2version patch 91 | bump2version minor 92 | bump2version major 93 | ``` 94 | 95 | Once you are done, run 96 | 97 | ``` 98 | git push --tags 99 | ``` 100 | 101 | to publish the created tag on GitHub. 102 | 103 | [bump2version]: https://github.com/c4urself/bump2version 104 | 105 | ### Building and publishing the package on PyPI 106 | 107 | Python packages are not distributed as source code, but as _distributions_. The most common distribution format is the so-called _wheel_. To build a _wheel_, run 108 | 109 | ```bash 110 | python -m build 111 | ``` 112 | 113 | This command creates a _source archive_ and a _wheel_, which are required for publishing your package to [PyPI][]. These files are created directly in the root of the repository. 114 | 115 | Before uploading them to [PyPI][] you can check that your _distribution_ is valid by running: 116 | 117 | ```bash 118 | twine check dist/* 119 | ``` 120 | 121 | and finally publishing it with: 122 | 123 | ```bash 124 | twine upload dist/* 125 | ``` 126 | 127 | Provide your username and password when requested and then go check out your package on [PyPI][]! 128 | 129 | For more information, follow the [Python packaging tutorial][]. 130 | 131 | It is possible to automate this with GitHub actions, see also [this feature request][pypi-feature-request] 132 | in the cookiecutter-scverse template. 133 | 134 | [python packaging tutorial]: https://packaging.python.org/en/latest/tutorials/packaging-projects/#generating-distribution-archives 135 | [pypi-feature-request]: https://github.com/scverse/cookiecutter-scverse/issues/88 136 | 137 | ## Writing documentation 138 | 139 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: 140 | 141 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text 142 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). 143 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) 144 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types 145 | 146 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information 147 | on how to write documentation. 148 | 149 | ### Tutorials with myst-nb and jupyter notebooks 150 | 151 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. 152 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. 153 | It is your reponsibility to update and re-run the notebook whenever necessary. 154 | 155 | If you are interested in automatically running notebooks as part of the continuous integration, please check 156 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse` 157 | repository. 158 | 159 | #### Hints 160 | 161 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only 162 | if you do so can sphinx automatically create a link to the external documentation. 163 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to 164 | the `nitpick_ignore` list in `docs/conf.py` 165 | 166 | #### Building the docs locally 167 | 168 | ```bash 169 | cd docs 170 | make html 171 | open _build/html/index.html 172 | ``` 173 | 174 | 175 | 176 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 177 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html 178 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 179 | [codecov]: https://about.codecov.io/sign-up/ 180 | [codecov docs]: https://docs.codecov.com/docs 181 | [codecov bot]: https://docs.codecov.com/docs/team-bot 182 | [codecov app]: https://github.com/apps/codecov 183 | [pre-commit.ci]: https://pre-commit.ci/ 184 | [readthedocs.org]: https://readthedocs.org/ 185 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 186 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 187 | [pre-commit]: https://pre-commit.com/ 188 | [anndata]: https://github.com/scverse/anndata 189 | [mudata]: https://github.com/scverse/mudata 190 | [pytest]: https://docs.pytest.org/ 191 | [semver]: https://semver.org/ 192 | [sphinx]: https://www.sphinx-doc.org/en/master/ 193 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 194 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 195 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 196 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 197 | [pypi]: https://pypi.org/ 198 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | import re 4 | 5 | from sphinx.application import Sphinx 6 | from sphinx.ext.napoleon import NumpyDocstring 7 | 8 | 9 | def _process_return(lines): 10 | for line in lines: 11 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) 12 | if m: 13 | # Once this is in scanpydoc, we can use the fancy hover stuff 14 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 15 | else: 16 | yield line 17 | 18 | 19 | def _parse_returns_section(self, section): 20 | lines_raw = list(_process_return(self._dedent(self._consume_to_next_section()))) 21 | lines = self._format_block(":returns: ", lines_raw) 22 | if lines and lines[-1]: 23 | lines.append("") 24 | return lines 25 | 26 | 27 | def setup(app: Sphinx): 28 | """Set app.""" 29 | NumpyDocstring._parse_returns_section = _parse_returns_section 30 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | ```{toctree} 6 | :hidden: true 7 | :maxdepth: 1 8 | 9 | api.md 10 | changelog.md 11 | template_usage.md 12 | contributing.md 13 | references.md 14 | ``` 15 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Lopez18, 2 | title = {Deep generative modeling for single-cell transcriptomics}, 3 | author = {Romain Lopez and Jeffrey Regier and Michael B. Cole and Michael I. Jordan and Nir Yosef}, 4 | doi = {10.1038/s41592-018-0229-2}, 5 | year = {2018}, 6 | month = nov, 7 | journal = {Nature Methods}, 8 | volume = {15}, 9 | number = {12}, 10 | pages = {1053--1058}, 11 | publisher = {Springer Science and Business Media {LLC}} 12 | } 13 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/template_usage.md: -------------------------------------------------------------------------------- 1 | # Using this template 2 | 3 | Welcome to the developer guidelines! This document is split into two parts: 4 | 5 | 1. The [repository setup](#setting-up-the-repository). This section is relevant primarily for the repository maintainer and shows how to connect 6 | continuous integration services and documents initial set-up of the repository. 7 | 2. The [contributor guide](contributing.md#contributing-guide). It contains information relevant to all developers who want to make a contribution. 8 | 9 | ## Setting up the repository 10 | 11 | ### First commit 12 | 13 | If you are reading this, you should have just completed the repository creation with : 14 | 15 | ```bash 16 | cruft create https://github.com/scverse/cookiecutter-scverse 17 | ``` 18 | 19 | and you should have 20 | 21 | ``` 22 | cd simple-scvi 23 | ``` 24 | 25 | into the new project directory. Now that you have created a new repository locally, the first step is to push it to github. To do this, you'd have to create a **new repository** on github. 26 | You can follow the instructions directly on [github quickstart guide][]. 27 | Since `cruft` already populated the local repository of your project with all the necessary files, we suggest to _NOT_ initialize the repository with a `README.md` file or `.gitignore`, because you might encounter git conflicts on your first push. 28 | If you are familiar with git and knows how to handle git conflicts, you can go ahead with your preferred choice. 29 | 30 | :::{note} 31 | If you are looking at this document in the [cookiecutter-scverse-instance][] repository documentation, throughout this document the name of the project is `cookiecutter-scverse-instance`. Otherwise it should be replaced by your new project name: `simple-scvi`. 32 | ::: 33 | 34 | Now that your new project repository has been created on github at `https://github.com/adamgayoso/simple-scvi` you can push your first commit to github. 35 | To do this, simply follow the instructions on your github repository page or a more verbose walkthrough here: 36 | 37 | Assuming you are in `/your/path/to/simple-scvi`. Add all files and commit. 38 | 39 | ```bash 40 | # stage all files of your new repo 41 | git add --all 42 | # commit 43 | git commit -m "first commit" 44 | ``` 45 | 46 | You'll notice that the command `git commit` installed a bunch of packages and triggered their execution: those are pre-commit! To read more about what they are and what they do, you can go to the related section [Pre-commit checks](#pre-commit-checks) in this document. 47 | 48 | :::{note} 49 | There is a chance that `git commit -m "first commit"` fails due to the `prettier` pre-commit formatting the file `.cruft.json`. No problem, you have just experienced what pre-commit checks do in action. Just go ahead and re-add the modified file and try to commit again: 50 | 51 | ```bash 52 | git add -u # update all tracked file 53 | git commit -m "first commit" 54 | ``` 55 | 56 | ::: 57 | 58 | Now that all the files of the newly created project have been committed, go ahead with the remaining steps: 59 | 60 | ```bash 61 | # update the `origin` of your local repo with the remote github link 62 | git remote add origin https://github.com/adamgayoso/simple-scvi.git 63 | # rename the default branch to main 64 | git branch -M main 65 | # push all your files to remote 66 | git push -u origin main 67 | ``` 68 | 69 | Your project should be now available at `https://github.com/adamgayoso/simple-scvi`. While the repository at this point can be directly used, there are few remaining steps that needs to be done in order to achieve full functionality. 70 | 71 | ### Coverage tests with _Codecov_ 72 | 73 | Coverage tells what fraction of the code is "covered" by unit tests, thereby encouraging contributors to 74 | [write tests](contributing.md#writing-tests). 75 | To enable coverage checks, head over to [codecov][] and sign in with your GitHub account. 76 | You'll find more information in "getting started" section of the [codecov docs][]. 77 | 78 | In the `Actions` tab of your projects' github repository, you can see that the workflows are failing due to the **Upload coverage** step. The error message in the workflow should display something like: 79 | 80 | ``` 81 | ... 82 | Retrying 5/5 in 2s.. 83 | {'detail': ErrorDetail(string='Could not find a repository, try using repo upload token', code='not_found')} 84 | Error: 404 Client Error: Not Found for url: 85 | ... 86 | ``` 87 | 88 | While [codecov docs][] has a very extensive documentation on how to get started, _if_ you are using the default settings of this template we can assume that you are using [codecov][] in a github action workflow and hence you can make use of the [codecov bot][]. 89 | 90 | To set it up, simply go to the [codecov app][] page and follow the instructions to activate it for your repository. 91 | Once the activation is completed, go back to the `Actions` tab and re-run the failing workflows. 92 | 93 | The workflows should now succeed and you will be able to find the code coverage at this link: `https://app.codecov.io/gh/adamgayoso/simple-scvi`. You might have to wait couple of minutes and the coverage of this repository should be ~60%. 94 | 95 | If your repository is private, you will have to specify an additional token in the repository secrets. In brief, you need to: 96 | 97 | 1. Generate a Codecov Token by clicking _setup repo_ in the codecov dashboard. 98 | - If you have already set up codecov in the repository by following the previous steps, you can directly go to the codecov repo webpage. 99 | 2. Go to _Settings_ and copy **only** the token `_______-____-...`. 100 | 3. Go to _Settings_ of your newly created repository on GitHub. 101 | 4. Go to _Security > Secrets > Actions_. 102 | 5. Create new repository secret with name `CODECOV_TOKEN` and paste the token generated by codecov. 103 | 6. Past these additional lines in `/.github/workflows.test.yaml` under the **Upload coverage** step: 104 | ```bash 105 | - name: Upload coverage 106 | uses: codecov/codecov-action@v3 107 | with: 108 | token: ${{ secrets.CODECOV_TOKEN }} 109 | ``` 110 | 7. Go back to github `Actions` page an re-run previously failed jobs. 111 | 112 | ### Documentation on _readthedocs_ 113 | 114 | We recommend using [readthedocs.org][] (RTD) to build and host the documentation for your project. 115 | To enable readthedocs, head over to [their website][readthedocs.org] and sign in with your GitHub account. 116 | On the RTD dashboard choose "Import a Project" and follow the instructions to add your repository. 117 | 118 | - Make sure to choose the correct name of the default branch. On GitHub, the name of the default branch should be `main` (it has 119 | recently changed from `master` to `main`). 120 | - We recommend to enable documentation builds for pull requests (PRs). This ensures that a PR doesn't introduce changes 121 | that break the documentation. To do so, got to `Admin -> Advanced Settings`, check the 122 | `Build pull requests for this projects` option, and click `Save`. For more information, please refer to 123 | the [official RTD documentation](https://docs.readthedocs.io/en/stable/pull-requests.html). 124 | - If you find the RTD builds are failing, you can disable the `fail_on_warning` option in `.readthedocs.yaml`. 125 | 126 | If your project is private, there are ways to enable docs rendering on [readthedocs.org][] but it is more cumbersome and requires a different subscription for read the docs. See a guide [here](https://docs.readthedocs.io/en/stable/guides/importing-private-repositories.html). 127 | 128 | ### Pre-commit checks 129 | 130 | [Pre-commit][] checks are fast programs that 131 | check code for errors, inconsistencies and code styles, before the code 132 | is committed. 133 | 134 | We recommend setting up [pre-commit.ci][] to enforce consistency checks on every commit 135 | and pull-request. 136 | 137 | To do so, head over to [pre-commit.ci][] and click "Sign In With GitHub". Follow 138 | the instructions to enable pre-commit.ci for your account or your organization. You 139 | may choose to enable the service for an entire organization or on a per-repository basis. 140 | 141 | Once authorized, pre-commit.ci should automatically be activated. 142 | 143 | #### Overview of pre-commit hooks used by the template 144 | 145 | The following pre-commit checks are for code style and format: 146 | 147 | - [black](https://black.readthedocs.io/en/stable/): standard code 148 | formatter in Python. 149 | - [isort](https://pycqa.github.io/isort/): sort module imports into 150 | sections and types. 151 | - [prettier](https://prettier.io/docs/en/index.html): standard code 152 | formatter for non-Python files (e.g. YAML). 153 | - [blacken-docs](https://github.com/asottile/blacken-docs): black on 154 | python code in docs. 155 | 156 | The following pre-commit checks are for errors and inconsistencies: 157 | 158 | - [flake8](https://flake8.pycqa.org/en/latest/): standard check for errors in Python files. 159 | - [flake8-tidy-imports](https://github.com/adamchainz/flake8-tidy-imports): 160 | tidy module imports. 161 | - [flake8-docstrings](https://github.com/PyCQA/flake8-docstrings): 162 | pydocstyle extension of flake8. 163 | - [flake8-rst-docstrings](https://github.com/peterjc/e8-rst-docstrings): 164 | extension of `flake8-docstrings` for `rst` docs. 165 | - [flake8-comprehensions](https://github.com/adamchainz/e8-comprehensions): 166 | write better list/set/dict comprehensions. 167 | - [flake8-bugbear](https://github.com/PyCQA/flake8-bugbear): 168 | find possible bugs and design issues in program. 169 | - [flake8-blind-except](https://github.com/elijahandrews/flake8-blind-except): 170 | checks for blind, catch-all `except` statements. 171 | - [yesqa](https://github.com/asottile/yesqa): 172 | remove unneccesary `# noqa` comments, follows additional dependencies listed above. 173 | - [autoflake](https://github.com/PyCQA/autoflake): 174 | remove unused imports and variables. 175 | - [pre-commit-hooks](https://github.com/pre-commit/pre-commit-hooks): generic pre-commit hooks. 176 | - **detect-private-key**: checks for the existence of private keys. 177 | - **check-ast**: check whether files parse as valid python. 178 | - **end-of-file-fixer**:check files end in a newline and only a newline. 179 | - **mixed-line-ending**: checks mixed line ending. 180 | - **trailing-whitespace**: trims trailing whitespace. 181 | - **check-case-conflict**: check files that would conflict with case-insensitive file systems. 182 | - [pyupgrade](https://github.com/asottile/pyupgrade): 183 | upgrade syntax for newer versions of the language. 184 | - **forbid-to-commit**: Make sure that `*.rej` files cannot be commited. These files are created by the 185 | [automated template sync](#automated-template-sync) if there's a merge conflict and need to be addressed manually. 186 | 187 | ### How to disable or add pre-commit checks 188 | 189 | - To ignore lint warnigs from **flake8**, see [Ignore certain lint warnings](#how-to-ignore-certain-lint-warnings). 190 | - You can add or remove pre-commit checks by simply deleting relevant lines in the `.pre-commit-config.yaml` file. 191 | Some pre-commit checks have additional options that can be specified either in the `pyproject.toml` or tool-specific 192 | config files, such as `.prettierrc.yml` for **prettier** and `.flake8` for **flake8**. 193 | 194 | ### How to ignore certain lint warnings 195 | 196 | The [pre-commit checks](#pre-commit-checks) include [flake8](https://flake8.pycqa.org/en/latest/) which checks 197 | for errors in Python files, including stylistic errors. 198 | 199 | In some cases it might overshoot and you may have good reasons to ignore certain warnings. 200 | 201 | To ignore an specific error on a per-case basis, you can add a comment `# noqa` to the offending line. You can also 202 | specify the error ID to ignore, with e.g. `# noqa: E731`. Check the [flake8 guide][] for reference. 203 | 204 | Alternatively, you can disable certain error messages for the entire project. To do so, edit the `.flake8` 205 | file in the root of the repository. Add one line per linting code you wish to ignore and don't forget to add a comment. 206 | 207 | ```toml 208 | ... 209 | # line break before a binary operator -> black does not adhere to PEP8 210 | W503 211 | # line break occured after a binary operator -> black does not adhere to PEP8 212 | W504 213 | ... 214 | ``` 215 | 216 | [flake8 guide]: https://flake8.pycqa.org/en/3.1.1/user/ignoring-errors.html 217 | 218 | ### API design 219 | 220 | Scverse ecosystem packages should operate on [AnnData][] and/or [MuData][] data structures and typically use an API 221 | as originally [introduced by scanpy][scanpy-api] with the following submodules: 222 | 223 | - `pp` for preprocessing 224 | - `tl` for tools (that, compared to `pp` generate interpretable output, often associated with a corresponding plotting 225 | function) 226 | - `pl` for plotting functions 227 | 228 | You may add additional submodules as appropriate. While we encourage to follow a scanpy-like API for ecosystem packages, 229 | there may also be good reasons to choose a different approach, e.g. using an object-oriented API. 230 | 231 | [scanpy-api]: https://scanpy.readthedocs.io/en/stable/usage-principles.html 232 | 233 | ### Using VCS-based versioning 234 | 235 | By default, the template uses hard-coded version numbers that are set in `pyproject.toml` and [managed with 236 | bump2version](contributing.md#publishing-a-release). If you prefer to have your project automatically infer version numbers from git 237 | tags, it is straightforward to switch to vcs-based versioning using [hatch-vcs][]. 238 | 239 | In `pyproject.toml` add the following changes, and you are good to go! 240 | 241 | ```diff 242 | --- a/pyproject.toml 243 | +++ b/pyproject.toml 244 | @@ -1,11 +1,11 @@ 245 | [build-system] 246 | build-backend = "hatchling.build" 247 | -requires = ["hatchling"] 248 | +requires = ["hatchling", "hatch-vcs"] 249 | 250 | 251 | [project] 252 | name = "simple-scvi" 253 | -version = "0.3.1dev" 254 | +dynamic = ["version"] 255 | 256 | @@ -60,6 +60,9 @@ 257 | +[tool.hatch.version] 258 | +source = "vcs" 259 | + 260 | [tool.coverage.run] 261 | source = ["simple-scvi"] 262 | omit = [ 263 | ``` 264 | 265 | Don't forget to update the [Making a release section](contributing.md#publishing-a-release) in this document accordingly, after you are done! 266 | 267 | [hatch-vcs]: https://pypi.org/project/hatch-vcs/ 268 | 269 | ### Automated template sync 270 | 271 | Automated template sync is enabled by default. This means that every night, a GitHub action runs [cruft][] to check 272 | if a new version of the `scverse-cookiecutter` template got released. If there are any new changes, a pull request 273 | proposing these changes is created automatically. This helps keeping the repository up-to-date with the latest 274 | coding standards. 275 | 276 | It may happen that a template sync results in a merge conflict. If this is the case a `*.ref` file with the 277 | diff is created. You need to manually address these changes and remove the `.rej` file when you are done. 278 | The pull request can only be merged after all `*.rej` files have been removed. 279 | 280 | :::{tip} 281 | The following hints may be useful to work with the template sync: 282 | 283 | - GitHub automatically disables scheduled actions if there has been not activity to the repository for 60 days. 284 | You can re-enable or manually trigger the sync by navigating to `Actions` -> `Sync Template` in your GitHub repository. 285 | - If you want to ignore certain files from the template update, you can add them to the `[tool.cruft]` section in the 286 | `pyproject.toml` file in the root of your repository. More details are described in the 287 | [cruft documentation][cruft-update-project]. 288 | - To disable the sync entirely, simply remove the file `.github/workflows/sync.yaml`. 289 | 290 | ::: 291 | 292 | [cruft]: https://cruft.github.io/cruft/ 293 | [cruft-update-project]: https://cruft.github.io/cruft/#updating-a-project 294 | 295 | ## Moving forward 296 | 297 | You have reached the end of this document. Congratulations! You have successfully set up your project and are ready to start. 298 | For everything else related to documentation, code style, testing and publishing your project ot pypi, please refer to the [contributing docs](contributing.md#contributing-guide). 299 | 300 | 301 | 302 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 303 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html 304 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 305 | [codecov]: https://about.codecov.io/sign-up/ 306 | [codecov docs]: https://docs.codecov.com/docs 307 | [codecov bot]: https://docs.codecov.com/docs/team-bot 308 | [codecov app]: https://github.com/apps/codecov 309 | [pre-commit.ci]: https://pre-commit.ci/ 310 | [readthedocs.org]: https://readthedocs.org/ 311 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 312 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 313 | [pre-commit]: https://pre-commit.com/ 314 | [anndata]: https://github.com/scverse/anndata 315 | [mudata]: https://github.com/scverse/mudata 316 | [pytest]: https://docs.pytest.org/ 317 | [semver]: https://semver.org/ 318 | [sphinx]: https://www.sphinx-doc.org/en/master/ 319 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 320 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 321 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 322 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 323 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | 6 | [project] 7 | name = "simple-scvi" 8 | version = "0.0.1" 9 | description = "External and simple implementation of scVI" 10 | readme = "README.md" 11 | requires-python = ">=3.8" 12 | license = {file = "LICENSE"} 13 | authors = [ 14 | {name = "Adam Gayoso"}, 15 | ] 16 | maintainers = [ 17 | {name = "Adam Gayoso", email = "adamgayoso@berkeley.edu"}, 18 | ] 19 | urls.Documentation = "https://simple-scvi.readthedocs.io/" 20 | urls.Source = "https://github.com/scverse/simple-scvi" 21 | urls.Home-page = "https://github.com/scverse/simple-scvi" 22 | dependencies = [ 23 | "anndata", 24 | # for debug logging (referenced from the issue template) 25 | "session-info", 26 | "rich", 27 | "scvi-tools>=0.20.1", 28 | "torch", 29 | ] 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | # CLI for bumping the version number 34 | "bump2version", 35 | "pre-commit", 36 | "twine>=4.0.2" 37 | ] 38 | doc = [ 39 | "sphinx>=4", 40 | "sphinx-book-theme>=0.3.3", 41 | "myst-nb", 42 | "sphinxcontrib-bibtex>=1.0.0", 43 | "sphinx-autodoc-typehints", 44 | # For notebooks 45 | "ipykernel", 46 | "ipython", 47 | "sphinx-copybutton", 48 | ] 49 | test = [ 50 | "pytest", 51 | "pytest-cov", 52 | ] 53 | 54 | [tool.coverage.run] 55 | source = ["simple_scvi"] 56 | omit = [ 57 | "**/test_*.py", 58 | ] 59 | 60 | [tool.pytest.ini_options] 61 | testpaths = ["tests"] 62 | xfail_strict = true 63 | addopts = [ 64 | "--import-mode=importlib", # allow using test files with same name 65 | ] 66 | 67 | [tool.isort] 68 | include_trailing_comma = true 69 | multi_line_output = 3 70 | profile = "black" 71 | skip_glob = ["docs/*"] 72 | 73 | [tool.black] 74 | line-length = 120 75 | target-version = ['py38'] 76 | include = '\.pyi?$' 77 | exclude = ''' 78 | ( 79 | /( 80 | \.eggs 81 | | \.git 82 | | \.hg 83 | | \.mypy_cache 84 | | \.tox 85 | | \.venv 86 | | _build 87 | | buck-out 88 | | build 89 | | dist 90 | )/ 91 | ) 92 | ''' 93 | 94 | [tool.jupytext] 95 | formats = "ipynb,md" 96 | 97 | [tool.cruft] 98 | skip = [ 99 | "tests", 100 | "src/**/__init__.py", 101 | "src/**/basic.py", 102 | "docs/api.md", 103 | "docs/changelog.md", 104 | "docs/references.bib", 105 | "docs/references.md", 106 | "docs/notebooks/example.ipynb" 107 | ] 108 | -------------------------------------------------------------------------------- /src/simple_scvi/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from importlib.metadata import version 3 | 4 | from rich.console import Console 5 | from rich.logging import RichHandler 6 | 7 | from ._mymodel import MyModel, MyModule 8 | from ._mypyromodel import MyPyroModel, MyPyroModule 9 | 10 | logger = logging.getLogger(__name__) 11 | # set the logging level 12 | logger.setLevel(logging.INFO) 13 | 14 | # nice logging outputs 15 | console = Console(force_terminal=True) 16 | if console.is_jupyter is True: 17 | console.is_jupyter = False 18 | ch = RichHandler(show_path=False, console=console, show_time=False) 19 | formatter = logging.Formatter("simple_scvi: %(message)s") 20 | ch.setFormatter(formatter) 21 | logger.addHandler(ch) 22 | 23 | # this prevents double outputs 24 | logger.propagate = False 25 | 26 | __all__ = ["MyModel", "MyModule", "MyPyroModel", "MyPyroModule"] 27 | 28 | __version__ = version("simple-scvi") 29 | -------------------------------------------------------------------------------- /src/simple_scvi/_mymodel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | from anndata import AnnData 5 | from scvi import REGISTRY_KEYS 6 | from scvi.data import AnnDataManager 7 | from scvi.data.fields import ( 8 | CategoricalJointObsField, 9 | CategoricalObsField, 10 | LayerField, 11 | NumericalJointObsField, 12 | ) 13 | from scvi.model._utils import _init_library_size 14 | from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin 15 | from scvi.utils import setup_anndata_dsp 16 | 17 | from ._mymodule import MyModule 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): 23 | """ 24 | Skeleton for an scvi-tools model. 25 | 26 | Please use this skeleton to create new models. This is a simple 27 | implementation of the scVI model :cite:p:`Lopez18`. 28 | 29 | Parameters 30 | ---------- 31 | adata 32 | AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`. 33 | n_hidden 34 | Number of nodes per hidden layer. 35 | n_latent 36 | Dimensionality of the latent space. 37 | n_layers 38 | Number of hidden layers used for encoder and decoder NNs. 39 | **model_kwargs 40 | Keyword args for :class:`~mypackage.MyModule` 41 | Examples 42 | -------- 43 | >>> adata = anndata.read_h5ad(path_to_anndata) 44 | >>> mypackage.MyModel.setup_anndata(adata, batch_key="batch") 45 | >>> vae = mypackage.MyModel(adata) 46 | >>> vae.train() 47 | >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() 48 | """ 49 | 50 | def __init__( 51 | self, 52 | adata: AnnData, 53 | n_hidden: int = 128, 54 | n_latent: int = 10, 55 | n_layers: int = 1, 56 | **model_kwargs, 57 | ): 58 | super().__init__(adata) 59 | 60 | library_log_means, library_log_vars = _init_library_size(self.adata_manager, self.summary_stats["n_batch"]) 61 | 62 | # self.summary_stats provides information about anndata dimensions and other tensor info 63 | 64 | self.module = MyModule( 65 | n_input=self.summary_stats["n_vars"], 66 | n_hidden=n_hidden, 67 | n_latent=n_latent, 68 | n_layers=n_layers, 69 | library_log_means=library_log_means, 70 | library_log_vars=library_log_vars, 71 | **model_kwargs, 72 | ) 73 | self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" 74 | # necessary line to get params that will be used for saving/loading 75 | self.init_params_ = self._get_init_params(locals()) 76 | 77 | logger.info("The model has been initialized") 78 | 79 | @classmethod 80 | @setup_anndata_dsp.dedent 81 | def setup_anndata( 82 | cls, 83 | adata: AnnData, 84 | batch_key: Optional[str] = None, 85 | labels_key: Optional[str] = None, 86 | layer: Optional[str] = None, 87 | categorical_covariate_keys: Optional[List[str]] = None, 88 | continuous_covariate_keys: Optional[List[str]] = None, 89 | **kwargs, 90 | ) -> Optional[AnnData]: 91 | """ 92 | %(summary)s. 93 | 94 | Parameters 95 | ---------- 96 | %(param_adata)s 97 | %(param_batch_key)s 98 | %(param_labels_key)s 99 | %(param_layer)s 100 | %(param_cat_cov_keys)s 101 | %(param_cont_cov_keys)s 102 | Returns 103 | ------- 104 | %(returns)s 105 | """ 106 | setup_method_args = cls._get_setup_method_args(**locals()) 107 | anndata_fields = [ 108 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 109 | CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), 110 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), 111 | CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), 112 | NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), 113 | ] 114 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) 115 | adata_manager.register_fields(adata, **kwargs) 116 | cls.register_manager(adata_manager) 117 | -------------------------------------------------------------------------------- /src/simple_scvi/_mymodule.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scvi import REGISTRY_KEYS 7 | from scvi.distributions import ZeroInflatedNegativeBinomial 8 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 9 | from scvi.nn import DecoderSCVI, Encoder, one_hot 10 | from torch.distributions import Normal 11 | from torch.distributions import kl_divergence as kl 12 | 13 | TensorDict = Dict[str, torch.Tensor] 14 | 15 | 16 | class MyModule(BaseModuleClass): 17 | """ 18 | Skeleton Variational auto-encoder model. 19 | 20 | Here we implement a basic version of scVI's underlying VAE :cite:p:`Lopez18`. 21 | This implementation is for instructional purposes only. 22 | 23 | Parameters 24 | ---------- 25 | n_input 26 | Number of input genes 27 | library_log_means 28 | 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if 29 | not using observed library size. 30 | library_log_vars 31 | 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if 32 | not using observed library size. 33 | n_batch 34 | Number of batches, if 0, no batch correction is performed. 35 | n_hidden 36 | Number of nodes per hidden layer 37 | n_latent 38 | Dimensionality of the latent space 39 | n_layers 40 | Number of hidden layers used for encoder and decoder NNs 41 | dropout_rate 42 | Dropout rate for neural networks 43 | """ 44 | 45 | def __init__( 46 | self, 47 | n_input: int, 48 | library_log_means: np.ndarray, 49 | library_log_vars: np.ndarray, 50 | n_batch: int = 0, 51 | n_hidden: int = 128, 52 | n_latent: int = 10, 53 | n_layers: int = 1, 54 | dropout_rate: float = 0.1, 55 | ): 56 | super().__init__() 57 | self.n_latent = n_latent 58 | self.n_batch = n_batch 59 | # this is needed to comply with some requirement of the VAEMixin class 60 | self.latent_distribution = "normal" 61 | 62 | self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float()) 63 | self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float()) 64 | 65 | # setup the parameters of your generative model, as well as your inference model 66 | self.px_r = torch.nn.Parameter(torch.randn(n_input)) 67 | # z encoder goes from the n_input-dimensional data to an n_latent-d 68 | # latent space representation 69 | self.z_encoder = Encoder( 70 | n_input, 71 | n_latent, 72 | n_layers=n_layers, 73 | n_hidden=n_hidden, 74 | dropout_rate=dropout_rate, 75 | ) 76 | # l encoder goes from n_input-dimensional data to 1-d library size 77 | self.l_encoder = Encoder( 78 | n_input, 79 | 1, 80 | n_layers=1, 81 | n_hidden=n_hidden, 82 | dropout_rate=dropout_rate, 83 | ) 84 | # decoder goes from n_latent-dimensional space to n_input-d data 85 | self.decoder = DecoderSCVI( 86 | n_latent, 87 | n_input, 88 | n_layers=n_layers, 89 | n_hidden=n_hidden, 90 | ) 91 | 92 | def _get_inference_input(self, tensors): 93 | """Parse the dictionary to get appropriate args""" 94 | x = tensors[REGISTRY_KEYS.X_KEY] 95 | 96 | input_dict = dict(x=x) 97 | return input_dict 98 | 99 | def _get_generative_input(self, tensors, inference_outputs): 100 | z = inference_outputs["z"] 101 | library = inference_outputs["library"] 102 | 103 | input_dict = { 104 | "z": z, 105 | "library": library, 106 | } 107 | return input_dict 108 | 109 | @auto_move_data 110 | def inference(self, x): 111 | """ 112 | High level inference method. 113 | 114 | Runs the inference (encoder) model. 115 | """ 116 | # log the input to the variational distribution for numerical stability 117 | x_ = torch.log(1 + x) 118 | # get variational parameters via the encoder networks 119 | qz_m, qz_v, z = self.z_encoder(x_) 120 | ql_m, ql_v, library = self.l_encoder(x_) 121 | 122 | outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) 123 | return outputs 124 | 125 | @auto_move_data 126 | def generative(self, z, library): 127 | """Runs the generative model.""" 128 | # form the parameters of the ZINB likelihood 129 | px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) 130 | px_r = torch.exp(self.px_r) 131 | 132 | return dict(px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout) 133 | 134 | def loss( 135 | self, 136 | tensors, 137 | inference_outputs, 138 | generative_outputs, 139 | kl_weight: float = 1.0, 140 | ): 141 | """Loss function.""" 142 | x = tensors[REGISTRY_KEYS.X_KEY] 143 | qz_m = inference_outputs["qz_m"] 144 | qz_v = inference_outputs["qz_v"] 145 | ql_m = inference_outputs["ql_m"] 146 | ql_v = inference_outputs["ql_v"] 147 | px_rate = generative_outputs["px_rate"] 148 | px_r = generative_outputs["px_r"] 149 | px_dropout = generative_outputs["px_dropout"] 150 | 151 | mean = torch.zeros_like(qz_m) 152 | scale = torch.ones_like(qz_v) 153 | 154 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) 155 | 156 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 157 | n_batch = self.library_log_means.shape[1] 158 | local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means) 159 | local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) 160 | 161 | kl_divergence_l = kl( 162 | Normal(ql_m, torch.sqrt(ql_v)), 163 | Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), 164 | ).sum(dim=1) 165 | 166 | reconst_loss = ( 167 | -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1) 168 | ) 169 | 170 | kl_local_for_warmup = kl_divergence_z 171 | kl_local_no_warmup = kl_divergence_l 172 | 173 | weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup 174 | 175 | loss = torch.mean(reconst_loss + weighted_kl_local) 176 | 177 | kl_local = dict(kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z) 178 | return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local) 179 | 180 | @torch.no_grad() 181 | def sample( 182 | self, 183 | tensors, 184 | n_samples=1, 185 | library_size=1, 186 | ) -> torch.Tensor: 187 | r""" 188 | Generate observation samples from the posterior predictive distribution. 189 | 190 | The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. 191 | 192 | Parameters 193 | ---------- 194 | tensors 195 | Tensors dict 196 | n_samples 197 | Number of required samples for each cell 198 | library_size 199 | Library size to scale scamples to 200 | Returns 201 | ------- 202 | x_new 203 | tensor with shape (n_cells, n_genes, n_samples) 204 | """ 205 | inference_kwargs = dict(n_samples=n_samples) 206 | ( 207 | _, 208 | generative_outputs, 209 | ) = self.forward( 210 | tensors, 211 | inference_kwargs=inference_kwargs, 212 | compute_loss=False, 213 | ) 214 | 215 | px_r = generative_outputs["px_r"] 216 | px_rate = generative_outputs["px_rate"] 217 | px_dropout = generative_outputs["px_dropout"] 218 | 219 | dist = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) 220 | 221 | if n_samples > 1: 222 | exprs = dist.sample().permute([1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) 223 | else: 224 | exprs = dist.sample() 225 | 226 | return exprs.cpu() 227 | 228 | @torch.no_grad() 229 | @auto_move_data 230 | def marginal_ll(self, tensors: TensorDict, n_mc_samples: int): 231 | """Marginal ll.""" 232 | sample_batch = tensors[REGISTRY_KEYS.X_KEY] 233 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 234 | 235 | to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) 236 | 237 | for i in range(n_mc_samples): 238 | # Distribution parameters and sampled variables 239 | inference_outputs, _, losses = self.forward(tensors) 240 | qz_m = inference_outputs["qz_m"] 241 | qz_v = inference_outputs["qz_v"] 242 | z = inference_outputs["z"] 243 | ql_m = inference_outputs["ql_m"] 244 | ql_v = inference_outputs["ql_v"] 245 | library = inference_outputs["library"] 246 | 247 | # Reconstruction Loss 248 | reconst_loss = losses.dict_sum(losses.reconstruction_loss) 249 | 250 | # Log-probabilities 251 | n_batch = self.library_log_means.shape[1] 252 | local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means) 253 | local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) 254 | p_l = Normal(local_library_log_means, local_library_log_vars.sqrt()).log_prob(library).sum(dim=-1) 255 | 256 | p_z = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)).log_prob(z).sum(dim=-1) 257 | p_x_zl = -reconst_loss 258 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) 259 | q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) 260 | 261 | to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x 262 | 263 | batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) 264 | log_lkl = torch.sum(batch_log_lkl).item() 265 | return log_lkl 266 | -------------------------------------------------------------------------------- /src/simple_scvi/_mypyromodel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Sequence, Union 3 | 4 | import numpy as np 5 | import torch 6 | from anndata import AnnData 7 | from scvi import REGISTRY_KEYS 8 | from scvi.data import AnnDataManager 9 | from scvi.data.fields import ( 10 | CategoricalJointObsField, 11 | CategoricalObsField, 12 | LayerField, 13 | NumericalJointObsField, 14 | ) 15 | from scvi.dataloaders import DataSplitter 16 | from scvi.model.base import BaseModelClass 17 | from scvi.train import PyroTrainingPlan, TrainRunner 18 | from scvi.utils import setup_anndata_dsp 19 | 20 | from ._mypyromodule import MyPyroModule 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class MyPyroModel(BaseModelClass): 26 | """ 27 | Skeleton for a pyro version of a scvi-tools model. 28 | 29 | Please use this skeleton to create new models. 30 | 31 | Parameters 32 | ---------- 33 | adata 34 | AnnData object that has been registered via :meth:`~mypackage.MyPyroModel.setup_anndata`. 35 | n_hidden 36 | Number of nodes per hidden layer. 37 | n_latent 38 | Dimensionality of the latent space. 39 | n_layers 40 | Number of hidden layers used for encoder and decoder NNs. 41 | **model_kwargs 42 | Keyword args for :class:`~mypackage.MyModule` 43 | Examples 44 | -------- 45 | >>> adata = anndata.read_h5ad(path_to_anndata) 46 | >>> mypackage.MyPyroModel.setup_anndata(adata, batch_key="batch") 47 | >>> vae = mypackage.MyModel(adata) 48 | >>> vae.train() 49 | >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() 50 | """ 51 | 52 | def __init__( 53 | self, 54 | adata: AnnData, 55 | n_hidden: int = 128, 56 | n_latent: int = 10, 57 | n_layers: int = 1, 58 | **model_kwargs, 59 | ): 60 | super().__init__(adata) 61 | 62 | # self.summary_stats provides information about anndata dimensions and other tensor info 63 | 64 | self.module = MyPyroModule( 65 | n_input=self.summary_stats["n_vars"], 66 | n_hidden=n_hidden, 67 | n_latent=n_latent, 68 | n_layers=n_layers, 69 | **model_kwargs, 70 | ) 71 | self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" 72 | # necessary line to get params that will be used for saving/loading 73 | self.init_params_ = self._get_init_params(locals()) 74 | 75 | logger.info("The model has been initialized") 76 | 77 | def get_latent( 78 | self, 79 | adata: Optional[AnnData] = None, 80 | indices: Optional[Sequence[int]] = None, 81 | batch_size: Optional[int] = None, 82 | ): 83 | """ 84 | Return the latent representation for each cell. 85 | 86 | This is denoted as :math:`z_n` in our manuscripts. 87 | 88 | Parameters 89 | ---------- 90 | adata 91 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 92 | AnnData object used to initialize the model. 93 | indices 94 | Indices of cells in adata to use. If `None`, all cells are used. 95 | batch_size 96 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 97 | Returns 98 | ------- 99 | latent_representation : np.ndarray 100 | Low-dimensional representation for each cell 101 | """ 102 | adata = self._validate_anndata(adata) 103 | scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) 104 | latent = [] 105 | for tensors in scdl: 106 | qz_m = self.module.get_latent(tensors) 107 | latent += [qz_m.cpu()] 108 | return np.array(torch.cat(latent)) 109 | 110 | def train( 111 | self, 112 | max_epochs: Optional[int] = None, 113 | use_gpu: Optional[Union[str, int, bool]] = None, 114 | train_size: float = 0.9, 115 | validation_size: Optional[float] = None, 116 | batch_size: int = 128, 117 | plan_kwargs: Optional[dict] = None, 118 | **trainer_kwargs, 119 | ): 120 | """ 121 | Train the model. 122 | 123 | Parameters 124 | ---------- 125 | max_epochs 126 | Number of passes through the dataset. If `None`, defaults to 127 | `np.min([round((20000 / n_cells) * 400), 400])` 128 | use_gpu 129 | Use default GPU if available (if None or True), or index of GPU to use (if int), 130 | or name of GPU (if str), or use CPU (if False). 131 | train_size 132 | Size of training set in the range [0.0, 1.0]. 133 | validation_size 134 | Size of the test set. If `None`, defaults to 1 - `train_size`. If 135 | `train_size + validation_size < 1`, the remaining cells belong to a test set. 136 | batch_size 137 | Minibatch size to use during training. 138 | plan_kwargs 139 | Keyword args for :class:`~scvi.lightning.TrainingPlan`. Keyword arguments passed to 140 | `train()` will overwrite values present in `plan_kwargs`, when appropriate. 141 | **trainer_kwargs 142 | Other keyword args for :class:`~scvi.lightning.Trainer`. 143 | """ 144 | if max_epochs is None: 145 | n_cells = self.adata.n_obs 146 | max_epochs = np.min([round((20000 / n_cells) * 400), 400]) 147 | 148 | plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() 149 | 150 | data_splitter = DataSplitter( 151 | self.adata_manager, 152 | train_size=train_size, 153 | validation_size=validation_size, 154 | batch_size=batch_size, 155 | use_gpu=use_gpu, 156 | ) 157 | training_plan = PyroTrainingPlan(self.module, **plan_kwargs) 158 | runner = TrainRunner( 159 | self, 160 | training_plan=training_plan, 161 | data_splitter=data_splitter, 162 | max_epochs=max_epochs, 163 | use_gpu=use_gpu, 164 | **trainer_kwargs, 165 | ) 166 | return runner() 167 | 168 | @classmethod 169 | @setup_anndata_dsp.dedent 170 | def setup_anndata( 171 | cls, 172 | adata: AnnData, 173 | batch_key: Optional[str] = None, 174 | labels_key: Optional[str] = None, 175 | layer: Optional[str] = None, 176 | categorical_covariate_keys: Optional[List[str]] = None, 177 | continuous_covariate_keys: Optional[List[str]] = None, 178 | **kwargs, 179 | ) -> Optional[AnnData]: 180 | """ 181 | %(summary)s. 182 | 183 | Parameters 184 | ---------- 185 | %(param_adata)s 186 | %(param_batch_key)s 187 | %(param_labels_key)s 188 | %(param_layer)s 189 | %(param_cat_cov_keys)s 190 | %(param_cont_cov_keys)s 191 | Returns 192 | ------- 193 | %(returns)s 194 | """ 195 | setup_method_args = cls._get_setup_method_args(**locals()) 196 | anndata_fields = [ 197 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 198 | CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), 199 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), 200 | CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), 201 | NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), 202 | ] 203 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) 204 | adata_manager.register_fields(adata, **kwargs) 205 | cls.register_manager(adata_manager) 206 | -------------------------------------------------------------------------------- /src/simple_scvi/_mypyromodule.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import pyro 4 | import pyro.distributions as dist 5 | import torch 6 | from pyro import poutine 7 | from scvi import REGISTRY_KEYS 8 | from scvi.module.base import PyroBaseModuleClass, auto_move_data 9 | from scvi.nn import DecoderSCVI, Encoder 10 | 11 | TensorDict = Dict[str, torch.Tensor] 12 | 13 | 14 | class MyPyroModule(PyroBaseModuleClass): 15 | """ 16 | Skeleton Variational auto-encoder Pyro model. 17 | 18 | Here we implement a basic version of scVI's underlying VAE :cite:p:`Lopez18`. 19 | This implementation is for instructional purposes only. 20 | 21 | Parameters 22 | ---------- 23 | n_input 24 | Number of input genes 25 | n_latent 26 | Dimensionality of the latent space 27 | n_hidden 28 | Number of nodes per hidden layer 29 | n_layers 30 | Number of hidden layers used for encoder and decoder NNs 31 | """ 32 | 33 | def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): 34 | super().__init__() 35 | self.n_input = n_input 36 | self.n_latent = n_latent 37 | self.epsilon = 5.0e-3 38 | # z encoder goes from the n_input-dimensional data to an n_latent-d 39 | # latent space representation 40 | self.encoder = Encoder( 41 | n_input, 42 | n_latent, 43 | n_layers=n_layers, 44 | n_hidden=n_hidden, 45 | dropout_rate=0.1, 46 | ) 47 | # decoder goes from n_latent-dimensional space to n_input-d data 48 | self.decoder = DecoderSCVI( 49 | n_latent, 50 | n_input, 51 | n_layers=n_layers, 52 | n_hidden=n_hidden, 53 | ) 54 | # This gene-level parameter modulates the variance of the observation distribution 55 | self.px_r = torch.nn.Parameter(torch.ones(self.n_input)) 56 | 57 | @staticmethod 58 | def _get_fn_args_from_batch(tensor_dict: TensorDict): 59 | x = tensor_dict[REGISTRY_KEYS.X_KEY] 60 | log_library = torch.log(torch.sum(x, dim=1, keepdim=True) + 1e-6) 61 | return (x, log_library), {} 62 | 63 | def model(self, x: torch.Tensor, log_library: torch.Tensor, kl_weight: float = 1.0): 64 | """Pyro model.""" 65 | # register PyTorch module `decoder` with Pyro 66 | pyro.module("scvi", self) 67 | with pyro.plate("data", size=x.shape[0], subsample_size=x.shape[0]): 68 | with poutine.scale(None, kl_weight): 69 | # setup hyperparameters for prior p(z) 70 | z_loc = x.new_zeros(torch.Size((x.shape[0], self.n_latent))) 71 | z_scale = x.new_ones(torch.Size((x.shape[0], self.n_latent))) 72 | # sample from prior (value will be sampled by guide when computing the ELBO) 73 | z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) 74 | # decode the latent code z 75 | px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library) 76 | # build count distribution 77 | nb_logits = (px_rate + self.epsilon).log() - (self.px_r.exp() + self.epsilon).log() 78 | x_dist = dist.ZeroInflatedNegativeBinomial( 79 | gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits 80 | ) 81 | # score against actual counts 82 | pyro.sample("obs", x_dist.to_event(1), obs=x) 83 | 84 | def guide(self, x: torch.Tensor, log_library: torch.Tensor, kl_weight: float = 1.0): 85 | """Pyro guide.""" 86 | # define the guide (i.e. variational distribution) q(z|x) 87 | pyro.module("scvi", self) 88 | with pyro.plate("data", x.shape[0]), poutine.scale(None, kl_weight): 89 | # use the encoder to get the parameters used to define q(z|x) 90 | x_ = torch.log(1 + x) 91 | z_loc, z_scale, _ = self.encoder(x_) 92 | # sample the latent code z 93 | pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) 94 | 95 | @torch.no_grad() 96 | @auto_move_data 97 | def get_latent(self, tensor_dict: TensorDict): 98 | """Get the latent representation of the data.""" 99 | x = tensor_dict[REGISTRY_KEYS.X_KEY] 100 | x_ = torch.log(1 + x) 101 | z_loc, _, _ = self.encoder(x_) 102 | return z_loc 103 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pyro 2 | from scvi.data import synthetic_iid 3 | 4 | from simple_scvi import MyModel, MyPyroModel 5 | 6 | 7 | def test_mymodel(): 8 | n_latent = 5 9 | adata = synthetic_iid() 10 | MyModel.setup_anndata(adata, batch_key="batch", labels_key="labels") 11 | model = MyModel(adata, n_latent=n_latent) 12 | model.train(1, check_val_every_n_epoch=1, train_size=0.5) 13 | model.get_elbo() 14 | model.get_latent_representation() 15 | model.get_marginal_ll(n_mc_samples=5) 16 | model.get_reconstruction_error() 17 | model.history 18 | 19 | # tests __repr__ 20 | print(model) 21 | 22 | 23 | def test_mypyromodel(): 24 | adata = synthetic_iid() 25 | pyro.clear_param_store() 26 | MyPyroModel.setup_anndata(adata, batch_key="batch", labels_key="labels") 27 | model = MyPyroModel(adata) 28 | model.train(max_epochs=1, train_size=1) 29 | model.get_latent(adata) 30 | model.history 31 | 32 | # tests __repr__ 33 | print(model) 34 | --------------------------------------------------------------------------------