├── .editorconfig ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── _static │ └── img │ │ ├── overview_fig.png │ │ └── perturbation_overview_fig.png ├── _templates │ ├── autosummary │ │ └── class.rst │ └── class_no_inherited.rst ├── about │ └── index.rst ├── api │ ├── .Rhistory │ └── index.md ├── conf.py ├── extensions │ └── typed_returns.py ├── index.md ├── make.bat ├── references.bib ├── references.md ├── release_notes │ ├── index.rst │ └── v0.1.0.rst └── tutorials │ ├── index.md │ ├── index_modelcomp.md │ ├── index_murine.md │ ├── index_zebrafish.md │ ├── modelcomparison │ └── ModelComp.ipynb │ ├── murine │ ├── 01_SCENIC_tutorial.ipynb │ ├── 02_RegVelo_preparation.ipynb │ └── 03_perturbation_tutorial_murine.ipynb │ └── zebrafish │ ├── _static │ └── perturbation_metrics.svg │ └── tutorial.ipynb ├── pyproject.toml ├── readthedocs.yml ├── regvelo ├── ModelComparison.py ├── __init__.py ├── _constants.py ├── _model.py ├── _module.py ├── datasets │ ├── __init__.py │ └── _datasets.py ├── plotting │ ├── __init__.py │ ├── commitment_score.py │ ├── depletion_score.py │ ├── fate_probabilities.py │ ├── get_significance.py │ └── utils.py ├── preprocessing │ ├── __init__.py │ ├── filter_genes.py │ ├── preprocess_data.py │ ├── sanity_check.py │ └── set_prior_grn.py └── tools │ ├── TFScanning_func.py │ ├── TFscreening.py │ ├── __init__.py │ ├── _tsi.py │ ├── abundance_test.py │ ├── depletion_score.py │ ├── in_silico_block_simulation.py │ ├── markov_density_simulation.py │ ├── perturbation_effect.py │ ├── set_output.py │ └── utils.py ├── reproduce_env └── regvelo.yaml ├── setup.py └── tests ├── __init__.py └── test_regvelo.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*.*.*" 7 | 8 | jobs: 9 | release: 10 | name: Release 11 | runs-on: ubuntu-latest 12 | steps: 13 | # will use ref/SHA that triggered it 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up Python 3.10 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: "3.9" 21 | 22 | - name: Install poetry 23 | uses: abatilo/actions-poetry@v2.0.0 24 | with: 25 | poetry-version: 1.4.2 26 | 27 | - name: Build project for distribution 28 | run: poetry build 29 | 30 | - name: Check Version 31 | id: check-version 32 | run: | 33 | [[ "$(poetry version --short)" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]] \ 34 | || echo ::set-output name=prerelease::true 35 | 36 | - name: Publish to PyPI 37 | env: 38 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} 39 | run: poetry publish 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: regvelo 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11"] 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Cache pip 26 | uses: actions/cache@v2 27 | with: 28 | path: ~/.cache/pip 29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 30 | restore-keys: | 31 | ${{ runner.os }}-pip- 32 | - name: Install dependencies 33 | run: | 34 | pip install pytest-cov 35 | pip install .[dev] 36 | - name: Test with pytest 37 | run: | 38 | pytest --cov-report=xml --cov=velovi 39 | - name: After success 40 | run: | 41 | bash <(curl -s https://codecov.io/bash) 42 | pip list 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # DS_Store 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # vscode 135 | .vscode/settings.json 136 | docs/api/reference/ 137 | *.h5ad 138 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - commit 6 | - push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/psf/black 10 | rev: "23.1.0" 11 | hooks: 12 | - id: black 13 | - repo: https://github.com/asottile/blacken-docs 14 | rev: 1.13.0 15 | hooks: 16 | - id: blacken-docs 17 | - repo: https://github.com/pre-commit/mirrors-prettier 18 | rev: v3.0.0-alpha.6 19 | hooks: 20 | - id: prettier 21 | # Newer versions of node don't work on systems that have an older version of GLIBC 22 | # (in particular Ubuntu 18.04 and Centos 7) 23 | # EOL of Centos 7 is in 2024-06, we can probably get rid of this then. 24 | # See https://github.com/scverse/cookiecutter-scverse/issues/143 and 25 | # https://github.com/jupyterlab/jupyterlab/issues/12675 26 | language_version: "17.9.1" 27 | - repo: https://github.com/charliermarsh/ruff-pre-commit 28 | rev: v0.0.254 29 | hooks: 30 | - id: ruff 31 | args: [--fix, --exit-non-zero-on-fix] 32 | - repo: https://github.com/pre-commit/pre-commit-hooks 33 | rev: v4.4.0 34 | hooks: 35 | - id: detect-private-key 36 | - id: check-ast 37 | - id: end-of-file-fixer 38 | - id: mixed-line-ending 39 | args: [--fix=lf] 40 | - id: trailing-whitespace 41 | - id: check-case-conflict 42 | - repo: local 43 | hooks: 44 | - id: forbid-to-commit 45 | name: Don't commit rej files 46 | entry: | 47 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 48 | Fix the merge conflicts manually and remove the .rej files. 49 | language: fail 50 | files: '.*\.rej$' 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Yosef Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RegVelo: gene-regulatory-informed dynamics of single cells 2 | 3 | RegVelo 4 | 5 | **RegVelo** is a end-to-end framework to infer regulatory cellular dynamics through coupled splicing dynamics. See our [RegVelo manuscript](https://www.biorxiv.org/content/10.1101/2024.12.11.627935v1) to learn more. If you use our tool in your own work, please cite it as 6 | 7 | ``` 8 | @article{wang2024regvelo, 9 | title={RegVelo: gene-regulatory-informed dynamics of single cells}, 10 | author={Wang, Weixu and Hu, Zhiyuan and Weiler, Philipp and Mayes, Sarah and Lange, Marius and Wang, Jingye and Xue, Zhengyuan and Sauka-Spengler, Tatjana and Theis, Fabian J}, 11 | journal={bioRxiv}, 12 | pages={2024--12}, 13 | year={2024}, 14 | publisher={Cold Spring Harbor Laboratory} 15 | } 16 | ``` 17 | ## Getting started 18 | Please refer to the [Tutorials](https://regvelo.readthedocs.io/en/latest/index.html) 19 | 20 | ## Installation 21 | 22 | You need to have Python 3.8 or newer installed on your system. If you don't have 23 | Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 24 | 25 | To create and activate a new environment 26 | 27 | ```bash 28 | conda create -n regvelo-py310 python=3.10 --yes && conda activate regvelo-py310 29 | ``` 30 | 31 | Next, install the package with 32 | 33 | ```bash 34 | pip install git+https://github.com/theislab/regvelo.git@main --no-cache-dir --force-reinstall 35 | ``` 36 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Run to check if valid 2 | # curl --data-binary @codecov.yml https://codecov.io/validate 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: 80% 8 | threshold: 1% 9 | patch: off 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = scvi 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/img/overview_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/docs/_static/img/overview_fig.png -------------------------------------------------------------------------------- /docs/_static/img/perturbation_overview_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/docs/_static/img/perturbation_overview_fig.png -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/_templates/class_no_inherited.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | :show-inheritance: 9 | 10 | {% block attributes %} 11 | {% if attributes %} 12 | Attributes table 13 | ~~~~~~~~~~~~~~~~ 14 | 15 | .. autosummary:: 16 | {% for item in attributes %} 17 | {%- if item not in inherited_members%} 18 | ~{{ fullname }}.{{ item }} 19 | {%- endif -%} 20 | {%- endfor %} 21 | {% endif %} 22 | {% endblock %} 23 | 24 | 25 | {% block methods %} 26 | {% if methods %} 27 | Methods table 28 | ~~~~~~~~~~~~~~ 29 | 30 | .. autosummary:: 31 | {% for item in methods %} 32 | {%- if item != '__init__' and item not in inherited_members%} 33 | ~{{ fullname }}.{{ item }} 34 | {%- endif -%} 35 | 36 | {%- endfor %} 37 | {% endif %} 38 | {% endblock %} 39 | 40 | {% block attributes_documentation %} 41 | {% if attributes %} 42 | Attributes 43 | ~~~~~~~~~~ 44 | 45 | {% for item in attributes %} 46 | {%- if item not in inherited_members%} 47 | 48 | .. autoattribute:: {{ [objname, item] | join(".") }} 49 | {%- endif -%} 50 | {%- endfor %} 51 | 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block methods_documentation %} 56 | {% if methods %} 57 | Methods 58 | ~~~~~~~ 59 | 60 | {% for item in methods %} 61 | {%- if item != '__init__' and item not in inherited_members%} 62 | 63 | .. automethod:: {{ [objname, item] | join(".") }} 64 | {%- endif -%} 65 | {%- endfor %} 66 | 67 | {% endif %} 68 | {% endblock %} 69 | -------------------------------------------------------------------------------- /docs/about/index.rst: -------------------------------------------------------------------------------- 1 | About RegVelo 2 | ------------ 3 | 4 | Understanding cellular dynamics and regulatory interactions is crucial for decoding the complex processes that govern cell fate and differentiation. 5 | Traditional RNA velocity methods capture dynamic cellular transitions by modeling changes in spliced and unspliced mRNA but lack integration with gene regulatory networks (GRNs), omitting critical regulatory mechanisms underlying cellular decisions. 6 | Conversely, GRN inference techniques map regulatory connections but fail to account for the temporal dynamics of gene expression. 7 | 8 | With RegVelo, developed by `Wang et al. (biorxiv, 2024) `_, 9 | the research gap is bridged through combining RNA velocity's temporal insights with a regulatory framework to model transcriptome-wide splicing kinetics informed by GRNs. 10 | This extend current RNA velocity framework to a full mechanism model, allows to model more complicated development process. 11 | Further, by coupling with CellRank `Weiler et al. (Nature Methods, 2024) `_, RegVelo expands its capabilities to include robust perturbation predictions, linking regulatory changes to long-term cell fate decisions. 12 | CellRank employs velocity-based state transition probabilities to predict terminal cell states, and RegVelo enhances this framework with GRN-informed splicing kinetics, 13 | enabling precise simulations of transcription factor (TF) knockouts. This synergy allows for the identification of lineage drivers and the prediction of cell fate changes upon genetic perturbations. 14 | 15 | RegVelo's application 16 | ~~~~~~~~~~~~~~~~~~~ 17 | - estimate RNA velocity govarned by gene regulation. 18 | - infer latent time to indicating cellular differentiation process. 19 | - estimate velocity intrinsic and extrinsic uncertainty :cite:p:`gayoso2024deep`. 20 | - estimate regulon perturbation effects via CellRank framework :cite:p:`lange2022cellrank, weiler2024cellrank`. 21 | 22 | RegVelo model 23 | ~~~~~~~~~~~~~~~~~~~ 24 | RegVelo leverages deep generative modeling to infer splicing kinetics, transcription rates, and latent cellular time while integrating GRN priors derived from multi-omics data or curated databases. 25 | RegVelo incorporates cellular dynamic estimates by first encoding unspliced (*u*) and spliced RNA (*s*) readouts into posterior parameters of a low dimensional latent variable - the cell representation - with a neural network. 26 | An additional neural network takes samples of this cell representation as input to parameterize gene-wise latent time as in our previous model veloVI. 27 | We then model splicing dynamics with ordinary differential equations (ODEs) specified by a base transcription *b* and GRN weight matrix *W* , 28 | describing transcription and inferred by a shallow neural network, constant splicing and degradation rate parameters *beta* and *gamma* , respectively, 29 | and estimated cell and gene-specific latent times. Importantly, existing methods for inferring RNA velocity consider a set of decoupled one-dimensional ODEs for which analytic solutions exist, but RegVelo relies on the single, high-dimensional ODE 30 | 31 | .. math:: 32 | \begin{align} 33 | \frac{\mathrm{d} u_{g}(t)}{\mathrm{d} t} =\alpha_{g}(t) - \beta_{g} u_{g}(t), \\ 34 | \frac{\mathrm{d} s_{g}(t)}{\mathrm{d} t} = \beta_{g} u_{g}(t) - \gamma_{g} s_{g}(t), 35 | \end{align} 36 | 37 | that is now coupled through gene regulation-informed transcription 38 | 39 | .. math:: 40 | \alpha_g = h \left( \left [ W s(t) +b \right ] _{g} \right) 41 | 42 | where *g* indicates the gene and *h* is a non-linear activation function. 43 | We predict the gene and cell-specific spliced and unspliced abundances using a parallelizable ODE solver, 44 | as this new system does not pose an analytic solution anymore; compared to previous approaches, we solve all gene dynamics at once instead of sequentially for each gene independently of all others. 45 | The forward simulation of the ODE solver allows for computing the likelihood function encompassing all neural network and kinetic parameters. 46 | We assume that the predicted spliced and unspliced abundances are the expected value of the Gaussian likelihood of the observed dataset and use gradient-based optimization to update all parameters. 47 | After optimization, we define cell-gene-specific velocities as splicing velocities based on the estimated splicing and degradation rates and predicted spliced and unspliced abundance. 48 | Overall, RegVelo allows sampling predicted readouts and velocities from the learned posterior distribution. 49 | 50 | 51 | Perturbation prediction 52 | ~~~~~~~~~~~~~~~~~~~ 53 | 54 | .. image:: https://github.com/theislab/regvelo/blob/main/docs/_static/img/perturbation_overview_fig.png?raw=true 55 | :alt: RegVelo perturbation introduction 56 | :width: 600px 57 | 58 | RegVelo is a generative model that couples cellular dynamics with regulatory networks. 59 | We can, thus, perform in silico counterfactual inference to test the cellular response upon unseen perturbations of a TF in the regulatory network: for a trained RegVelo model, 60 | we ignore regulatory effects of the TF by removing all its downstream targets from the GRN, i.e., depleting the regulon, and generate the perturbed velocity vector field. 61 | The dissimilarity between the original and perturbed cell velocities - the perturbation effect score - reflects the local changes on each cell induced by perturbations; we quantify this score with cosine dissimilarity. 62 | 63 | RNA velocity describes a high dimensional vector field representing cellular change along the phenotypic manifold but lacks interpretability and quantifiable measures of the long-term cell behavior. 64 | We recently proposed CellRank to bridge this gap by leveraging gene expression and an estimated vector field to model cell state transitions through Markov chains and infer terminal cell states. 65 | For each terminal state identified, CellRank calculates the probability of a cell transitioning to this state - the fate probability - that allows us to predict the cell's future state. 66 | By combining RegVelo’s generative model with CellRank, we connect gene regulation with both local cell dynamics and long-term cell fate decisions, and how they change upon in silico perturbations. 67 | In the context of our perturbation analyses, we compare CellRank’s prediction of cell fate probabilities for the original and perturbed vector fields, 68 | to find enrichment (increased cell fate probability) or depletion (decreased cell fate probability) effects towards terminal states. 69 | 70 | See `Wang et al. (biorxiv, 2024) `_ for a detailed description of the methods and applications on different biological systems. 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/api/.Rhistory: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/docs/api/.Rhistory -------------------------------------------------------------------------------- /docs/api/index.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: regvelo 5 | ``` 6 | 7 | ```{eval-rst} 8 | .. autosummary:: 9 | :toctree: reference/ 10 | :nosignatures: 11 | 12 | REGVELOVI 13 | ``` 14 | 15 | ```{eval-rst} 16 | .. autosummary:: 17 | :toctree: reference/ 18 | :template: class_no_inherited.rst 19 | :nosignatures: 20 | 21 | VELOVAE 22 | ``` 23 | 24 | ```{eval-rst} 25 | .. autosummary:: 26 | :toctree: reference/ 27 | :nosignatures: 28 | 29 | TFscreening 30 | ``` 31 | 32 | ```{eval-rst} 33 | .. autosummary:: 34 | :toctree: reference/ 35 | :nosignatures: 36 | 37 | in_silico_block_simulation 38 | ``` 39 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import subprocess 3 | import os 4 | import importlib 5 | import inspect 6 | import re 7 | import sys 8 | from datetime import datetime 9 | from importlib.metadata import metadata 10 | from pathlib import Path 11 | 12 | HERE = Path(__file__).parent 13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] 14 | 15 | 16 | # -- Project information ----------------------------------------------------- 17 | 18 | project_name = "regvelo" 19 | package_name = "regvelo" 20 | author = "Weixu Wang" 21 | copyright = f"{datetime.now():%Y}, {author}." 22 | version = "0.2.0" 23 | repository_url = "https://github.com/theislab/regvelo" 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = "0.2.0" 27 | 28 | bibtex_bibfiles = ["references.bib"] 29 | bibtex_reference_style = "author_year" 30 | 31 | templates_path = ["_templates"] 32 | nitpicky = True # Warn about broken links 33 | needs_sphinx = "4.0" 34 | 35 | html_context = { 36 | "display_github": True, # Integrate GitHub 37 | "github_user": "theislab", # Username 38 | "github_repo": project_name, # Repo name 39 | "github_version": "main", # Version 40 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 41 | } 42 | 43 | # -- General configuration --------------------------------------------------- 44 | 45 | # Add any Sphinx extension module names here, as strings. 46 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 47 | extensions = [ 48 | "myst_nb", 49 | "sphinx.ext.autodoc", 50 | "sphinx.ext.linkcode", 51 | "sphinx.ext.intersphinx", 52 | "sphinx.ext.autosummary", 53 | "sphinx.ext.napoleon", 54 | "sphinxcontrib.bibtex", 55 | "sphinx.ext.mathjax", 56 | "sphinx.ext.extlinks", 57 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 58 | "sphinx_copybutton", 59 | ] 60 | 61 | autosummary_generate = True 62 | autodoc_member_order = "bysource" 63 | default_role = "literal" 64 | autodoc_typehints = "description" 65 | bibtex_reference_style = "author_year" 66 | napoleon_google_docstring = True 67 | napoleon_numpy_docstring = True 68 | napoleon_include_init_with_doc = False 69 | napoleon_use_rtype = True # having a separate entry generally helps readability 70 | napoleon_use_param = True 71 | myst_enable_extensions = [ 72 | "amsmath", 73 | "colon_fence", 74 | "deflist", 75 | "dollarmath", 76 | "html_image", 77 | "html_admonition", 78 | ] 79 | myst_url_schemes = ("http", "https", "mailto") 80 | nb_output_stderr = "remove" 81 | nb_execution_mode = "off" 82 | nb_merge_streams = True 83 | 84 | source_suffix = { 85 | ".rst": "restructuredtext", 86 | ".ipynb": "myst-nb", 87 | ".myst": "myst-nb", 88 | } 89 | 90 | intersphinx_mapping = { 91 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 92 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None), 93 | "matplotlib": ("https://matplotlib.org/", None), 94 | "numpy": ("https://numpy.org/doc/stable/", None), 95 | "pandas": ("https://pandas.pydata.org/docs/", None), 96 | "python": ("https://docs.python.org/3", None), 97 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), 98 | "sklearn": ("https://scikit-learn.org/stable/", None), 99 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), 100 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 101 | "torch": ("https://pytorch.org/docs/master/", None), 102 | "plottable": ("https://plottable.readthedocs.io/en/latest/", None), 103 | "scvi-tools": ("https://docs.scvi-tools.org/en/stable/", None), 104 | "mudata": ("https://mudata.readthedocs.io/en/latest/", None), 105 | } 106 | 107 | # List of patterns, relative to source directory, that match files and 108 | # directories to ignore when looking for source files. 109 | # This pattern also affects html_static_path and html_extra_path. 110 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 111 | 112 | # extlinks config 113 | extlinks = { 114 | "issue": (f"{repository_url}/issues/%s", "#%s"), 115 | "pr": (f"{repository_url}/pull/%s", "#%s"), 116 | "ghuser": ("https://github.com/%s", "@%s"), 117 | } 118 | 119 | 120 | # -- Linkcode settings ------------------------------------------------- 121 | 122 | 123 | def git(*args): 124 | """Run a git command and return the output.""" 125 | return subprocess.check_output(["git", *args]).strip().decode() 126 | 127 | 128 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192 129 | # Current git reference. Uses branch/tag name if found, otherwise uses commit hash 130 | git_ref = None 131 | try: 132 | git_ref = git("name-rev", "--name-only", "--no-undefined", "HEAD") 133 | git_ref = re.sub(r"^(remotes/[^/]+|tags)/", "", git_ref) 134 | except Exception: 135 | pass 136 | 137 | # (if no name found or relative ref, use commit hash instead) 138 | if not git_ref or re.search(r"[\^~]", git_ref): 139 | try: 140 | git_ref = git("rev-parse", "HEAD") 141 | except Exception: 142 | git_ref = "main" 143 | 144 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192 145 | github_repo = "https://github.com/" + html_context["github_user"] + "/" + project_name 146 | _project_module_path = os.path.dirname(importlib.util.find_spec(package_name).origin) # type: ignore 147 | 148 | 149 | def linkcode_resolve(domain, info): 150 | """Resolve links for the linkcode extension.""" 151 | if domain != "py": 152 | return None 153 | 154 | try: 155 | obj: Any = sys.modules[info["module"]] 156 | for part in info["fullname"].split("."): 157 | obj = getattr(obj, part) 158 | obj = inspect.unwrap(obj) 159 | 160 | if isinstance(obj, property): 161 | obj = inspect.unwrap(obj.fget) # type: ignore 162 | 163 | path = os.path.relpath(inspect.getsourcefile(obj), start=_project_module_path) # type: ignore 164 | src, lineno = inspect.getsourcelines(obj) 165 | except Exception: 166 | return None 167 | 168 | path = f"{path}#L{lineno}-L{lineno + len(src) - 1}" 169 | return f"{github_repo}/blob/{git_ref}/{package_name}/{path}" 170 | 171 | 172 | # -- Options for HTML output ------------------------------------------------- 173 | 174 | # The theme to use for HTML and HTML Help pages. See the documentation for 175 | # a list of builtin themes. 176 | # 177 | html_theme = "sphinx_book_theme" 178 | html_static_path = ["_static"] 179 | html_title = "RegVelo" 180 | 181 | html_theme_options = { 182 | "repository_url": github_repo, 183 | "use_repository_button": True, 184 | } 185 | 186 | pygments_style = "default" 187 | 188 | nitpick_ignore = [ 189 | # If building the documentation fails because of a missing link that is outside your control, 190 | # you can add an exception to this list. 191 | ] 192 | 193 | 194 | def setup(app): 195 | """App setup hook.""" 196 | app.add_config_value( 197 | "recommonmark_config", 198 | { 199 | "auto_toc_tree_section": "Contents", 200 | "enable_auto_toc_tree": True, 201 | "enable_math": True, 202 | "enable_inline_math": False, 203 | "enable_eval_rst": True, 204 | }, 205 | True, 206 | ) 207 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | import re 4 | 5 | from sphinx.application import Sphinx 6 | from sphinx.ext.napoleon import NumpyDocstring 7 | 8 | 9 | def process_return(lines): 10 | """Process the return section of a docstring.""" 11 | for line in lines: 12 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) 13 | if m: 14 | # Once this is in scanpydoc, we can use the fancy hover stuff 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def scanpy_parse_returns_section(self, section): 21 | """Parse the returns section of the docstring.""" 22 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section()))) 23 | lines = self._format_block(":returns: ", lines_raw) 24 | if lines and lines[-1]: 25 | lines.append("") 26 | return lines 27 | 28 | 29 | def setup(app: Sphinx): 30 | """Setup the extension.""" 31 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section 32 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | # Welcome to the RegVelo documentation. 6 | 7 | ```{toctree} 8 | :maxdepth: 3 9 | :titlesonly: true 10 | 11 | about/index 12 | tutorials/index 13 | api/index 14 | release_notes/index 15 | references 16 | 17 | ``` 18 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=scvi 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{wang2024regvelo, 2 | title={RegVelo: gene-regulatory-informed dynamics of single cells}, 3 | author={Wang, Weixu and Hu, Zhiyuan and Weiler, Philipp and Mayes, Sarah and Lange, Marius and Wang, Jingye and Xue, Zhengyuan and Sauka-Spengler, Tatjana and Theis, Fabian J}, 4 | journal={bioRxiv}, 5 | pages={2024--12}, 6 | year={2024}, 7 | publisher={Cold Spring Harbor Laboratory} 8 | } 9 | 10 | @article{lange2022cellrank, 11 | title={CellRank for directed single-cell fate mapping}, 12 | author={Lange, Marius and Bergen, Volker and Klein, Michal and Setty, Manu and Reuter, Bernhard and Bakhti, Mostafa and Lickert, Heiko and Ansari, Meshal and Schniering, Janine and Schiller, Herbert B and others}, 13 | journal={Nature methods}, 14 | volume={19}, 15 | number={2}, 16 | pages={159--170}, 17 | year={2022}, 18 | publisher={Nature Publishing Group US New York} 19 | } 20 | 21 | @article{weiler2024cellrank, 22 | title={CellRank 2: unified fate mapping in multiview single-cell data}, 23 | author={Weiler, Philipp and Lange, Marius and Klein, Michal and Pe’er, Dana and Theis, Fabian}, 24 | journal={Nature Methods}, 25 | pages={1--10}, 26 | year={2024}, 27 | publisher={Nature Publishing Group US New York} 28 | } 29 | 30 | @article{bergen2020generalizing, 31 | title={Generalizing RNA velocity to transient cell states through dynamical modeling}, 32 | author={Bergen, Volker and Lange, Marius and Peidli, Stefan and Wolf, F Alexander and Theis, Fabian J}, 33 | journal={Nature biotechnology}, 34 | volume={38}, 35 | number={12}, 36 | pages={1408--1414}, 37 | year={2020}, 38 | publisher={Nature Publishing Group US New York} 39 | } 40 | 41 | @article{gayoso2024deep, 42 | title={Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells}, 43 | author={Gayoso, Adam and Weiler, Philipp and Lotfollahi, Mohammad and Klein, Dominik and Hong, Justin and Streets, Aaron and Theis, Fabian J and Yosef, Nir}, 44 | journal={Nature methods}, 45 | volume={21}, 46 | number={1}, 47 | pages={50--59}, 48 | year={2024}, 49 | publisher={Nature Publishing Group US New York} 50 | } 51 | 52 | @article{hu2024single, 53 | title={Single-cell multi-omics, spatial transcriptomics and systematic perturbation decode circuitry of neural crest fate decisions}, 54 | author={Hu, Zhiyuan and Mayes, Sarah and Wang, Weixu and Santos-Pereira, Jose M and Theis, Fabian and Sauka-Spengler, Tatjana}, 55 | journal={bioRxiv}, 56 | pages={2024--09}, 57 | year={2024}, 58 | publisher={Cold Spring Harbor Laboratory} 59 | } -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | If RegVelo is helpful in your research, please cite {cite:p}`wang2024regvelo`. If you are using zebrafish smart-seq3, 10x multiome and in vivo perturb-seq, please cite {cite:p}`wang2024regvelo, hu2024single`. 4 | 5 | ```{bibliography} 6 | :cited: 7 | ``` 8 | -------------------------------------------------------------------------------- /docs/release_notes/index.rst: -------------------------------------------------------------------------------- 1 | Release notes 2 | ============= 3 | 4 | Version 0.1 5 | ----------- 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | v0.1.0 10 | -------------------------------------------------------------------------------- /docs/release_notes/v0.1.0.rst: -------------------------------------------------------------------------------- 1 | New in 0.1.0 (2024-09-03) 2 | ------------------------- 3 | -------------------------------------------------------------------------------- /docs/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | 4 | ```{toctree} 5 | :maxdepth: 2 6 | 7 | index_murine 8 | index_zebrafish 9 | index_modelcomp 10 | ``` 11 | -------------------------------------------------------------------------------- /docs/tutorials/index_modelcomp.md: -------------------------------------------------------------------------------- 1 | # Model comparison with multi-view information 2 | 3 | ```{toctree} 4 | :maxdepth: 1 5 | :titlesonly: 6 | 7 | modelcomparison/ModelComp.ipynb 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/tutorials/index_murine.md: -------------------------------------------------------------------------------- 1 | # Murine Neural Crest Development 2 | 3 | ```{toctree} 4 | :maxdepth:1 5 | :titlesonly: 6 | 7 | murine/01_SCENIC_tutorial 8 | murine/02_RegVelo_preparation 9 | murine/03_perturbation_tutorial_murine 10 | 11 | ``` 12 | 13 | 14 | -------------------------------------------------------------------------------- /docs/tutorials/index_zebrafish.md: -------------------------------------------------------------------------------- 1 | # Zebrafish Neural Crest Development 2 | 3 | ```{toctree} 4 | :maxdepth: 1 5 | :titlesonly: 6 | 7 | zebrafish/tutorial 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/tutorials/murine/01_SCENIC_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "37f7b867-f3e7-4d1a-ad57-36898e03a0d8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Infer prior GRN from [pySCENIC](https://pyscenic.readthedocs.io/en/latest/installation.html)\n", 9 | "In this notebook, we use [SCENIC](https://scenic.aertslab.org/) to infer a prior gene regulatory network (GRN) for the RegVelo pipeline." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "id": "405f8373-e7d0-4ebe-b03f-3937d1aa9d46", 15 | "metadata": {}, 16 | "source": [ 17 | "## Library import" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "b3b40717-9fb5-4ef4-8733-6fb7f4dff658", 24 | "metadata": { 25 | "tags": [] 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import os\n", 30 | "import numpy as np\n", 31 | "import pandas as pd\n", 32 | "import scanpy as sc\n", 33 | "import loompy as lp\n", 34 | "\n", 35 | "import glob" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "286984e3-5844-418b-bd2e-70eda9c109a1", 42 | "metadata": { 43 | "tags": [] 44 | }, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "/home/icb/yifan.chen/miniconda3/envs/pyscenic/lib/python3.10/site-packages/session_info/main.py:213: UserWarning: The '__version__' attribute is deprecated and will be removed in MarkupSafe 3.1. Use feature detection, or `importlib.metadata.version(\"markupsafe\")`, instead.\n", 51 | " mod_version = _find_version(mod.__version__)\n" 52 | ] 53 | }, 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "-----\n", 59 | "anndata 0.11.4\n", 60 | "scanpy 1.10.4\n", 61 | "-----\n", 62 | "PIL 11.2.1\n", 63 | "asttokens NA\n", 64 | "charset_normalizer 3.4.1\n", 65 | "cloudpickle 3.1.1\n", 66 | "comm 0.2.1\n", 67 | "cycler 0.12.1\n", 68 | "cython_runtime NA\n", 69 | "cytoolz 1.0.1\n", 70 | "dask 2025.4.1\n", 71 | "dateutil 2.9.0.post0\n", 72 | "debugpy 1.8.11\n", 73 | "decorator 5.1.1\n", 74 | "exceptiongroup 1.2.0\n", 75 | "executing 0.8.3\n", 76 | "h5py 3.13.0\n", 77 | "ipykernel 6.29.5\n", 78 | "jedi 0.19.2\n", 79 | "jinja2 3.1.6\n", 80 | "joblib 1.4.2\n", 81 | "kiwisolver 1.4.8\n", 82 | "legacy_api_wrap NA\n", 83 | "llvmlite 0.44.0\n", 84 | "loompy 3.0.8\n", 85 | "lz4 4.4.4\n", 86 | "markupsafe 3.0.2\n", 87 | "matplotlib 3.10.1\n", 88 | "mpl_toolkits NA\n", 89 | "natsort 8.4.0\n", 90 | "numba 0.61.2\n", 91 | "numexpr 2.10.2\n", 92 | "numpy 2.2.5\n", 93 | "numpy_groupies 0.11.2\n", 94 | "packaging 24.2\n", 95 | "pandas 2.2.3\n", 96 | "parso 0.8.4\n", 97 | "platformdirs 4.3.7\n", 98 | "prompt_toolkit 3.0.43\n", 99 | "psutil 5.9.0\n", 100 | "pure_eval 0.2.2\n", 101 | "pyarrow 20.0.0\n", 102 | "pydev_ipython NA\n", 103 | "pydevconsole NA\n", 104 | "pydevd 3.2.3\n", 105 | "pydevd_file_utils NA\n", 106 | "pydevd_plugins NA\n", 107 | "pydevd_tracing NA\n", 108 | "pygments 2.19.1\n", 109 | "pyparsing 3.2.3\n", 110 | "pytz 2025.2\n", 111 | "scipy 1.15.2\n", 112 | "session_info v1.0.1\n", 113 | "six 1.17.0\n", 114 | "sklearn 1.6.1\n", 115 | "stack_data 0.2.0\n", 116 | "tblib 3.1.0\n", 117 | "threadpoolctl 3.6.0\n", 118 | "tlz 1.0.1\n", 119 | "toolz 1.0.0\n", 120 | "tornado 6.4.2\n", 121 | "traitlets 5.14.3\n", 122 | "typing_extensions NA\n", 123 | "wcwidth 0.2.5\n", 124 | "yaml 6.0.2\n", 125 | "zipp NA\n", 126 | "zmq 26.2.0\n", 127 | "zoneinfo NA\n", 128 | "-----\n", 129 | "IPython 8.30.0\n", 130 | "jupyter_client 8.6.3\n", 131 | "jupyter_core 5.7.2\n", 132 | "-----\n", 133 | "Python 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0]\n", 134 | "Linux-5.14.0-427.37.1.el9_4.x86_64-x86_64-with-glibc2.34\n", 135 | "-----\n", 136 | "Session information updated at 2025-04-28 11:30\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "sc.settings.verbosity = 3\n", 142 | "sc.logging.print_versions()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "943265ea-267c-4c47-b00c-28ef7fbd8ab2", 148 | "metadata": {}, 149 | "source": [ 150 | "## Load data and output to loom file\n", 151 | "Read murine neural crest data." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "6f4a000b-455a-4625-a3e2-c0b2bafc5cea", 158 | "metadata": { 159 | "tags": [] 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "adata = rgv.datasets.murine_nc(data_type = \"normalized\")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 4, 169 | "id": "b7ce4c29-e8b3-428d-a930-469a155bd6f0", 170 | "metadata": { 171 | "tags": [] 172 | }, 173 | "outputs": [ 174 | { 175 | "data": { 176 | "text/plain": [ 177 | "AnnData object with n_obs × n_vars = 6788 × 30717\n", 178 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron'\n", 179 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank'\n", 180 | " obsm: 'X_pca', 'X_umap'" 181 | ] 182 | }, 183 | "execution_count": 4, 184 | "metadata": {}, 185 | "output_type": "execute_result" 186 | } 187 | ], 188 | "source": [ 189 | "adata" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 5, 195 | "id": "56baecec-0ee1-4f30-bddb-31426b2e3b35", 196 | "metadata": { 197 | "tags": [] 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "adata = sc.AnnData(adata.X, obs=adata.obs, var=adata.var)\n", 202 | "adata.var[\"Gene\"] = adata.var_names\n", 203 | "adata.obs[\"CellID\"] = adata.obs_names\n", 204 | "adata.write_loom(\"adata.loom\")" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "2005e02a-abb1-4db9-b944-fe9fdb8ac9f5", 210 | "metadata": {}, 211 | "source": [ 212 | "## SCENIC steps\n", 213 | "In the following, we use [SCENIC](https://scenic.aertslab.org/) to infer prior regulation information. Installation and usage steps are given in [pySCENIC](https://pyscenic.readthedocs.io/en/latest/installation.html) and are demonstrated in [SCENICprotocol](https://github.com/aertslab/SCENICprotocol/tree/master)." 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 7, 219 | "id": "64e1477e-6539-4605-abbf-c7f72ddfb8dc", 220 | "metadata": { 221 | "tags": [] 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "# path to loom file created previously\n", 226 | "f_loom_path_scenic = \"adata.loom\"\n", 227 | "# path to list of transcription factors\n", 228 | "f_tfs = \"allTFs_mm.txt\"" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 9, 234 | "id": "6426c2fe-a664-479e-9113-45fbfb62bcef", 235 | "metadata": { 236 | "tags": [] 237 | }, 238 | "outputs": [ 239 | { 240 | "name": "stdout", 241 | "output_type": "stream", 242 | "text": [ 243 | "/bin/bash: line 1: pyscenic: command not found\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "!pyscenic grn {f_loom_path_scenic} {f_tfs} -o \"adj.csv\" --num_workers 24" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "id": "1f9fc0ba-0012-48eb-bcbd-3f859dcc37a8", 255 | "metadata": { 256 | "tags": [] 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "# path to ranking databases in feather format\n", 261 | "f_db_glob = \"scenic/cisTarget_databases/*feather\"\n", 262 | "f_db_names = ' '.join( glob.glob(f_db_glob) )\n", 263 | "\n", 264 | "# path to motif databases\n", 265 | "f_motif_path = \"scenic/cisTarget_databases/motifs-v9-nr.mgi-m0.001-o0.0.tbl\"" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "e1bdaf1e-91e2-49ff-91fd-4143bbdc032a", 272 | "metadata": { 273 | "tags": [] 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "!pyscenic ctx \"adj.csv\" \\\n", 278 | " {f_db_names} \\\n", 279 | " --annotations_fname {f_motif_path} \\\n", 280 | " --expression_mtx_fname {f_loom_path_scenic} \\\n", 281 | " --output \"reg.csv\" \\\n", 282 | " --all_modules \\\n", 283 | " --num_workers 24" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "id": "6c2f499b-b032-4afd-bb12-57eae82d9a46", 290 | "metadata": { 291 | "tags": [] 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "f_pyscenic_output = \"pyscenic_output_all_regulon_no_mask.loom\"" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "id": "db083910-2345-4e39-8ddd-ff14f210c498", 302 | "metadata": { 303 | "tags": [] 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "!pyscenic aucell \\\n", 308 | " {f_loom_path_scenic} \\\n", 309 | " \"reg.csv\" \\\n", 310 | " --output {f_pyscenic_output} \\\n", 311 | " --num_workers 4" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 26, 317 | "id": "fbba0c0c-109a-4112-a1eb-340c3942a3b7", 318 | "metadata": { 319 | "tags": [] 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "# collect SCENIC AUCell output\n", 324 | "lf = lp.connect(f_pyscenic_output, mode='r+', validate=False )\n", 325 | "auc_mtx = pd.DataFrame(lf.ca.RegulonsAUC, index=lf.ca.CellID)\n", 326 | "regulons = lf.ra.Regulons" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 27, 332 | "id": "7a0925e3-d1da-4158-a54a-44535613c6c0", 333 | "metadata": { 334 | "tags": [] 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "res = pd.concat([pd.Series(r.tolist(), index=regulons.dtype.names) for r in regulons], axis=1)\n", 339 | "res.columns = lf.row_attrs[\"SYMBOL\"]\n", 340 | "res.to_csv(\"regulon_mat_all_regulons.csv\")" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "id": "512bff3d-6cfb-4dd0-a332-703f867943af", 346 | "metadata": {}, 347 | "source": [ 348 | "## Create prior GRN for RegVelo\n", 349 | "In the following, we preprocess the GRN inferred from [SCENIC](https://scenic.aertslab.org/), saved as `regulon_mat_all_regulons.csv`. We first read the regulon file, where rows are regulators and columns are target genes. We further extract the names of the transcription factors (TFs) from the row indices using a regex and collapse dublicte TFs by summing their rows." 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "id": "f5f040df-aff8-43ff-91b4-b4e7ff6dbf18", 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "# load saved regulon-target matrix\n", 360 | "reg = pd.read_csv(\"regulon_mat_all_regulons.csv\", index_col = 0)\n", 361 | "\n", 362 | "reg.index = reg.index.str.extract(r\"(\\w+)\")[0]\n", 363 | "reg = reg.groupby(reg.index).sum()" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "id": "5c18d266-5726-4ecc-ab38-2de5ce0d185b", 369 | "metadata": {}, 370 | "source": [ 371 | "We further binarize the matrix, where 1 indicates the presence of regulation and 0 indicates otherwise and get the list of TFs and genes." 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "id": "bd475a15-b5b7-49a9-b549-c7fece33ba81", 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "reg[reg != 0] = 1\n", 382 | "\n", 383 | "TF = np.unique(list(map(lambda x: x.split(\"(\")[0], reg.index.tolist())))\n", 384 | "genes = np.unique(TF.tolist() + reg.columns.tolist())" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "id": "4a9633d5-204a-431f-a121-7cd927604b72", 390 | "metadata": {}, 391 | "source": [ 392 | "For the prior GRN, we first construct an empty square matrix and populate it based on the previously binarized regulation information. We further remove the genes that are neither a TF nor a target gene (i.e. remove empty rows and comlumns) and save the cleaned and structured GRN to a `.parquet` file for RegVelo's downstream pipeline." 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 76, 398 | "id": "dd55d3a2-f98e-42a3-849a-22471675db2d", 399 | "metadata": { 400 | "tags": [] 401 | }, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "Done! processed GRN with 543 TF and 30717 targets\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "GRN = pd.DataFrame(0, index=genes, columns=genes)\n", 413 | "GRN.loc[TF,reg.columns.tolist()] = np.array(reg)\n", 414 | "\n", 415 | "mask = (GRN.sum(0) != 0) | (GRN.sum(1) != 0)\n", 416 | "GRN = GRN.loc[mask, mask].copy()\n", 417 | "\n", 418 | "GRN.to_parquet(\"regulon_mat_processed_all_regulons.parquet\")\n", 419 | "print(\"Done! processed GRN with \" + str(reg.shape[0]) + \" TFs and \" + str(reg.shape[1]) + \" targets\")" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "id": "5a5ff903-a13d-400c-9dcc-d2d75bd4d4f7", 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [] 429 | } 430 | ], 431 | "metadata": { 432 | "kernelspec": { 433 | "display_name": "Python (pyscenic)", 434 | "language": "python", 435 | "name": "pyscenic" 436 | }, 437 | "language_info": { 438 | "codemirror_mode": { 439 | "name": "ipython", 440 | "version": 3 441 | }, 442 | "file_extension": ".py", 443 | "mimetype": "text/x-python", 444 | "name": "python", 445 | "nbconvert_exporter": "python", 446 | "pygments_lexer": "ipython3", 447 | "version": "3.10.16" 448 | } 449 | }, 450 | "nbformat": 4, 451 | "nbformat_minor": 5 452 | } 453 | -------------------------------------------------------------------------------- /docs/tutorials/murine/02_RegVelo_preparation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e224f114-2918-4942-82d1-801e1429af7b", 6 | "metadata": {}, 7 | "source": [ 8 | "# Preprocess data and add prior GRN information\n", 9 | "In this notebook, we will go through the preprocessing steps needed prior to running RegVelo pipeline." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "id": "862aad5e-ac38-4ad6-9557-989e8f261fe8", 15 | "metadata": {}, 16 | "source": [ 17 | "## Library import " 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "b1205fa6-37b9-45b7-bae4-6f45369db6f6", 24 | "metadata": { 25 | "tags": [] 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.\n", 33 | " warnings.warn(msg, FutureWarning)\n", 34 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.\n", 35 | " warnings.warn(msg, FutureWarning)\n", 36 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.\n", 37 | " warnings.warn(msg, FutureWarning)\n", 38 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.\n", 39 | " warnings.warn(msg, FutureWarning)\n", 40 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.\n", 41 | " warnings.warn(msg, FutureWarning)\n", 42 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.\n", 43 | " warnings.warn(msg, FutureWarning)\n", 44 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.\n", 45 | " warnings.warn(msg, FutureWarning)\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "import scvelo as scv\n", 51 | "import scanpy as sc\n", 52 | "import pandas as pd\n", 53 | "import numpy as np\n", 54 | "\n", 55 | "import matplotlib.pyplot as plt\n", 56 | "import mplscience\n", 57 | "import seaborn as sns\n", 58 | "\n", 59 | "import regvelo as rgv" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "cd088464-c366-40a8-a34e-8d3947fa4991", 65 | "metadata": {}, 66 | "source": [ 67 | "## Load data\n", 68 | "Read murine neural crest data that contains `.layers['spliced']` and `.layers['unspliced']`." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "2c1022b2-6133-4072-acb1-df900c0875b7", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "adata = rgv.datasets.murine_nc(data_type = \"velocyto\")" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "id": "cf07867d-5b7a-4feb-82e1-fa382ad19613", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "AnnData object with n_obs × n_vars = 6788 × 30717\n", 91 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron'\n", 92 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank'\n", 93 | " obsm: 'X_pca', 'X_umap'\n", 94 | " layers: 'spliced', 'unspliced'" 95 | ] 96 | }, 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "adata" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "aa1e8c0d-7e11-4a77-bd6f-042e27a09b6e", 109 | "metadata": {}, 110 | "source": [ 111 | "## Data preprocessing" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "id": "85fa5243-6477-4a86-9640-4fd2f1172d1d", 118 | "metadata": { 119 | "tags": [] 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Filtered out 22217 genes that are detected 20 counts (shared).\n", 127 | "Normalized count data: X, spliced, unspliced.\n", 128 | "Extracted 3000 highly variable genes.\n" 129 | ] 130 | }, 131 | { 132 | "name": "stderr", 133 | "output_type": "stream", 134 | "text": [ 135 | "/tmp/ipykernel_3655946/1257038519.py:4: DeprecationWarning: `log1p` is deprecated since scVelo v0.3.0 and will be removed in a future version. Please use `log1p` from `scanpy.pp` instead.\n", 136 | " scv.pp.log1p(adata)\n" 137 | ] 138 | }, 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "computing moments based on connectivities\n", 144 | " finished (0:00:02) --> added \n", 145 | " 'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "scv.pp.filter_genes(adata, min_shared_counts=20)\n", 151 | "scv.pp.normalize_per_cell(adata)\n", 152 | "scv.pp.filter_genes_dispersion(adata, n_top_genes=3000)\n", 153 | "scv.pp.log1p(adata)\n", 154 | "\n", 155 | "sc.pp.neighbors(adata,n_pcs = 30,n_neighbors = 50)\n", 156 | "sc.tl.umap(adata)\n", 157 | "scv.pp.moments(adata, n_pcs=None, n_neighbors=None)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "id": "91c087a7-ec36-4d86-a891-c297eba29a7c", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "AnnData object with n_obs × n_vars = 6788 × 3000\n", 170 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'\n", 171 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'\n", 172 | " uns: 'log1p', 'neighbors', 'umap'\n", 173 | " obsm: 'X_pca', 'X_umap'\n", 174 | " layers: 'spliced', 'unspliced', 'Ms', 'Mu'\n", 175 | " obsp: 'distances', 'connectivities'" 176 | ] 177 | }, 178 | "execution_count": 5, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "adata" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "943f6605-e091-4b8c-ad4b-075dbc4ef803", 190 | "metadata": {}, 191 | "source": [ 192 | "## Load prior GRN created from notebook 'Infer prior GRN from [pySCENIC](https://pyscenic.readthedocs.io/en/latest/installation.html)' for RegVelo\n", 193 | "In the following, we load the processed prior GRN infromation into our AnnData object. In this step `.uns['skeleton']` and `.var['TF']` are added, which will be needed for RegVelo's velocity pipeline." 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 6, 199 | "id": "f7d9c59f-6238-494b-8649-178ae5dd783b", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "GRN = pd.read_parquet(\"regulon_mat_processed_all_regulons.parquet\")" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 7, 209 | "id": "63410f40-8ee1-48cb-9a93-5737fb3aa9ef", 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/html": [ 215 | "
\n", 216 | "\n", 229 | "\n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | "
0610005C13Rik0610009L18Rik0610010K14Rik0610012G03Rik0610030E20Rik0610038B21Rik0610040B10Rik0610040J01Rik0610043K17Rik1110002L01Rik...Zswim8Zw10ZwilchZwintZxdbZxdcZyg11bZyxZzef1Zzz3
0610005C13Rik0000000000...0000000000
0610009L18Rik0000000000...0000000000
0610010K14Rik0000000000...0000000000
0610012G03Rik0000000000...0000000000
0610030E20Rik0000000000...0000000000
\n", 379 | "

5 rows × 13697 columns

\n", 380 | "
" 381 | ], 382 | "text/plain": [ 383 | " 0610005C13Rik 0610009L18Rik 0610010K14Rik 0610012G03Rik \\\n", 384 | "0610005C13Rik 0 0 0 0 \n", 385 | "0610009L18Rik 0 0 0 0 \n", 386 | "0610010K14Rik 0 0 0 0 \n", 387 | "0610012G03Rik 0 0 0 0 \n", 388 | "0610030E20Rik 0 0 0 0 \n", 389 | "\n", 390 | " 0610030E20Rik 0610038B21Rik 0610040B10Rik 0610040J01Rik \\\n", 391 | "0610005C13Rik 0 0 0 0 \n", 392 | "0610009L18Rik 0 0 0 0 \n", 393 | "0610010K14Rik 0 0 0 0 \n", 394 | "0610012G03Rik 0 0 0 0 \n", 395 | "0610030E20Rik 0 0 0 0 \n", 396 | "\n", 397 | " 0610043K17Rik 1110002L01Rik ... Zswim8 Zw10 Zwilch Zwint \\\n", 398 | "0610005C13Rik 0 0 ... 0 0 0 0 \n", 399 | "0610009L18Rik 0 0 ... 0 0 0 0 \n", 400 | "0610010K14Rik 0 0 ... 0 0 0 0 \n", 401 | "0610012G03Rik 0 0 ... 0 0 0 0 \n", 402 | "0610030E20Rik 0 0 ... 0 0 0 0 \n", 403 | "\n", 404 | " Zxdb Zxdc Zyg11b Zyx Zzef1 Zzz3 \n", 405 | "0610005C13Rik 0 0 0 0 0 0 \n", 406 | "0610009L18Rik 0 0 0 0 0 0 \n", 407 | "0610010K14Rik 0 0 0 0 0 0 \n", 408 | "0610012G03Rik 0 0 0 0 0 0 \n", 409 | "0610030E20Rik 0 0 0 0 0 0 \n", 410 | "\n", 411 | "[5 rows x 13697 columns]" 412 | ] 413 | }, 414 | "execution_count": 7, 415 | "metadata": {}, 416 | "output_type": "execute_result" 417 | } 418 | ], 419 | "source": [ 420 | "GRN.head()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 8, 426 | "id": "2087dc84-d2cd-4ae1-989f-5691ac6dbbc0", 427 | "metadata": {}, 428 | "outputs": [ 429 | { 430 | "name": "stderr", 431 | "output_type": "stream", 432 | "text": [ 433 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/regvelo/preprocessing/set_prior_grn.py:75: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.\n", 434 | " adata.uns[\"regulators\"] = adata.var_names.to_numpy()\n" 435 | ] 436 | } 437 | ], 438 | "source": [ 439 | "adata = rgv.pp.set_prior_grn(adata, GRN.T)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 9, 445 | "id": "652d5597-1abd-4c71-a328-00d3d8f5510f", 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "data": { 450 | "text/plain": [ 451 | "AnnData object with n_obs × n_vars = 6788 × 2112\n", 452 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'\n", 453 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'\n", 454 | " uns: 'log1p', 'neighbors', 'umap', 'regulators', 'targets', 'skeleton', 'network'\n", 455 | " obsm: 'X_pca', 'X_umap'\n", 456 | " layers: 'spliced', 'unspliced', 'Ms', 'Mu'\n", 457 | " obsp: 'distances', 'connectivities'" 458 | ] 459 | }, 460 | "execution_count": 9, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "adata" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 10, 472 | "id": "36ccbcad-1c0e-4807-9fd0-3ba512b8ced3", 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "name": "stdout", 477 | "output_type": "stream", 478 | "text": [ 479 | "computing velocities\n", 480 | " finished (0:00:00) --> added \n", 481 | " 'velocity', velocity vectors for each individual cell (adata.layers)\n" 482 | ] 483 | } 484 | ], 485 | "source": [ 486 | "velocity_genes = rgv.pp.preprocess_data(adata.copy()).var_names.tolist()\n", 487 | "\n", 488 | "# select TFs that regulate at least one gene\n", 489 | "TF = adata.var_names[adata.uns[\"skeleton\"].sum(1) != 0]\n", 490 | "var_mask = np.union1d(TF, velocity_genes)\n", 491 | "\n", 492 | "adata = adata[:, var_mask].copy()" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 11, 498 | "id": "575c89de-7fc5-4a3b-9e8a-0287028f90fe", 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "name": "stdout", 503 | "output_type": "stream", 504 | "text": [ 505 | "Number of genes: 1187\n", 506 | "Number of genes: 1164\n" 507 | ] 508 | } 509 | ], 510 | "source": [ 511 | "adata = rgv.pp.filter_genes(adata)\n", 512 | "adata = rgv.pp.preprocess_data(adata, filter_on_r2=False)\n", 513 | "\n", 514 | "adata.var[\"velocity_genes\"] = adata.var_names.isin(velocity_genes)\n", 515 | "adata.var[\"TF\"] = adata.var_names.isin(TF)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 12, 521 | "id": "152b7016-ff2a-435e-9f02-601ceadc31a7", 522 | "metadata": {}, 523 | "outputs": [ 524 | { 525 | "data": { 526 | "text/plain": [ 527 | "AnnData object with n_obs × n_vars = 6788 × 1164\n", 528 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'\n", 529 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_genes', 'TF'\n", 530 | " uns: 'log1p', 'neighbors', 'umap', 'regulators', 'targets', 'skeleton', 'network'\n", 531 | " obsm: 'X_pca', 'X_umap'\n", 532 | " layers: 'spliced', 'unspliced', 'Ms', 'Mu'\n", 533 | " obsp: 'distances', 'connectivities'" 534 | ] 535 | }, 536 | "execution_count": 12, 537 | "metadata": {}, 538 | "output_type": "execute_result" 539 | } 540 | ], 541 | "source": [ 542 | "adata" 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "id": "ad07a3c3-9086-484b-8447-adea76ce2e45", 548 | "metadata": {}, 549 | "source": [ 550 | "## Save data\n", 551 | "This data, with the parameters chosen in this tutorial, can also be assessed by calling `rgv.datasets.murine_nc(data_type = \"preprocessed\")`" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 13, 557 | "id": "534c33f0-197a-424c-928e-9c5215fe2660", 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "adata.write_h5ad(\"adata_processed_velo.h5ad\")" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "id": "1d0da013-f4d5-4efd-9145-ba497056581b", 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [] 571 | } 572 | ], 573 | "metadata": { 574 | "kernelspec": { 575 | "display_name": "regvelo-py310-v2", 576 | "language": "python", 577 | "name": "regvelo-py310-v2" 578 | }, 579 | "language_info": { 580 | "codemirror_mode": { 581 | "name": "ipython", 582 | "version": 3 583 | }, 584 | "file_extension": ".py", 585 | "mimetype": "text/x-python", 586 | "name": "python", 587 | "nbconvert_exporter": "python", 588 | "pygments_lexer": "ipython3", 589 | "version": "3.10.13" 590 | } 591 | }, 592 | "nbformat": 4, 593 | "nbformat_minor": 5 594 | } 595 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.0.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | authors = ["Weixu Wang "] 7 | classifiers = [ 8 | "Development Status :: 4 - Beta", 9 | "Intended Audience :: Science/Research", 10 | "Natural Language :: English", 11 | "Programming Language :: Python :: 3.9", 12 | "Programming Language :: Python :: 3.10", 13 | "Programming Language :: Python :: 3.11", 14 | "Operating System :: MacOS :: MacOS X", 15 | "Operating System :: Microsoft :: Windows", 16 | "Operating System :: POSIX :: Linux", 17 | "Topic :: Scientific/Engineering :: Bio-Informatics", 18 | ] 19 | description = "Estimation of RNA velocity with variational inference." 20 | documentation = "https://scvi-tools.org" 21 | homepage = "https://github.com/theislab/RegVelo/" 22 | license = "BSD-3-Clause" 23 | name = "regvelo" 24 | packages = [ 25 | {include = "regvelo"}, 26 | ] 27 | readme = "README.md" 28 | version = "0.2.0" 29 | 30 | [tool.poetry.dependencies] 31 | anndata = ">=0.10.8" 32 | black = {version = ">=20.8b1", optional = true} 33 | codecov = {version = ">=2.0.8", optional = true} 34 | ruff = {version = "*", optional = true} 35 | importlib-metadata = {version = "^1.0", python = "<3.8"} 36 | ipython = {version = ">=7.1.1", optional = true} 37 | jupyter = {version = ">=1.0", optional = true} 38 | pre-commit = {version = ">=2.7.1", optional = true} 39 | sphinx-book-theme = {version = ">=1.0.0", optional = true} 40 | myst-nb = {version = "*", optional = true} 41 | sphinx-copybutton = {version = "*", optional = true} 42 | sphinxcontrib-bibtex = {version = "2.6.3", optional = true} 43 | ipykernel = {version = "*", optional = true} 44 | pytest = {version = ">=4.4", optional = true} 45 | pytest-cov = {version = "*", optional = true} 46 | python = ">=3.9,<4.0" 47 | python-igraph = {version = "*", optional = true} 48 | scanpy = {version = ">=1.10.3", optional = true} 49 | scanpydoc = {version = ">=0.5", optional = true} 50 | scvelo = ">=0.3.2" 51 | scvi-tools = ">=1.0.0,<1.2.1" 52 | scikit-learn = ">=0.21.2" 53 | velovi = ">=0.3.1" 54 | torchode = ">=0.1.6" 55 | cellrank = ">=2.0.0" 56 | matplotlib = ">=3.7.3" 57 | sphinx = {version = ">=4.1", optional = true} 58 | sphinx-autodoc-typehints = {version = "*", optional = true} 59 | torch = "<2.6.0" 60 | 61 | 62 | [tool.poetry.extras] 63 | dev = ["black", "pytest", "pytest-cov", "ruff", "codecov", "scanpy", "loompy", "jupyter", "pre-commit"] 64 | docs = [ 65 | "sphinx", 66 | "scanpydoc", 67 | "ipython", 68 | "myst-nb", 69 | "sphinx-book-theme", 70 | "sphinx-copybutton", 71 | "sphinxcontrib-bibtex", 72 | "ipykernel", 73 | "ipython", 74 | ] 75 | tutorials = ["scanpy"] 76 | 77 | [tool.poetry.dev-dependencies] 78 | 79 | 80 | [tool.coverage.run] 81 | source = ["regvelo"] 82 | omit = [ 83 | "**/test_*.py", 84 | ] 85 | 86 | [tool.pytest.ini_options] 87 | testpaths = ["tests"] 88 | xfail_strict = true 89 | 90 | 91 | [tool.black] 92 | include = '\.pyi?$' 93 | exclude = ''' 94 | ( 95 | /( 96 | \.eggs 97 | | \.git 98 | | \.hg 99 | | \.mypy_cache 100 | | \.tox 101 | | \.venv 102 | | _build 103 | | buck-out 104 | | build 105 | | dist 106 | )/ 107 | ) 108 | ''' 109 | 110 | [tool.ruff] 111 | src = ["."] 112 | line-length = 119 113 | target-version = "py38" 114 | select = [ 115 | "F", # Errors detected by Pyflakes 116 | "E", # Error detected by Pycodestyle 117 | "W", # Warning detected by Pycodestyle 118 | "I", # isort 119 | "D", # pydocstyle 120 | "B", # flake8-bugbear 121 | "TID", # flake8-tidy-imports 122 | "C4", # flake8-comprehensions 123 | "BLE", # flake8-blind-except 124 | "UP", # pyupgrade 125 | "RUF100", # Report unused noqa directives 126 | ] 127 | ignore = [ 128 | # line too long -> we accept long comment lines; black gets rid of long code lines 129 | "E501", 130 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 131 | "E731", 132 | # allow I, O, l as variable names -> I is the identity matrix 133 | "E741", 134 | # Missing docstring in public package 135 | "D104", 136 | # Missing docstring in public module 137 | "D100", 138 | # Missing docstring in __init__ 139 | "D107", 140 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 141 | "B008", 142 | # __magic__ methods are are often self-explanatory, allow missing docstrings 143 | "D105", 144 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 145 | "D400", 146 | # First line should be in imperative mood; try rephrasing 147 | "D401", 148 | ## Disable one in each pair of mutually incompatible rules 149 | # We don’t want a blank line before a class docstring 150 | "D203", 151 | # We want docstrings to start immediately after the opening triple quote 152 | "D213", 153 | # Missing argument description in the docstring TODO: enable 154 | "D417", 155 | ] 156 | 157 | [tool.ruff.per-file-ignores] 158 | "docs/*" = ["I", "BLE001"] 159 | "tests/*" = ["D"] 160 | "*/__init__.py" = ["F401"] 161 | "regvelo/__init__.py" = ["I"] 162 | 163 | [tool.jupytext] 164 | formats = "ipynb,md" 165 | 166 | [tool.cruft] 167 | skip = [ 168 | "tests", 169 | "src/**/__init__.py", 170 | "src/**/basic.py", 171 | "docs/api.md", 172 | "docs/changelog.md", 173 | "docs/references.bib", 174 | "docs/references.md", 175 | "docs/notebooks/example.ipynb", 176 | ] 177 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-20.04 4 | tools: 5 | python: "3.10" 6 | sphinx: 7 | configuration: docs/conf.py 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - docs 14 | -------------------------------------------------------------------------------- /regvelo/ModelComparison.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scipy 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | 8 | import scipy.stats 9 | from scipy.stats import ttest_rel 10 | 11 | import cellrank as cr 12 | import scanpy as sc 13 | import scvi 14 | from ._model import REGVELOVI 15 | import scvelo as scv 16 | 17 | 18 | from .tools.set_output import set_output 19 | from .tools._tsi import get_tsi_score 20 | 21 | 22 | # Packages used for data type validation 23 | from typing import List, Optional, Union, Dict 24 | from anndata import AnnData 25 | 26 | class ModelComparison: 27 | """Compare different types of RegVelo models : cite:p: `Wang2025`. 28 | 29 | This class is used to compare different RegVelo models with different optimization mode (soft, hard, soft_regularized) and under different normalization factor lamda2. 30 | User can evaluate and visulize competence of different types of models based on various side information (Real time, Pseudo Time, Stemness Score, Terminal States Identification, Cross Boundary Correctness) of cell. 31 | Finally, it will return a barplot with best performed model marked, and its performance will also be highlighted by significance test. 32 | 33 | Examples 34 | ---------- 35 | See notebook. 36 | 37 | """ 38 | def __init__(self, 39 | terminal_states: List = None, 40 | state_transition: Dict = None, 41 | n_states: int = None): 42 | """Initialize parameters in comparision object. 43 | 44 | Parameters 45 | ---------- 46 | terminal_states 47 | A list records all terminal states among all cell types. 48 | This parameter is not necessary if you don't use TSI as side_information. Please make sure they are consistent with information stored in 'side_key' under TSI mode. 49 | state_transition 50 | A dict records all possible state transition relationships among all cell types. 51 | This parameter is not necessary if you don't use CBC as side_information. Please make sure they are consistent with information stored in 'side_key' under CBC mode. 52 | n_state 53 | A integer provide the number of cell clusters in total. 54 | This parameter is not necessary if you don't use TSI as side_information. 55 | 56 | Returns 57 | ---------- 58 | An comparision object. You can deal with more operations as follows. 59 | 60 | """ 61 | self.TERMINAL_STATES = terminal_states 62 | self.STATE_TRANSITION = state_transition 63 | self.N_STATES = n_states 64 | 65 | self.METHOD = None 66 | self.MODEL_LIST = None 67 | 68 | self.side_info_dict = { 69 | 'Pseudo_Time':'dpt_pseudotime', 70 | 'Stemness_Score': 'ct_score', 71 | 'Real_Time': None, 72 | 'TSI': None, 73 | 'CBC': None 74 | } 75 | self.MODEL_TRAINED = {} 76 | 77 | def validate_input(self, 78 | adata: AnnData, 79 | model_list: List[str] = None, 80 | side_information: str = None, 81 | lam2: Union[List[float], float] = None, 82 | side_key: str = None) -> None: 83 | 84 | # 1.Validate adata 85 | if not isinstance(adata, AnnData): 86 | raise TypeError(f"Expected AnnData object, got {type(adata).__name__}") 87 | layers = ['Ms', 'Mu', 'fit_t'] 88 | for layer in layers: 89 | if layer not in adata.layers: 90 | raise ValueError(f"Missing required layer: {layer}") 91 | if 'skeleton' not in adata.uns: 92 | raise ValueError("Missing required 'skeleton' in adata.uns") 93 | 94 | if 'TF' not in adata.var: 95 | raise ValueError("Missing required 'TF' column in adata.var") 96 | 97 | # 2.Validate Model_list 98 | if model_list is not None: 99 | valid_models = ['hard', 'soft', 'soft_regularized'] 100 | if not isinstance(model_list, list) or len(model_list) == 0: 101 | raise ValueError("model_list must be a non-empty list") 102 | for model in model_list: 103 | if model not in valid_models: 104 | raise ValueError(f"Invalid model: {model}. Valid models are {valid_models}") 105 | if model == 'soft_regularized' and lam2 is None: 106 | raise ValueError(f"Under 'soft_regularized' mode, lam2 must be given") 107 | if lam2 is not None: 108 | if not isinstance(lam2, (float, list)): 109 | raise TypeError('lam2 must be a float or a list of floats') 110 | if isinstance(lam2, list): 111 | if len(lam2) == 0: 112 | raise ValueError('lam2 list can not be empty') 113 | for num in lam2: 114 | if not isinstance(num, float): 115 | raise ValueError('All elements in lam2 list must be float') 116 | if not 0.0 < num <= 1.0: 117 | raise ValueError('lam2 is expected to be in range of (0,1)') 118 | 119 | # 3.Validate side_information 120 | if side_information is not None: 121 | if not isinstance(side_information, str): 122 | raise TypeError(f"side_information must be a string") 123 | 124 | if side_information not in self.side_info_dict.keys(): 125 | raise ValueError(f"Valid side_information are {self.side_info_dict.keys()}") 126 | 127 | if side_key is not None: 128 | if not isinstance(side_key, str): 129 | raise TypeError(f"side_key must be a string") 130 | if side_key not in adata.obs: 131 | raise TypeError(f"side_key: {side_key} not found in adata.obs.") 132 | if side_key is None: 133 | side_key = self.side_info_dict[side_information] 134 | if side_key is not None: 135 | if side_key not in adata.obs: 136 | raise TypeError(f"Default side_key: {side_key} not found in adata.obs, please input it manualy with parameter: side_key") 137 | 138 | def min_max_scaling(self,x): 139 | return (x - np.min(x)) / (np.max(x) - np.min(x)) 140 | 141 | def train( 142 | self, 143 | adata: AnnData, 144 | model_list: List[str], 145 | lam2: Union[List[float], float] = None, 146 | n_repeat: int = 1 147 | ) -> List: 148 | """Train all the possible models given by users, and stored them in a dictionary, where users can reach them easily and deal with them in batch.If there are already model trained and saved before, they won't be removed. 149 | 150 | Parameters 151 | ---------- 152 | adata 153 | The annotated data matrix. After input of adata, the object will store it as self variable. 154 | model_list 155 | The list of valid model type, including 'Soft', 'Hard', 'Soft_regularized' 156 | lam2 157 | Normalization factor used under 'soft_regularized' mode. A float or a list of float number in range of (0,1) 158 | 159 | Returns 160 | ---------- 161 | A dictionary key names, represent to all models trained in this step. 162 | 163 | """ 164 | self.validate_input(adata, model_list = model_list, lam2 = lam2) 165 | self.ADATA = adata 166 | 167 | if not isinstance(n_repeat, int) or n_repeat < 1: 168 | raise ValueError("n_repeat must be a positive integer") 169 | 170 | W = adata.uns["skeleton"].copy() 171 | W = torch.tensor(np.array(W)).int() 172 | TF = adata.var_names[adata.var['TF']] 173 | REGVELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu") 174 | 175 | 176 | for model in model_list: 177 | for nrun in range(n_repeat): 178 | scvi.settings.seed = nrun 179 | if model == 'soft_regularized': 180 | if isinstance(lam2,list): 181 | for lambda2 in lam2: 182 | vae = REGVELOVI( 183 | adata, 184 | W=W.T, 185 | regulators=TF, 186 | lam2 = lambda2 187 | ) 188 | vae.train() 189 | self.MODEL_TRAINED[f"{model}\nlam2:{lambda2}_{nrun}"] = vae 190 | else: 191 | vae = REGVELOVI( 192 | adata, 193 | W=W.T, 194 | regulators=TF, 195 | lam2=lam2 196 | ) 197 | vae.train() 198 | self.MODEL_TRAINED[f"{model}_{nrun}"] = vae 199 | else: 200 | vae = REGVELOVI( 201 | adata, 202 | W=W.T, 203 | regulators=TF, 204 | soft_constraint=(model == 'soft') 205 | ) 206 | vae.train() 207 | self.MODEL_TRAINED[f"{model}_{nrun}"] = vae 208 | 209 | return list(self.MODEL_TRAINED.keys()) 210 | 211 | def evaluate( 212 | self, 213 | side_information: str, 214 | side_key:str = None 215 | ) -> pd.DataFrame: 216 | """Evaluate all of trained model under one specific side_information mode, For example, if user know the exact time or stage of cells, user can choose 'Real_Time' as reference; If users has used Pseudotime calculator such as CellRank beforehand, they can also choose 'Pseudo_Time' as reference. 217 | 218 | Parameters 219 | ---------- 220 | side_information 221 | User can choose perspectives to compare RegVelo models, including 'Real_Time', 'Pseudo_Time', 'Stemness_Score','TSI','CBC'. 222 | side_key 223 | Column name of adata.obs which used to store information of selected side_information. For 'Pseudo_Time' and 'Stemness_Score', we provide default side_key, but you can also choose your own side_key as input. 224 | 225 | Returns 226 | ---------- 227 | A dataframe records evaluation performance of all models. 228 | """ 229 | self.validate_input(self.ADATA, side_information=side_information, side_key=side_key) 230 | correlations = [] 231 | 232 | for model, vae in self.MODEL_TRAINED.items(): 233 | set_output(self.ADATA, vae, n_samples = 30, batch_size = self.ADATA.n_obs) 234 | fit_t_mean = self.ADATA.layers['fit_t'].mean(axis = 1) 235 | self.ADATA.obs["latent_time"] = self.min_max_scaling(fit_t_mean) 236 | corr = np.abs(self.calculate( 237 | self.ADATA, side_information, side_key 238 | )) 239 | correlations.append({ 240 | 'Model': model[:-2], 241 | 'Corr':corr, 242 | 'Run':model[-1] 243 | }) 244 | df_name = f"df_{side_information}" 245 | df = pd.DataFrame(correlations) 246 | setattr(self, df_name, df) 247 | return df_name, df 248 | 249 | def calculate( 250 | self, 251 | adata: AnnData, 252 | side_information: str, 253 | side_key: str = None 254 | ): 255 | if side_information in ['Pseudo_Time', 'Stemness_Score', 'Real_Time']: 256 | if side_information in ['Pseudo_Time', 'Stemness_Score'] and side_key is None: 257 | side_key = self.side_info_dict[side_information] 258 | return scipy.stats.spearmanr(self.ADATA.obs[side_key].values, self.ADATA.obs['latent_time'])[0] 259 | elif side_information == 'TSI': 260 | thresholds = np.linspace(0.1,1,21)[:20] 261 | vk = cr.kernels.VelocityKernel(self.ADATA) 262 | vk.compute_transition_matrix() 263 | ck = cr.kernels.ConnectivityKernel(self.ADATA).compute_transition_matrix() 264 | kernel = 0.8 * vk + 0.2 * ck 265 | estimator = cr.estimators.GPCCA(kernel) 266 | estimator.compute_macrostates(n_states=self.N_STATES, n_cells=30, cluster_key=side_key) 267 | return np.mean(get_tsi_score(self.ADATA, thresholds, side_key, self.TERMINAL_STATES, estimator)) 268 | elif side_information == 'CBC': 269 | self.ADATA.obs['CBC_key'] = self.ADATA.obs[side_key].astype(str) 270 | vk = cr.kernels.VelocityKernel(self.ADATA) 271 | vk.compute_transition_matrix() 272 | ck = cr.kernels.ConnectivityKernel(self.ADATA).compute_transition_matrix() 273 | kernel = 0.8 * vk + 0.2 * ck 274 | cbc_values = [] 275 | for source, target in self.STATE_TRANSITION: 276 | cbc = kernel.cbc(source = source, target=target, cluster_key='CBC_key', rep = 'X_pca') 277 | cbc_values.append(np.mean(cbc)) 278 | return np.mean(cbc_values) 279 | 280 | def plot_results( 281 | self, 282 | side_information, 283 | figsize = (6, None), 284 | palette = 'lightpink' 285 | ): 286 | """Visualize comparision result by barplot with scatters. The significant mark will only show with n_repeats more than 3, and p < 0.05. 287 | 288 | Paramters 289 | ---------- 290 | side_information 291 | Here choose the side_information you wish to visulize, which must be performed in 'evaluation' step in advance. 292 | figsize 293 | You can choose the size of figure. Default is (6,None), which means the height of the plot are set to change with the number of models. 294 | palette 295 | You can choose the color of barplot. 296 | 297 | Returns 298 | ---------- 299 | A barplot with scatters, represent performance of all models. 300 | """ 301 | df_name = f"df_{side_information}" 302 | data = getattr(self, df_name) 303 | 304 | model_order = data.groupby('Model')['Corr'].mean().sort_values(ascending=False).index.tolist() 305 | num_models = len(model_order) 306 | fig_height = 2 + num_models * 0.5 307 | figsize = (figsize[0], fig_height) 308 | 309 | sns.set(style='whitegrid', font_scale=1.2) 310 | fig, ax = plt.subplots(figsize=figsize) 311 | 312 | 313 | sns.barplot( 314 | y="Model", 315 | x="Corr", 316 | data=data, 317 | order = model_order, 318 | width=0.3, 319 | ci="sd", 320 | capsize=0.1, 321 | errwidth=2, 322 | color=palette, 323 | ax = ax) 324 | sns.stripplot( 325 | y="Model", 326 | x="Corr", 327 | data=data, 328 | order = model_order, 329 | dodge=True, 330 | jitter=0.25, 331 | color="black", 332 | size = 4, 333 | alpha=0.8, 334 | ax = ax 335 | ) 336 | 337 | model_means = data.groupby('Model')['Corr'].mean() 338 | ref_model = model_means.idxmax() 339 | 340 | ref_data = data[data["Model"] == ref_model]["Corr"] 341 | y_ticks = ax.get_yticks() 342 | model_positions = dict(zip(model_order, y_ticks)) 343 | 344 | for target_model in model_order: 345 | if target_model == ref_model: 346 | continue 347 | target_data = data[data["Model"] == target_model]["Corr"] 348 | 349 | if len(target_data) < 3: 350 | continue 351 | 352 | try: 353 | t_stat, p_value = scipy.stats.ttest_rel( 354 | ref_data, 355 | target_data, 356 | alternative="greater" 357 | ) 358 | except ValueError as e: 359 | print(f"Significance test: {ref_model} vs {target_model} failed:{str(e)}") 360 | continue 361 | 362 | if p_value < 0.05: 363 | significance = self.get_significance(p_value) 364 | self._draw_significance_marker( 365 | ax=ax, 366 | start=model_positions[ref_model], 367 | end=model_positions[target_model], 368 | significance=significance, 369 | bracket_height=0.05 370 | ) 371 | 372 | ax.set_title( 373 | f"Prediction based on {side_information}", 374 | fontsize=12, 375 | wrap = True 376 | ) 377 | ax.set_ylabel('') 378 | ax.set_xlabel( 379 | "Prediction Score", 380 | fontsize=12, 381 | labelpad=10 382 | ) 383 | plt.xticks(fontsize=10) 384 | plt.tight_layout() 385 | plt.show() 386 | 387 | def get_significance(self, pvalue): 388 | if pvalue < 0.001: 389 | return "***" 390 | elif pvalue < 0.01: 391 | return "**" 392 | elif pvalue < 0.05: 393 | return "*" 394 | else: 395 | return "ns" 396 | 397 | def _draw_significance_marker( 398 | self, 399 | ax, 400 | start, 401 | end, 402 | significance, 403 | bracket_height=0.05, 404 | linewidth=1.2, 405 | text_offset=0.05): 406 | 407 | if start > end: 408 | start, end = end, start 409 | x_max = ax.get_xlim()[1] 410 | bracket_level = max(start, end) + bracket_height 411 | 412 | ax.plot( 413 | [bracket_level-0.02, bracket_level, bracket_level, bracket_level-0.02], 414 | [start, start, end, end], 415 | color='black', 416 | lw=linewidth, 417 | solid_capstyle="butt", 418 | clip_on=False 419 | ) 420 | 421 | ax.text( 422 | bracket_level + text_offset, 423 | (start + end)/2, 424 | significance, 425 | ha='center', 426 | va='baseline', 427 | color='black', 428 | fontsize=10, 429 | fontweight='bold', 430 | rotation = 90 431 | ) 432 | -------------------------------------------------------------------------------- /regvelo/__init__.py: -------------------------------------------------------------------------------- 1 | """regvelo.""" 2 | 3 | import logging 4 | 5 | from rich.console import Console 6 | from rich.logging import RichHandler 7 | 8 | from regvelo import datasets 9 | from regvelo import tools as tl 10 | from regvelo import plotting as pl 11 | from regvelo import preprocessing as pp 12 | 13 | from ._constants import REGISTRY_KEYS 14 | from ._model import REGVELOVI, VELOVAE 15 | from .ModelComparison import ModelComparison 16 | 17 | import sys # isort:skip 18 | 19 | sys.modules.update({f"{__name__}.{m}": globals()[m] for m in ["tl", "pl", "pp"]}) 20 | 21 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 22 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 23 | try: 24 | import importlib.metadata as importlib_metadata 25 | except ModuleNotFoundError: 26 | import importlib_metadata 27 | 28 | package_name = "regvelo" 29 | __version__ = importlib_metadata.version(package_name) 30 | 31 | logger = logging.getLogger(__name__) 32 | # set the logging level 33 | logger.setLevel(logging.INFO) 34 | 35 | # nice logging outputs 36 | console = Console(force_terminal=True) 37 | if console.is_jupyter is True: 38 | console.is_jupyter = False 39 | ch = RichHandler(show_path=False, console=console, show_time=False) 40 | formatter = logging.Formatter("regvelo: %(message)s") 41 | ch.setFormatter(formatter) 42 | logger.addHandler(ch) 43 | 44 | # this prevents double outputs 45 | logger.propagate = False 46 | 47 | __all__ = [ 48 | "REGVELOVI", 49 | "VELOVAE", 50 | "REGISTRY_KEYS", 51 | "datasets", 52 | "ModelComparison" 53 | ] 54 | -------------------------------------------------------------------------------- /regvelo/_constants.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class _REGISTRY_KEYS_NT(NamedTuple): 5 | X_KEY: str = "X" 6 | U_KEY: str = "U" 7 | 8 | 9 | REGISTRY_KEYS = _REGISTRY_KEYS_NT() 10 | -------------------------------------------------------------------------------- /regvelo/_module.py: -------------------------------------------------------------------------------- 1 | """Main module.""" 2 | from typing import Callable, Iterable, Literal, Optional, Any 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 8 | from scvi.nn import Encoder, FCLayers 9 | from torch import nn as nn 10 | from torch.distributions import Categorical, Dirichlet, MixtureSameFamily, Normal 11 | from torch.distributions import kl_divergence as kl 12 | import torchode as to 13 | from ._constants import REGISTRY_KEYS 14 | from torch import Tensor 15 | import torch.nn.utils.prune as prune 16 | 17 | torch.backends.cudnn.benchmark = True 18 | 19 | def _softplus_inverse(x: np.ndarray) -> np.ndarray: 20 | x = torch.from_numpy(x) 21 | x_inv = torch.where(x > 20, x, x.expm1().log()).numpy() 22 | return x_inv 23 | 24 | class ThresholdPruning(prune.BasePruningMethod): 25 | PRUNING_TYPE = "unstructured" 26 | 27 | def __init__(self, threshold): 28 | self.threshold = threshold 29 | 30 | def compute_mask(self, tensor, default_mask): 31 | return torch.abs(tensor) > self.threshold 32 | 33 | class DecoderVELOVI(nn.Module): 34 | """Decodes data from latent space of ``n_input`` dimensions ``n_output``dimensions. 35 | 36 | Uses a fully-connected neural network of ``n_hidden`` layers. 37 | 38 | Parameters 39 | ---------- 40 | n_input 41 | The dimensionality of the input (latent space) 42 | n_output 43 | The dimensionality of the output (data space) 44 | n_cat_list 45 | A list containing the number of categories 46 | for each category of interest. Each category will be 47 | included using a one-hot encoding 48 | n_layers 49 | The number of fully-connected hidden layers 50 | n_hidden 51 | The number of nodes per hidden layer 52 | dropout_rate 53 | Dropout rate to apply to each of the hidden layers 54 | inject_covariates 55 | Whether to inject covariates in each layer, or just the first (default). 56 | use_batch_norm 57 | Whether to use batch norm in layers 58 | use_layer_norm 59 | Whether to use layer norm in layers 60 | linear_decoder 61 | Whether to use linear decoder for time 62 | """ 63 | 64 | def __init__( 65 | self, 66 | n_input: int, 67 | n_output: int, 68 | n_cat_list: Iterable[int] = None, 69 | n_layers: int = 1, 70 | n_hidden: int = 128, 71 | inject_covariates: bool = True, 72 | use_batch_norm: bool = True, 73 | use_layer_norm: bool = False, 74 | dropout_rate: float = 0.0, 75 | linear_decoder: bool = False, 76 | **kwargs, 77 | ): 78 | super().__init__() 79 | self.n_output = n_output 80 | self.linear_decoder = linear_decoder 81 | self.rho_first_decoder = FCLayers( 82 | n_in=n_input, 83 | n_out=n_hidden if not linear_decoder else n_output, 84 | n_cat_list=n_cat_list, 85 | n_layers=n_layers if not linear_decoder else 1, 86 | n_hidden=n_hidden, 87 | dropout_rate=dropout_rate, 88 | inject_covariates=inject_covariates, 89 | use_batch_norm=use_batch_norm if not linear_decoder else False, 90 | use_layer_norm=use_layer_norm if not linear_decoder else False, 91 | use_activation=not linear_decoder, 92 | bias=not linear_decoder, 93 | **kwargs, 94 | ) 95 | 96 | # rho for induction 97 | self.px_rho_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid()) 98 | 99 | def forward(self, z: torch.Tensor, latent_dim: int = None): 100 | """The forward computation for a single sample. 101 | 102 | #. Decodes the data from the latent space using the decoder network 103 | #. Returns parameters for the ZINB distribution of expression 104 | #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None`` 105 | 106 | Parameters 107 | ---------- 108 | z : 109 | tensor with shape ``(n_input,)`` 110 | cat_list 111 | list of category membership(s) for this sample 112 | 113 | Returns 114 | ------- 115 | 4-tuple of :py:class:`torch.Tensor` 116 | parameters for the ZINB distribution of expression 117 | 118 | """ 119 | z_in = z 120 | if latent_dim is not None: 121 | mask = torch.zeros_like(z) 122 | mask[..., latent_dim] = 1 123 | z_in = z * mask 124 | # The decoder returns values for the parameters of the ZINB distribution 125 | rho_first = self.rho_first_decoder(z_in) 126 | 127 | if not self.linear_decoder: 128 | px_rho = self.px_rho_decoder(rho_first) 129 | else: 130 | px_rho = nn.Sigmoid()(torch.matmul(z_in,torch.ones([1,self.n_output]))) 131 | 132 | return px_rho 133 | 134 | ## define a new class velocity encoder 135 | class velocity_encoder(nn.Module): 136 | """Encode the velocity 137 | 138 | time dependent transcription rate is determined by upstream regulator, velocity could be build on top of this 139 | 140 | Parameters 141 | ---------- 142 | activate 143 | activate function used for modeling transcription rate 144 | bas_alpha 145 | adding base transcription rate 146 | n_int 147 | number of genes 148 | """ 149 | def __init__( 150 | self, 151 | activate: str = "softplus", 152 | base_alpha: bool = True, 153 | n_int: int = 5, 154 | ): 155 | super().__init__() 156 | self.n_int = n_int 157 | self.fc1 = nn.Linear(n_int, n_int) 158 | self.activate = activate 159 | self.base_alpha = base_alpha 160 | 161 | def _set_mask_grad(self): 162 | self.hooks = [] 163 | 164 | def _hook_mask_no_regulator(grad): 165 | return grad * self.mask_m 166 | 167 | w_grn = self.fc1.weight.register_hook(_hook_mask_no_regulator) 168 | self.hooks.append(w_grn) 169 | 170 | 171 | ## TODO: regularizing the jacobian 172 | def GRN_Jacobian(self,s): 173 | 174 | if self.activate is not "OR": 175 | if self.base_alpha is not True: 176 | grn = self.fc1.weight 177 | #grn = grn - self.lamb_I 178 | alpha_unconstr = torch.matmul(s,grn.T) 179 | else: 180 | alpha_unconstr = self.fc1(s) 181 | 182 | if self.activate == "softplus": 183 | coef = (torch.sigmoid(alpha_unconstr)) 184 | if self.activate == "sigmoid": 185 | coef = (torch.sigmoid(alpha_unconstr))*(1 - torch.sigmoid(alpha_unconstr))*self.alpha_unconstr_max 186 | else: 187 | coef = (1 / (torch.nn.functional.softsign(s) + 1)) * (1 / (1 + torch.abs(s - 0.5))**2) * torch.exp(self.fc1(torch.log(torch.nn.functional.softsign(s - 0.5)+1))) 188 | 189 | if coef.dim() > 1: 190 | coef = coef.mean(0) 191 | Jaco_m = torch.matmul(torch.diag(coef), self.fc1.weight) 192 | 193 | 194 | return Jaco_m 195 | 196 | def GRN_Jacobian2(self,s): 197 | 198 | if self.base_alpha is not True: 199 | grn = self.fc1.weight 200 | alpha_unconstr = torch.matmul(s,grn.T) 201 | else: 202 | alpha_unconstr = self.fc1(s) 203 | 204 | if self.activate == "softplus": 205 | coef = (torch.sigmoid(alpha_unconstr)) 206 | if self.activate == "sigmoid": 207 | coef = (torch.sigmoid(alpha_unconstr))*(1 - torch.sigmoid(alpha_unconstr))*self.alpha_unconstr_max 208 | 209 | # Perform element-wise multiplication 210 | Jaco = coef.unsqueeze(-1) * self.fc1.weight.unsqueeze(0) 211 | 212 | # Transpose and reshape to get the final 3D tensor with dimensions (m, n, n) 213 | Jaco = Jaco.reshape(s.shape[0], s.shape[1], s.shape[1]) 214 | 215 | return Jaco 216 | 217 | 218 | def transcription_rate(self,s): 219 | if self.activate is not "OR": 220 | if self.base_alpha is not True: 221 | grn = self.fc1.weight 222 | #grn = grn - self.lamb_I 223 | alpha_unconstr = torch.matmul(s,grn.T) 224 | else: 225 | alpha_unconstr = self.fc1(s) 226 | 227 | if self.activate == "softplus": 228 | alpha = torch.clamp(F.softplus(alpha_unconstr),0,50) 229 | elif self.activate == "sigmoid": 230 | alpha = torch.sigmoid(alpha_unconstr)*self.alpha_unconstr_max 231 | else: 232 | raise NotImplementedError 233 | elif self.activate is "OR": 234 | alpha = torch.exp(self.fc1(torch.log(torch.nn.functional.softsign(s - 0.5)+1))) 235 | 236 | return alpha 237 | 238 | ## TODO: introduce sparsity in the model 239 | def forward(self,t, u, s): 240 | ## split x into unspliced and spliced readout 241 | ## x is a matrix with (G*2), in which row is a subgraph (batch) 242 | beta = torch.clamp(F.softplus(self.beta_mean_unconstr), 0, 50) 243 | gamma = torch.clamp(F.softplus(self.gamma_mean_unconstr), 0, 50) 244 | if self.activate is not "OR": 245 | if self.base_alpha is not True: 246 | grn = self.fc1.weight 247 | alpha_unconstr = torch.matmul(s,grn.T) 248 | else: 249 | alpha_unconstr = self.fc1(s) 250 | 251 | if self.activate == "softplus": 252 | alpha = torch.clamp(F.softplus(alpha_unconstr),0,50) 253 | elif self.activate == "sigmoid": 254 | alpha = torch.sigmoid(alpha_unconstr)*self.alpha_unconstr_max 255 | elif self.activate is "OR": 256 | alpha = torch.exp(self.fc1(torch.log(torch.nn.functional.softsign(s - 0.5)+1))) 257 | else: 258 | raise NotImplementedError 259 | 260 | ## Predict velocity 261 | du = alpha - beta*u 262 | ds = beta*u - gamma*s 263 | 264 | return du,ds 265 | 266 | class v_encoder_batch(nn.Module): 267 | """Batching the velocity 268 | 269 | Parameters 270 | ---------- 271 | num_g 272 | number of genes 273 | """ 274 | 275 | def __init__( 276 | self, 277 | num_g: int = 5, 278 | ): 279 | super().__init__() 280 | self.num_g = num_g 281 | 282 | def forward(self,t,y): 283 | """ 284 | in which x is a reshape matrix: (g*n) * 2 285 | we first reshape x into two matrix: unspliced (g*n) and spliced (g*n) 286 | and calculate velocity 287 | then shape back to the vector: (g*n) * 2 288 | the batch number in this case is g*n 289 | """ 290 | u_v = y[:,0] 291 | s_v = y[:,1] 292 | u = u_v.reshape(-1, self.num_g) 293 | s = s_v.reshape(-1, self.num_g) 294 | du, ds = self.v_encoder_class(t, u, s) 295 | 296 | ## reshape du and ds 297 | du = du.reshape(-1, 1) 298 | ds = ds.reshape(-1, 1) 299 | 300 | v = torch.concatenate([du,ds],axis = 1) 301 | 302 | return v 303 | 304 | 305 | # VAE model 306 | class VELOVAE(BaseModuleClass): 307 | """Variational auto-encoder model. 308 | 309 | This is an implementation of the RegVelo model. 310 | 311 | Parameters 312 | ---------- 313 | n_input 314 | Number of input genes. 315 | regulator_index 316 | list index for all regulators. 317 | target_index 318 | list index for all targets. 319 | skeleton 320 | prior gene regulatory graph. 321 | regulator_list 322 | a integer list represents where is the regulators. 323 | activate 324 | Activation function used for modeling transcription rate. 325 | base_alpha 326 | Adding base transcription rate. 327 | n_hidden 328 | Number of nodes per hidden layer. 329 | n_latent 330 | Dimensionality of the latent space. 331 | n_layers 332 | Number of hidden layers used for encoder and decoder NNs. 333 | lam 334 | Regularization parameter for controling the strengths of adding prior knowledge. 335 | lam2 336 | Regularization parameter for controling the strengths of L1 regularization to the Jacobian matrix. 337 | vector_constraint 338 | Regularization on velocity. 339 | bias_constraint 340 | Regularization on bias term (base transcription rate). 341 | dropout_rate 342 | Dropout rate for neural networks 343 | log_variational 344 | Log(data+1) prior to encoding for numerical stability. Not normalization. 345 | latent_distribution 346 | One of 347 | 348 | * ``'normal'`` - Isotropic normal 349 | * ``'ln'`` - Logistic normal with normal params N(0, 1) 350 | use_layer_norm 351 | Whether to use layer norm in layers 352 | var_activation 353 | Callable used to ensure positivity of the variational distributions' variance. 354 | When `None`, defaults to `torch.exp`. 355 | """ 356 | 357 | def __init__( 358 | self, 359 | n_input: int, 360 | regulator_index: list, 361 | target_index: list, 362 | skeleton: torch.Tensor, 363 | regulator_list: list, 364 | activate: Literal["sigmoid", "softplus"] = "softplus", 365 | base_alpha: bool = True, 366 | n_hidden: int = 128, 367 | n_latent: int = 10, 368 | n_layers: int = 1, 369 | lam: float = 1, 370 | lam2: float = 0, 371 | vector_constraint: bool = True, 372 | alpha_constraint: float = 0.1, 373 | bias_constraint: bool = True, 374 | dropout_rate: float = 0.1, 375 | log_variational: bool = False, 376 | latent_distribution: str = "normal", 377 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", 378 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", 379 | var_activation: Optional[Callable] = torch.nn.Softplus(), 380 | gamma_unconstr_init: Optional[np.ndarray] = None, 381 | alpha_unconstr_init: Optional[np.ndarray] = None, 382 | alpha_1_unconstr_init: Optional[np.ndarray] = None, 383 | x0: Optional[np.ndarray] = None, 384 | t0: Optional[np.ndarray] = None, 385 | t_max: float = 20, 386 | linear_decoder: bool = False, 387 | soft_constraint: bool = True, 388 | auto_regulation: bool = False, 389 | ): 390 | super().__init__() 391 | self.n_latent = n_latent 392 | self.log_variational = log_variational 393 | self.latent_distribution = latent_distribution 394 | self.n_input = n_input 395 | self.t_max = t_max 396 | self.lamba = lam 397 | self.lamba2 = lam2 398 | self.vector_constraint = vector_constraint 399 | self.alpha_constraint = alpha_constraint 400 | self.bias_constraint = bias_constraint 401 | self.soft_constraint = soft_constraint 402 | 403 | 404 | n_genes = n_input * 2 405 | n_targets = sum(target_index) 406 | n_regulators = sum(regulator_index) 407 | self.n_targets = int(n_targets) 408 | self.n_regulators = int(n_regulators) 409 | self.regulator_index = regulator_index 410 | self.target_index = target_index 411 | 412 | # degradation for each target gene 413 | if gamma_unconstr_init is None: 414 | self.gamma_mean_unconstr = torch.nn.Parameter(-1 * torch.ones(n_targets)) 415 | else: 416 | self.gamma_mean_unconstr = torch.nn.Parameter( 417 | torch.from_numpy(gamma_unconstr_init) 418 | ) 419 | 420 | # splicing for each target gene 421 | # first samples around 1 422 | self.beta_mean_unconstr = torch.nn.Parameter(0.5 * torch.ones(n_targets)) 423 | 424 | # transcription (bias term for target gene transcription rate function) 425 | if alpha_unconstr_init is None: 426 | self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_targets)) 427 | else: 428 | #self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_targets)) 429 | self.alpha_unconstr = torch.nn.Parameter( 430 | torch.from_numpy(alpha_unconstr_init) 431 | ) 432 | 433 | # TODO: Add `require_grad` 434 | ## The maximum transcription rate (alpha_1) for each target gene 435 | if alpha_1_unconstr_init is None: 436 | self.alpha_1_unconstr = torch.nn.Parameter(torch.ones(n_targets)) 437 | else: 438 | self.alpha_1_unconstr = torch.nn.Parameter( 439 | torch.from_numpy(alpha_1_unconstr_init) 440 | ) 441 | self.alpha_1_unconstr.data = self.alpha_1_unconstr.data.float() 442 | 443 | # likelihood dispersion 444 | # for now, with normal dist, this is just the variance for target genes 445 | self.scale_unconstr_targets = torch.nn.Parameter(-1 * torch.ones(n_targets*2, 3)) 446 | 447 | use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" 448 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" 449 | use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" 450 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" 451 | self.use_batch_norm_decoder = use_batch_norm_decoder 452 | 453 | # z encoder goes from the n_input-dimensional data to an n_latent-d 454 | # latent space representation 455 | n_input_encoder = n_genes 456 | self.z_encoder = Encoder( 457 | n_input_encoder, 458 | n_latent, 459 | n_layers=n_layers, 460 | n_hidden=n_hidden, 461 | dropout_rate=dropout_rate, 462 | distribution=latent_distribution, 463 | use_batch_norm=use_batch_norm_encoder, 464 | use_layer_norm=use_layer_norm_encoder, 465 | var_activation=var_activation, 466 | activation_fn=torch.nn.ReLU, 467 | ) 468 | # decoder goes from n_latent-dimensional space to n_target-d data 469 | n_input_decoder = n_latent 470 | self.decoder = DecoderVELOVI( 471 | n_input_decoder, 472 | n_targets, 473 | n_layers=n_layers, 474 | n_hidden=n_hidden, 475 | use_batch_norm=use_batch_norm_decoder, 476 | use_layer_norm=use_layer_norm_decoder, 477 | activation_fn=torch.nn.ReLU, 478 | linear_decoder=linear_decoder, 479 | ) 480 | 481 | # define velocity encoder, define velocity vector for target genes 482 | self.v_encoder = velocity_encoder(n_int = n_targets,activate = activate,base_alpha = base_alpha) 483 | self.v_encoder.fc1.weight = torch.nn.Parameter(0 * torch.ones(self.v_encoder.fc1.weight.shape)) 484 | # saved kinetic parameter in velocity encoder module 485 | self.v_encoder.regulator_index = self.regulator_index 486 | self.v_encoder.beta_mean_unconstr = self.beta_mean_unconstr 487 | self.v_encoder.gamma_mean_unconstr = self.gamma_mean_unconstr 488 | self.v_encoder.register_buffer("alpha_unconstr_max", torch.tensor(10.0)) 489 | 490 | # initilize grn (masked parameters) 491 | if self.soft_constraint is not True: 492 | self.v_encoder.register_buffer("mask_m", skeleton) 493 | self.v_encoder._set_mask_grad() 494 | else: 495 | if regulator_list is not None: 496 | skeleton_ref = torch.zeros(skeleton.shape) 497 | skeleton_ref[:,regulator_list] = 1 498 | else: 499 | skeleton_ref = torch.ones(skeleton.shape) 500 | if not auto_regulation: 501 | skeleton_ref[range(skeleton_ref.shape[0]), range(skeleton_ref.shape[1])] = 0 502 | self.v_encoder.register_buffer("mask_m", skeleton_ref) 503 | self.v_encoder._set_mask_grad() 504 | self.v_encoder.register_buffer("mask_m_raw", skeleton) 505 | 506 | 507 | ## define batch velocity vector for numerical solver 508 | self.v_batch = v_encoder_batch(num_g = n_targets) 509 | self.v_batch.v_encoder_class = self.v_encoder 510 | 511 | ## register variable for torchode 512 | if x0 is not None: 513 | self.register_buffer("x0", torch.tensor(x0)) 514 | else: 515 | self.register_buffer("x0", torch.zeros([n_targets,2])) 516 | 517 | ## TODO: follow xiaojie suggestion, update x0 estimate 518 | 519 | if t0 is not None: 520 | self.register_buffer("t0", torch.tensor(t0).reshape(-1,1)) 521 | else: 522 | self.register_buffer("t0", torch.zeros([n_targets,1])) 523 | 524 | self.register_buffer("dt0", torch.ones([1])) 525 | #self.register_buffer("t0", torch.zeros([1])) 526 | self.register_buffer("target_m",torch.zeros(self.v_encoder.fc1.weight.data.shape)) 527 | 528 | 529 | def _get_inference_input(self, tensors): 530 | spliced = tensors[REGISTRY_KEYS.X_KEY] 531 | unspliced = tensors[REGISTRY_KEYS.U_KEY] 532 | 533 | input_dict = { 534 | "spliced": spliced, 535 | "unspliced": unspliced, 536 | } 537 | return input_dict 538 | 539 | def _get_generative_input(self, tensors, inference_outputs): 540 | z = inference_outputs["z"] 541 | gamma = inference_outputs["gamma"] 542 | beta = inference_outputs["beta"] 543 | alpha_1 = inference_outputs["alpha_1"] 544 | 545 | input_dict = { 546 | "z": z, 547 | "gamma": gamma, 548 | "beta": beta, 549 | "alpha_1": alpha_1, 550 | } 551 | return input_dict 552 | 553 | @auto_move_data 554 | def inference( 555 | self, 556 | spliced, 557 | unspliced, 558 | n_samples=1, 559 | ): 560 | """High level inference method. 561 | 562 | Runs the inference (encoder) model. 563 | """ 564 | spliced_ = spliced 565 | unspliced_ = unspliced 566 | if self.log_variational: 567 | spliced_ = torch.log(0.01 + spliced) 568 | unspliced_ = torch.log(0.01 + unspliced) 569 | 570 | encoder_input = torch.cat((spliced_, unspliced_), dim=-1) 571 | 572 | qz_m, qz_v, z = self.z_encoder(encoder_input) 573 | 574 | if n_samples > 1: 575 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) 576 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) 577 | # when z is normal, untran_z == z 578 | untran_z = Normal(qz_m, qz_v.sqrt()).sample() 579 | z = self.z_encoder.z_transformation(untran_z) 580 | 581 | gamma, beta, alpha_1 = self._get_rates() 582 | 583 | outputs = { 584 | "z": z, 585 | "qz_m": qz_m, 586 | "qz_v": qz_v, 587 | "qzm": qz_m, 588 | "qzv": qz_v, 589 | "gamma": gamma, 590 | "beta": beta, 591 | "alpha_1": alpha_1, 592 | } 593 | return outputs 594 | 595 | def _get_rates(self): 596 | # globals 597 | # degradation for each target gene 598 | gamma = torch.clamp(F.softplus(self.v_encoder.gamma_mean_unconstr), 0, 50) 599 | # splicing for each target gene 600 | beta = torch.clamp(F.softplus(self.v_encoder.beta_mean_unconstr), 0, 50) 601 | # transcription for each target gene (bias term) 602 | alpha_1 = self.alpha_1_unconstr 603 | 604 | return gamma, beta, alpha_1 605 | 606 | @auto_move_data 607 | def generative(self, z, gamma, beta, alpha_1, latent_dim=None): 608 | """Runs the generative model.""" 609 | decoder_input = z 610 | 611 | ## decoder directly decode the latent time of each gene 612 | px_rho = self.decoder(decoder_input, latent_dim=latent_dim) 613 | 614 | scale_unconstr = self.scale_unconstr_targets 615 | scale = F.softplus(scale_unconstr) 616 | 617 | dist_s, dist_u, index, ind_t = self.get_px( 618 | px_rho, 619 | scale, 620 | gamma, 621 | beta, 622 | alpha_1, 623 | ) 624 | 625 | return { 626 | "px_rho": px_rho, 627 | "scale": scale, 628 | "dist_u": dist_u, 629 | "dist_s": dist_s, 630 | "t_sort": index, 631 | "ind_t": ind_t, 632 | } 633 | 634 | def pearson_correlation_loss(self, tensor1, tensor2, eps=1e-6): 635 | # Calculate means 636 | mean1 = torch.mean(tensor1, dim=0) 637 | mean2 = torch.mean(tensor2, dim=0) 638 | 639 | # Calculate covariance 640 | covariance = torch.mean((tensor1 - mean1) * (tensor2 - mean2), dim=0) 641 | 642 | # Calculate standard deviations 643 | std1 = torch.std(tensor1, dim=0,correction = 0) 644 | std2 = torch.std(tensor2, dim=0,correction = 0) 645 | 646 | # Calculate correlation coefficients 647 | correlation_coefficients = covariance / (std1 * std2 + eps) 648 | 649 | # Convert NaNs to 0 (when std1 or std2 are 0) 650 | correlation_coefficients[torch.isnan(correlation_coefficients)] = 0 651 | 652 | # Calculate loss (1 - correlation_coefficient) to minimize correlation 653 | loss = - correlation_coefficients 654 | 655 | return loss 656 | 657 | def loss( 658 | self, 659 | tensors, 660 | inference_outputs, 661 | generative_outputs, 662 | kl_weight: float = 1.0, 663 | n_obs: float = 1.0, 664 | ): 665 | spliced = tensors[REGISTRY_KEYS.X_KEY] 666 | unspliced = tensors[REGISTRY_KEYS.U_KEY] 667 | 668 | ### extract spliced, unspliced readout 669 | regulator_spliced = spliced[:,self.regulator_index] 670 | target_spliced = spliced[:,self.target_index] 671 | target_unspliced = unspliced[:,self.target_index] 672 | 673 | qz_m = inference_outputs["qz_m"] 674 | qz_v = inference_outputs["qz_v"] 675 | beta = inference_outputs["beta"] 676 | 677 | dist_s = generative_outputs["dist_s"] 678 | dist_u = generative_outputs["dist_u"] 679 | t = generative_outputs["ind_t"] 680 | t_sort = generative_outputs["t_sort"].T 681 | 682 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1) 683 | 684 | reconst_loss_s = -dist_s.log_prob(target_spliced) 685 | reconst_loss_u = -dist_u.log_prob(target_unspliced) 686 | 687 | reconst_loss_target = reconst_loss_u.sum(dim=-1) + reconst_loss_s.sum(dim=-1) 688 | 689 | 690 | alpha = self.v_encoder.transcription_rate(regulator_spliced) 691 | target_unspliced_sort = torch.gather(target_unspliced,0,t_sort) ## measure the lag target unspliced readout (t+1) 692 | alpha_sort = torch.gather(alpha,0,t_sort) ## measure the transcription activity (t) 693 | alpha_loss = self.alpha_constraint * self.pearson_correlation_loss(target_unspliced_sort,alpha_sort).sum() / alpha.shape[1] 694 | 695 | ## add velocity constraint 696 | ## regularize the inferred velocity has both negative and positive compartments 697 | 698 | if self.vector_constraint: 699 | du,_ = self.v_encoder(t = 0, u = unspliced, s = spliced) 700 | alpha = self.v_encoder.transcription_rate(s = spliced) 701 | du = alpha - beta * unspliced 702 | velo_loss = 100 * torch.norm(du,dim=1) 703 | else: 704 | velo_loss = 0 705 | 706 | ## add graph constraint 707 | if self.soft_constraint: 708 | ## Using norm function to perform graph regularization 709 | mask_m = 1 - self.v_encoder.mask_m_raw 710 | graph_constr_loss = self.lamba * torch.norm(self.v_encoder.fc1.weight * mask_m) 711 | else: 712 | graph_constr_loss = 0 713 | 714 | Jaco = self.v_encoder.GRN_Jacobian(dist_s.mean.mean(0)) 715 | loss = torch.nn.L1Loss(reduction = "sum") 716 | L1_loss = (self.lamba2)*loss(Jaco,self.target_m) 717 | 718 | ## regularize bias need to be negative 719 | if self.bias_constraint: 720 | bias_regularize = torch.norm(self.v_encoder.fc1.bias + 10) 721 | else: 722 | bias_regularize = 0 723 | 724 | # local loss 725 | kl_local = kl_divergence_z 726 | weighted_kl_local = (kl_divergence_z) * kl_weight 727 | local_loss = torch.mean(reconst_loss_target + weighted_kl_local + velo_loss) 728 | 729 | # total loss 730 | loss = local_loss + alpha_loss + L1_loss + graph_constr_loss + bias_regularize 731 | 732 | loss_recoder = LossOutput( 733 | loss=loss, reconstruction_loss=reconst_loss_target, kl_local=kl_local 734 | ) 735 | 736 | return loss_recoder 737 | 738 | @auto_move_data 739 | def get_px( 740 | self, 741 | px_rho, 742 | scale, 743 | gamma, 744 | beta, 745 | alpha_1, 746 | ) -> torch.Tensor: 747 | 748 | # predict the abundance in induction phase for target genes 749 | ind_t = self.t_max * px_rho 750 | n_cells = px_rho.shape[0] 751 | mean_u, mean_s, index = self._get_induction_unspliced_spliced( 752 | ind_t 753 | ) 754 | 755 | ### only consider induction phase 756 | scale_u = scale[: self.n_targets, 0].expand(n_cells, self.n_targets).sqrt() 757 | scale_s = scale[self.n_targets :, 0].expand(n_cells, self.n_targets).sqrt() 758 | 759 | dist_u = Normal(mean_u, scale_u) 760 | dist_s = Normal(mean_s, scale_s) 761 | 762 | return dist_s, dist_u, index, ind_t 763 | 764 | def root_time(self, t, root=None): 765 | """TODO.""" 766 | t_root = 0 if root is None else t[root] 767 | o = (t >= t_root).int() 768 | t_after = (t - t_root) * o 769 | t_origin,_ = torch.max(t_after, axis=0) 770 | t_before = (t + t_origin) * (1 - o) 771 | 772 | t_rooted = t_after + t_before 773 | 774 | return t_rooted 775 | 776 | def _get_induction_unspliced_spliced( 777 | self, t, eps=1e-6 778 | ): 779 | """ 780 | this function aim to calculate the spliced and unspliced abundance for target genes 781 | 782 | alpha_1: the maximum transcription rate during induction phase for each target gene 783 | beta: the splicing parameter for each target gene 784 | gamma: the degradation parameter for each target gene 785 | 786 | ** the above parameters are saved in v_encoder 787 | t: target gene specific latent time 788 | """ 789 | device = self.device 790 | #t = t.T 791 | 792 | if t.shape[0] > 1: 793 | ## define parameters 794 | _, index = torch.sort(t, dim=0) 795 | index = index.T 796 | dim = t.shape[0] * t.shape[1] 797 | t0 = self.t0.repeat(t.shape[0],1) 798 | dt0 = self.dt0.expand(dim) 799 | x0 = self.x0.repeat(t.shape[0],1) 800 | 801 | t_eval = t.reshape(-1,1) 802 | t_eval = torch.cat((t0,t_eval),dim=1) 803 | 804 | ## set up G batches, Each G represent a module (a target gene centerred regulon) 805 | ## infer the observe gene expression through ODE solver based on x0, t, and velocity_encoder 806 | 807 | term = to.ODETerm(self.v_batch) 808 | step_method = to.Dopri5(term=term) 809 | #step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term) 810 | step_size_controller = to.FixedStepController() 811 | solver = to.AutoDiffAdjoint(step_method, step_size_controller) 812 | #jit_solver = torch.jit.script(solver) 813 | sol = solver.solve(to.InitialValueProblem(y0=x0, t_eval=t_eval), dt0 = dt0) 814 | else: 815 | t_eval = t 816 | t_eval = torch.cat((self.t0,t_eval),dim=1) 817 | ## set up G batches, Each G represent a module (a target gene centerred regulon) 818 | ## infer the observe gene expression through ODE solver based on x0, t, and velocity_encoder 819 | #x0 = x0.double() 820 | 821 | term = to.ODETerm(self.v_encoder) 822 | step_method = to.Dopri5(term=term) 823 | #step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term) 824 | step_size_controller = to.FixedStepController() 825 | solver = to.AutoDiffAdjoint(step_method, step_size_controller) 826 | #jit_solver = torch.jit.script(solver) 827 | sol = solver.solve(to.InitialValueProblem(y0=self.x0, t_eval=t_eval), dt0 = self.dt0) 828 | 829 | ## generate predict results 830 | # the solved results are saved in sol.ys [the number of subsystems, time_stamps, [u,s]] 831 | pre_u = sol.ys[:,1:,0] 832 | pre_s = sol.ys[:,1:,1] 833 | 834 | if t.shape[1] > 1: 835 | unspliced = pre_u.reshape(-1,t.shape[1]) 836 | spliced = pre_s.reshape(-1,t.shape[1]) 837 | else: 838 | unspliced = pre_u.ravel() 839 | spliced = pre_s.ravel() 840 | 841 | return unspliced, spliced, index 842 | 843 | def _get_repression_unspliced_spliced(self, u_0, s_0, beta, gamma, t, eps=1e-6): 844 | unspliced = torch.exp(-beta * t) * u_0 845 | spliced = s_0 * torch.exp(-gamma * t) - ( 846 | beta * u_0 / ((gamma - beta) + eps) 847 | ) * (torch.exp(-gamma * t) - torch.exp(-beta * t)) 848 | return unspliced, spliced 849 | 850 | def sample( 851 | self, 852 | ) -> np.ndarray: 853 | """Not implemented.""" 854 | raise NotImplementedError 855 | 856 | @torch.no_grad() 857 | def get_loadings(self) -> np.ndarray: 858 | """Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder.""" 859 | # This is BW, where B is diag(b) batch norm, W is weight matrix 860 | if self.decoder.linear_decoder is False: 861 | raise ValueError("Model not trained with linear decoder") 862 | w = self.decoder.rho_first_decoder.fc_layers[0][0].weight 863 | if self.use_batch_norm_decoder: 864 | bn = self.decoder.rho_first_decoder.fc_layers[0][1] 865 | sigma = torch.sqrt(bn.running_var + bn.eps) 866 | gamma = bn.weight 867 | b = gamma / sigma 868 | b_identity = torch.diag(b) 869 | loadings = torch.matmul(b_identity, w) 870 | else: 871 | loadings = w 872 | loadings = loadings.detach().cpu().numpy() 873 | 874 | return loadings 875 | 876 | def freeze_mapping(self): 877 | for param in self.z_encoder.parameters(): 878 | param.requires_grad = False 879 | 880 | for param in self.decoder.parameters(): 881 | param.requires_grad = False -------------------------------------------------------------------------------- /regvelo/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from ._datasets import ( 2 | zebrafish_nc, 3 | zebrafish_grn, 4 | murine_nc, 5 | ) 6 | 7 | __all__ = [ 8 | "zebrafish_nc", 9 | "zebrafish_grn", 10 | "murine_nc", 11 | ] -------------------------------------------------------------------------------- /regvelo/datasets/_datasets.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | import pandas as pd 5 | 6 | from scanpy import read 7 | 8 | from scvelo.core import cleanup 9 | from scvelo.read_load import load 10 | 11 | url_adata = "https://drive.google.com/uc?id=1Nzq1F6dGw-nR9lhRLfZdHOG7dcYq7P0i&export=download" 12 | url_grn = "https://drive.google.com/uc?id=1ci_gCwdgGlZ0xSn6gSa_-LlIl9-aDa1c&export=download/" 13 | url_adata_murine_processed = "https://drive.usercontent.google.com/download?id=19bNQfW3jMKEEjpjNdUkVd7KDTjJfqxa5&export=download&authuser=1&confirm=t&uuid=4fdf3051-229b-4ce2-b644-cb390424570a&at=APcmpoxgcuZ5r6m6Fb6N_2Og6tEO:1745354679573" 14 | url_adata_murine_normalized = "https://drive.usercontent.google.com/download?id=1xy2FNYi6Y2o_DzXjRmmCtARjoZ97Ro_w&export=download&authuser=1&confirm=t&uuid=12cf5d23-f549-48d9-b7ec-95411a58589f&at=APcmpoyexgouf243lNygF9yRUkmi:1745997349046" 15 | url_adata_murine_velocyto = "https://drive.usercontent.google.com/download?id=18Bhtb7ruoUxpNt8WMYSaJ1RyoiHOCEjd&export=download&authuser=1&confirm=t&uuid=ecc42202-bc82-4ab1-b2c3-bfc31c99f0df&at=APcmpozsh6tBzkv8NSIZW0VipDJa:1745997422108" 16 | 17 | 18 | def zebrafish_nc(file_path: Union[str, Path] = "data/zebrafish_nc/adata_zebrafish_preprocessed.h5ad"): 19 | """Zebrafish neural crest cells. 20 | 21 | Single cell RNA-seq datasets of zebrafish neural crest cell development across 22 | seven distinct time points using ultra-deep Smart-seq3 technique. 23 | 24 | There are four distinct phases of NC cell development: 1) specification at the NPB, 2) epithelial-to-mesenchymal 25 | transition (EMT) from the neural tube, 3) migration throughout the periphery, 4) differentiation into distinct cell types 26 | 27 | Arguments: 28 | --------- 29 | file_path 30 | Path where to save dataset and read it from. 31 | 32 | Returns 33 | ------- 34 | Returns `adata` object 35 | """ 36 | adata = read(file_path, backup_url=url_adata, sparse=True, cache=True) 37 | return adata 38 | 39 | def zebrafish_grn(file_path: Union[str, Path] = "data/zebrafish_nc/prior_GRN.csv"): 40 | """Zebrafish neural crest cells. 41 | 42 | Single cell RNA-seq datasets of zebrafish neural crest cell development across 43 | seven distinct time points using ultra-deep Smart-seq3 technique. 44 | 45 | There are four distinct phases of NC cell development: 1) specification at the NPB, 2) epithelial-to-mesenchymal 46 | transition (EMT) from the neural tube, 3) migration throughout the periphery, 4) differentiation into distinct cell types 47 | 48 | Arguments: 49 | --------- 50 | file_path 51 | Path where to save dataset and read it from. 52 | 53 | Returns 54 | ------- 55 | Returns `adata` object 56 | """ 57 | grn = pd.read_csv(url_grn, index_col = 0) 58 | grn.to_csv(file_path) 59 | return grn 60 | 61 | def murine_nc(data_type: str = "preprocessed"): 62 | """ 63 | Mouse neural crest cells. 64 | 65 | Single-cell RNA-seq datasets of mouse neural crest cell development, 66 | subset from Qiu, Chengxiang et al. 67 | 68 | The gene regulatory network (GRN) is saved in `adata.uns["skeleton"]`, 69 | which is learned via pySCENIC. 70 | 71 | Parameters 72 | ---------- 73 | data_type : str 74 | Which version of the dataset to load. Must be one of: 75 | - "preprocessed" 76 | - "normalized" 77 | - "velocyto" 78 | 79 | Returns 80 | ------- 81 | AnnData 82 | Annotated data matrix (an `AnnData` object). 83 | """ 84 | valid_types = ["preprocessed", "normalized", "velocyto"] 85 | if data_type not in valid_types: 86 | raise ValueError(f"Invalid data_type: '{data_type}'. Must be one of {valid_types}.") 87 | 88 | file_path = ["data/murine_nc/adata_preprocessed.h5ad","data/murine_nc/adata_gex_normalized.h5ad","data/murine_nc/adata_velocity.h5ad"] 89 | 90 | if data_type == "preprocessed": 91 | adata = read(file_path[0], backup_url=url_adata_murine_processed, sparse=True, cache=True) 92 | elif data_type == "normalized": 93 | adata = read(file_path[1], backup_url=url_adata_murine_normalized, sparse=True, cache=True) 94 | elif data_type == "velocyto": 95 | adata = read(file_path[2], backup_url=url_adata_murine_velocyto, sparse=True, cache=True) 96 | 97 | return adata -------------------------------------------------------------------------------- /regvelo/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from .fate_probabilities import fate_probabilities 2 | from .commitment_score import commitment_score 3 | from .depletion_score import depletion_score 4 | from .get_significance import get_significance 5 | 6 | __all__ = [ 7 | "fate_probabilities", 8 | "commitment_score", 9 | "depletion_score", 10 | "get_significance", 11 | ] 12 | -------------------------------------------------------------------------------- /regvelo/plotting/commitment_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | from anndata import AnnData 5 | from typing import Any 6 | 7 | from .utils import calculate_entropy 8 | 9 | 10 | def commitment_score(adata : AnnData, 11 | lineage_key : str = "lineages_fwd", 12 | **kwargs : Any 13 | ) -> None: 14 | """ 15 | Compute and plot cell fate commitment scores based on fate probabilities. 16 | 17 | Parameters 18 | ---------- 19 | adata : AnnData 20 | Dataset containing fate probabilities. Original dataset or perturbed dataset. 21 | lineage_key : str 22 | The key in .obsm that stores the fate probabilities. 23 | kwargs : Any 24 | Optional 25 | Additional keyword arguments passed to scanpy.pl.umap function. 26 | """ 27 | 28 | if lineage_key not in adata.obsm: 29 | raise KeyError(f"Key '{lineage_key}' not found in `adata.obsm`.") 30 | 31 | p = pd.DataFrame(adata.obsm[lineage_key], columns=adata.obsm[lineage_key].names.tolist()) 32 | score = calculate_entropy(p) 33 | adata.obs["commitment_score"] = np.array(score) 34 | 35 | sc.pl.umap( 36 | adata, 37 | color="commitment_score", 38 | **kwargs 39 | ) 40 | 41 | -------------------------------------------------------------------------------- /regvelo/plotting/depletion_score.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import pandas as pd 4 | from anndata import AnnData 5 | from typing import Union, Sequence, Any, Optional 6 | import cellrank as cr 7 | 8 | import seaborn as sns 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | from ..tools.abundance_test import abundance_test 13 | 14 | def depletion_score(adata : AnnData, 15 | df : pd.DataFrame, 16 | color_label : str = "celltype_cluster", 17 | **kwargs : Any, 18 | ) -> None: 19 | """ 20 | Plot depletion scores. 21 | 22 | Parameters 23 | ---------- 24 | adata : AnnData 25 | Annotated data matrix of original model. 26 | df : pandas.DataFrame 27 | color_label : str 28 | Used for color palette 29 | kwargs : Any 30 | Optional 31 | Additional keyword arguments passed to CellRank and plot functions. 32 | """ 33 | 34 | fontsize = kwargs.get("fontsize", 14) 35 | figsize = kwargs.get("figsize", (12, 6)) 36 | 37 | legend_loc = kwargs.get("legend_loc", "center left") 38 | legend_bbox = kwargs.get("legend_bbox", (1.02, 0.5)) 39 | 40 | xlabel = kwargs.get("xlabel", "TF") 41 | ylabel = kwargs.get("ylabel", "Depletion score") 42 | 43 | plot_kwargs = {k: kwargs[k] for k in ("ax",) if k in kwargs} 44 | 45 | plt.figure(figsize=figsize) 46 | 47 | palette = dict(zip(adata.obs[color_label].cat.categories, adata.uns[f"{color_label}_colors"])) 48 | sns.barplot(x=xlabel, y=ylabel, hue='Terminal state', data=df, palette=palette, dodge=True, **plot_kwargs) 49 | 50 | for i in range(len(df['TF'].unique()) - 1): 51 | plt.axvline(x=i + 0.5, color='gray', linestyle='--') 52 | 53 | plt.ylabel(ylabel, fontsize=fontsize) 54 | plt.xlabel(xlabel, fontsize=fontsize) 55 | plt.xticks(fontsize=fontsize) 56 | plt.yticks(fontsize=fontsize) 57 | 58 | plt.legend( 59 | title='Terminal state', 60 | bbox_to_anchor=legend_bbox, 61 | loc=legend_loc, 62 | borderaxespad=0 63 | ) 64 | 65 | 66 | plt.tight_layout() 67 | plt.show() 68 | -------------------------------------------------------------------------------- /regvelo/plotting/fate_probabilities.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cellrank as cr 3 | from anndata import AnnData 4 | from typing import Union, Sequence, Any 5 | 6 | def fate_probabilities(adata : AnnData, 7 | terminal_state : Union[str, Sequence[str]], 8 | n_states : int, 9 | save_kernel : bool = True, 10 | **kwargs : Any 11 | ) -> None: 12 | """ 13 | Compute transition matrix and fate probabilities toward the terminal states and plot these for each of the 14 | terminal states. 15 | 16 | Parameters 17 | ---------- 18 | adata : AnnData 19 | Annotated data matrix. 20 | terminal_state : str or Sequence[str] 21 | List of terminal states to compute probabilities for. 22 | n_states : int 23 | Number of states to compute probabilities for. 24 | save_kernel : bool 25 | Whether to write the kernel to adata. Default is True. 26 | kwargs : Any 27 | Optional 28 | Additional keyword arguments passed to CellRank functions 29 | """ 30 | 31 | macro_kwargs = {k: kwargs[k] for k in ("cluster_key", "method") if k in kwargs} 32 | compute_fate_probabilities_kwargs = {k: kwargs[k] for k in ("solver", "tol") if k in kwargs} 33 | plot_kwargs = {k: kwargs[k] for k in ("basis", "same_plot", "states", "title") if k in kwargs} 34 | 35 | vk = cr.kernels.VelocityKernel(adata).compute_transition_matrix() 36 | if save_kernel: 37 | vk.write_to_adata() 38 | 39 | estimator = cr.estimators.GPCCA(vk) 40 | estimator.compute_macrostates(n_states=n_states, **macro_kwargs) 41 | estimator.set_terminal_states(terminal_state) 42 | estimator.compute_fate_probabilities(**compute_fate_probabilities_kwargs) 43 | estimator.plot_fate_probabilities(**plot_kwargs) 44 | estimator.compute_lineage_drivers() 45 | plt.show() 46 | 47 | -------------------------------------------------------------------------------- /regvelo/plotting/get_significance.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_significance(pvalue: float) -> str: 4 | """ 5 | Return significance annotation for a p-value. 6 | 7 | Parameters 8 | ---------- 9 | pvalue : float 10 | P-value to interpret. 11 | 12 | Returns 13 | ------- 14 | str 15 | A string indicating the level of significance: 16 | "***" for p < 0.001, 17 | "**" for p < 0.01, 18 | "*" for p < 0.1, 19 | "n.s." (not significant) otherwise. 20 | """ 21 | if pvalue < 0.001: 22 | return "***" 23 | elif pvalue < 0.01: 24 | return "**" 25 | elif pvalue < 0.1: 26 | return "*" 27 | else: 28 | return "n.s." -------------------------------------------------------------------------------- /regvelo/plotting/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import scvelo as scv 4 | 5 | 6 | def calculate_entropy(prob_matrix : np.array 7 | ) -> np.array: 8 | """ 9 | Calculate entropy for each row in a cell fate probability matrix. 10 | 11 | Parameters 12 | ---------- 13 | prob_matrix : np.ndarray 14 | A 2D NumPy array of shape (n_cells, n_lineages) where each row represents 15 | a cell's fate probabilities across different lineages. Each row should sum to 1. 16 | 17 | Returns 18 | ------- 19 | entropy : np.ndarray 20 | A 1D NumPy array of length n_cells containing the entropy values for each cell. 21 | """ 22 | log_probs = np.zeros_like(prob_matrix) 23 | mask = prob_matrix != 0 24 | np.log2(prob_matrix, where=mask, out=log_probs) 25 | 26 | entropy = -np.sum(prob_matrix * log_probs, axis=1) 27 | return entropy 28 | 29 | -------------------------------------------------------------------------------- /regvelo/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess_data import preprocess_data 2 | from .sanity_check import sanity_check 3 | from .set_prior_grn import set_prior_grn 4 | from .filter_genes import filter_genes 5 | 6 | __all__ = [ 7 | "preprocess_data", 8 | "set_prior_grn", 9 | "sanity_check", 10 | "filter_genes" 11 | ] 12 | -------------------------------------------------------------------------------- /regvelo/preprocessing/filter_genes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from anndata import AnnData 5 | 6 | def filter_genes(adata: AnnData) -> AnnData: 7 | """Filter genes in an AnnData object to ensure each gene has upstream regulators. 8 | 9 | The function iteratively refines the skeleton matrix to maintain only genes with regulatory connections. Only used 10 | by `soft_constraint=False` RegVelo model. 11 | 12 | Parameters 13 | ---------- 14 | adata 15 | Annotated data object (AnnData) containing gene expression data, a skeleton matrix of regulatory interactions, 16 | and a list of regulators. 17 | 18 | Returns 19 | ------- 20 | adata 21 | Updated AnnData object with filtered genes and a refined skeleton matrix where all genes have at least one 22 | upstream regulator. 23 | """ 24 | # Initial filtering based on regulators 25 | var_mask = adata.var_names.isin(adata.uns["regulators"]) 26 | 27 | # Filter genes based on `full_names` 28 | adata = adata[:, var_mask].copy() 29 | 30 | # Update skeleton matrix 31 | skeleton = adata.uns["skeleton"].loc[adata.var_names.tolist(), adata.var_names.tolist()] 32 | adata.uns["skeleton"] = skeleton 33 | 34 | # Iterative refinement 35 | while adata.uns["skeleton"].sum(0).min() == 0: 36 | # Update filtering based on skeleton 37 | skeleton = adata.uns["skeleton"] 38 | mask = skeleton.sum(0) > 0 39 | 40 | regulators = adata.var_names[mask].tolist() 41 | print(f"Number of genes: {len(regulators)}") 42 | 43 | # Filter skeleton and update `adata` 44 | skeleton = skeleton.loc[regulators, regulators] 45 | adata.uns["skeleton"] = skeleton 46 | 47 | # Update adata with filtered genes 48 | adata = adata[:, mask].copy() 49 | adata.uns["regulators"] = regulators 50 | adata.uns["targets"] = regulators 51 | 52 | # Re-index skeleton with updated gene names 53 | skeleton_df = pd.DataFrame( 54 | adata.uns["skeleton"], 55 | index=adata.uns["regulators"], 56 | columns=adata.uns["targets"], 57 | ) 58 | adata.uns["skeleton"] = skeleton_df 59 | 60 | return adata -------------------------------------------------------------------------------- /regvelo/preprocessing/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # the folloing code is adapted from velovi 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import scvelo as scv 7 | from anndata import AnnData 8 | from sklearn.preprocessing import MinMaxScaler 9 | 10 | 11 | def preprocess_data( 12 | adata: AnnData, 13 | spliced_layer: Optional[str] = "Ms", 14 | unspliced_layer: Optional[str] = "Mu", 15 | min_max_scale: bool = True, 16 | filter_on_r2: bool = True, 17 | ) -> AnnData: 18 | """Preprocess data. 19 | 20 | This function removes poorly detected genes and minmax scales the data. 21 | 22 | Parameters 23 | ---------- 24 | adata 25 | Annotated data matrix. 26 | spliced_layer 27 | Name of the spliced layer. 28 | unspliced_layer 29 | Name of the unspliced layer. 30 | min_max_scale 31 | Min-max scale spliced and unspliced 32 | filter_on_r2 33 | Filter out genes according to linear regression fit 34 | 35 | Returns 36 | ------- 37 | Preprocessed adata. 38 | """ 39 | if min_max_scale: 40 | scaler = MinMaxScaler() 41 | adata.layers[spliced_layer] = scaler.fit_transform(adata.layers[spliced_layer]) 42 | 43 | scaler = MinMaxScaler() 44 | adata.layers[unspliced_layer] = scaler.fit_transform( 45 | adata.layers[unspliced_layer] 46 | ) 47 | 48 | if filter_on_r2: 49 | scv.tl.velocity(adata, mode="deterministic") 50 | 51 | adata = adata[ 52 | :, np.logical_and(adata.var.velocity_r2 > 0, adata.var.velocity_gamma > 0) 53 | ].copy() 54 | adata = adata[:, adata.var.velocity_genes].copy() 55 | 56 | return adata -------------------------------------------------------------------------------- /regvelo/preprocessing/sanity_check.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | from urllib.request import urlretrieve 5 | 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | import scvelo as scv 10 | from anndata import AnnData 11 | from sklearn.preprocessing import MinMaxScaler 12 | from scipy.spatial.distance import cdist 13 | 14 | def sanity_check( 15 | adata : AnnData, 16 | ) -> AnnData: 17 | 18 | """ 19 | Sanity check 20 | 21 | This function helps to ensure each gene will have at least one regulator. 22 | 23 | Parameters 24 | ---------- 25 | adata 26 | Annotated data matrix. 27 | """ 28 | 29 | gene_name = adata.var.index.tolist() 30 | full_name = adata.uns["regulators"] 31 | index = [i in gene_name for i in full_name] 32 | full_name = full_name[index] 33 | adata = adata[:,full_name].copy() 34 | 35 | W = adata.uns["skeleton"] 36 | W = W[index,:] 37 | W = W[:,index] 38 | 39 | adata.uns["skeleton"] = W 40 | W = adata.uns["network"] 41 | W = W[index,:] 42 | W = W[:,index] 43 | #csgn = csgn[index,:,:] 44 | #csgn = csgn[:,index,:] 45 | adata.uns["network"] = W 46 | 47 | ### 48 | for i in range(1000): 49 | if adata.uns["skeleton"].sum(0).min()>0: 50 | break 51 | else: 52 | W = np.array(adata.uns["skeleton"]) 53 | gene_name = adata.var.index.tolist() 54 | 55 | indicator = W.sum(0) > 0 ## every gene would need to have a upstream regulators 56 | regulators = [gene for gene, boolean in zip(gene_name, indicator) if boolean] 57 | targets = [gene for gene, boolean in zip(gene_name, indicator) if boolean] 58 | print("num regulators: "+str(len(regulators))) 59 | print("num targets: "+str(len(targets))) 60 | W = np.array(adata.uns["skeleton"]) 61 | W = W[indicator,:] 62 | W = W[:,indicator] 63 | adata.uns["skeleton"] = W 64 | 65 | W = np.array(adata.uns["network"]) 66 | W = W[indicator,:] 67 | W = W[:,indicator] 68 | adata.uns["network"] = W 69 | 70 | #csgn = csgn[indicator,:,:] 71 | #csgn = csgn[:,indicator,:] 72 | #adata.uns["csgn"] = csgn 73 | 74 | adata.uns["regulators"] = regulators 75 | adata.uns["targets"] = targets 76 | 77 | W = pd.DataFrame(adata.uns["skeleton"],index = adata.uns["regulators"],columns = adata.uns["targets"]) 78 | W = W.loc[regulators,targets] 79 | adata.uns["skeleton"] = W 80 | W = pd.DataFrame(adata.uns["network"],index = adata.uns["regulators"],columns = adata.uns["targets"]) 81 | W = W.loc[regulators,targets] 82 | adata.uns["network"] = W 83 | adata = adata[:,indicator].copy() 84 | 85 | return adata -------------------------------------------------------------------------------- /regvelo/preprocessing/set_prior_grn.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | from urllib.request import urlretrieve 5 | 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | import scvelo as scv 10 | from anndata import AnnData 11 | from sklearn.preprocessing import MinMaxScaler 12 | from scipy.spatial.distance import cdist 13 | 14 | def set_prior_grn(adata: AnnData, gt_net: pd.DataFrame, keep_dim: bool = False) -> AnnData: 15 | """Adds prior gene regulatory network (GRN) information to an AnnData object. 16 | 17 | Parameters 18 | ---------- 19 | adata : AnnData 20 | Annotated data matrix with gene expression data. 21 | gt_net : pd.DataFrame 22 | Prior gene regulatory network (targets as rows, regulators as columns). 23 | keep_dim : bool, optional 24 | If True, output AnnData retains original dimensions. Default is False. 25 | 26 | Returns 27 | ------- 28 | AnnData 29 | Updated AnnData object with GRN stored in .uns["skeleton"]. 30 | """ 31 | # Identify regulators and targets present in adata 32 | regulator_mask = adata.var_names.isin(gt_net.columns) 33 | target_mask = adata.var_names.isin(gt_net.index) 34 | regulators = adata.var_names[regulator_mask] 35 | targets = adata.var_names[target_mask] 36 | 37 | if keep_dim: 38 | skeleton = pd.DataFrame(0, index=adata.var_names, columns=adata.var_names, dtype=float) 39 | common_targets = list(set(adata.var_names).intersection(gt_net.index)) 40 | common_regulators = list(set(adata.var_names).intersection(gt_net.columns)) 41 | skeleton.loc[common_targets, common_regulators] = gt_net.loc[common_targets, common_regulators] 42 | gt_net = skeleton.copy() 43 | 44 | # Compute correlation matrix based on gene expression layer "Ms" 45 | gex = adata.layers["Ms"] 46 | correlation = 1 - cdist(gex.T, gex.T, metric="correlation") 47 | #correlation = torch.tensor(correlation).float() 48 | correlation = correlation[np.ix_(target_mask, regulator_mask)] 49 | correlation[np.isnan(correlation)] = 0 50 | 51 | # Align and combine ground-truth GRN with expression correlation 52 | filtered_gt = gt_net.loc[targets, regulators] 53 | grn = filtered_gt * correlation 54 | 55 | # Binarize the GRN 56 | grn = (np.abs(grn) >= 0.01).astype(int) 57 | np.fill_diagonal(grn.values, 0) # Remove self-loops 58 | 59 | if keep_dim: 60 | skeleton = pd.DataFrame(0, index=adata.var_names, columns=adata.var_names, dtype=int) 61 | skeleton.loc[grn.columns, grn.index] = grn.T 62 | else: 63 | # Prune genes with no edges 64 | grn = grn.loc[grn.sum(axis=1) > 0, grn.sum(axis=0) > 0] 65 | genes = grn.index.union(grn.columns) 66 | skeleton = pd.DataFrame(0, index=genes, columns=genes, dtype=int) 67 | skeleton.loc[grn.columns, grn.index] = grn.T 68 | 69 | # Subset the adata to GRN genes and store in .uns 70 | adata = adata[:, skeleton.index].copy() 71 | skeleton = skeleton.loc[adata.var_names, adata.var_names] 72 | 73 | adata.uns["regulators"] = adata.var_names.to_numpy() 74 | adata.uns["targets"] = adata.var_names.to_numpy() 75 | adata.uns["skeleton"] = skeleton 76 | adata.uns["network"] = skeleton.copy() 77 | 78 | return adata 79 | -------------------------------------------------------------------------------- /regvelo/tools/TFScanning_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import cellrank as cr 6 | from anndata import AnnData 7 | from scvelo import logging as logg 8 | import os, shutil 9 | from typing import Dict, Optional, Sequence, Tuple, Union 10 | 11 | from .._model import REGVELOVI 12 | from .utils import split_elements, combine_elements 13 | from .abundance_test import abundance_test 14 | 15 | 16 | def TFScanning_func( 17 | model : str, 18 | adata : AnnData, 19 | cluster_label : Optional[str] = None, 20 | terminal_states : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None, 21 | KO_list : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None, 22 | n_states : Optional[Union[int, Sequence[int]]] = None, 23 | cutoff : Optional[Union[int, Sequence[int]]] = 1e-3, 24 | method : Optional[str] = "likelihood", 25 | combined_kernel : Optional[bool] = False, 26 | ) -> Dict[str, Union[float, pd.DataFrame]]: 27 | 28 | """ 29 | Perform in silico TF regulon knock-out screening 30 | 31 | Parameters 32 | ---------- 33 | model 34 | The saved address for the RegVelo model. 35 | adata 36 | Anndata objects. 37 | cluster_label 38 | Key in :attr:`~anndata.AnnData.obs` to associate names and colors with :attr:`terminal_states`. 39 | terminal_states 40 | subset of :attr:`macrostates`. 41 | KO_list 42 | List of TF combinations to simulate knock-out (KO) effects 43 | Can be single TF e.g. geneA 44 | or double TFs e.g. geneB_geneC 45 | example input: ["geneA","geneB_geneC"] 46 | n_states 47 | Number of macrostates to compute. 48 | cutoff 49 | The threshold for determing which links need to be muted, 50 | method 51 | Quantify perturbation effects via `likelihood` or `t-statistics` 52 | combined_kernel 53 | Use combined kernel (0.8*VelocityKernel + 0.2*ConnectivityKernel) 54 | """ 55 | 56 | reg_vae = REGVELOVI.load(model, adata) 57 | adata = reg_vae.add_regvelo_outputs_to_adata(adata = adata) 58 | raw_GRN = reg_vae.module.v_encoder.fc1.weight.detach().clone() 59 | perturb_GRN = reg_vae.module.v_encoder.fc1.weight.detach().clone() 60 | 61 | ## define kernel matrix 62 | vk = cr.kernels.VelocityKernel(adata) 63 | vk.compute_transition_matrix() 64 | ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix() 65 | 66 | if combined_kernel: 67 | g2 = cr.estimators.GPCCA(0.8 * vk + 0.2 * ck) 68 | else: 69 | g2 = cr.estimators.GPCCA(vk) 70 | 71 | ## evaluate the fate prob on original space 72 | g2.compute_macrostates(n_states=n_states, n_cells = 30, cluster_key=cluster_label) 73 | # set a high number of states, and merge some of them and rename 74 | if terminal_states is None: 75 | g2.predict_terminal_states() 76 | terminal_states = g2.terminal_states.cat.categories.tolist() 77 | g2.set_terminal_states( 78 | terminal_states 79 | ) 80 | g2.compute_fate_probabilities(solver="direct") 81 | fate_prob = g2.fate_probabilities 82 | sampleID = adata.obs.index.tolist() 83 | fate_name = fate_prob.names.tolist() 84 | fate_prob = pd.DataFrame(fate_prob,index= sampleID,columns=fate_name) 85 | fate_prob_original = fate_prob.copy() 86 | 87 | ## create dictionary 88 | terminal_id = terminal_states.copy() 89 | terminal_type = terminal_states.copy() 90 | for i in terminal_states: 91 | for j in [1,2,3,4,5,6,7,8,9,10]: 92 | terminal_id.append(i+"_"+str(j)) 93 | terminal_type.append(i) 94 | terminal_dict = dict(zip(terminal_id, terminal_type)) 95 | n_states = len(g2.macrostates.cat.categories.tolist()) 96 | 97 | coef = [] 98 | pvalue = [] 99 | for tf in split_elements(KO_list): 100 | perturb_GRN = raw_GRN.clone() 101 | vec = perturb_GRN[:,[i in tf for i in adata.var.index.tolist()]].clone() 102 | vec[vec.abs() > cutoff] = 0 103 | perturb_GRN[:,[i in tf for i in adata.var.index.tolist()]]= vec 104 | reg_vae_perturb = REGVELOVI.load(model, adata) 105 | reg_vae_perturb.module.v_encoder.fc1.weight.data = perturb_GRN 106 | 107 | adata_target = reg_vae_perturb.add_regvelo_outputs_to_adata(adata = adata) 108 | ## perturb the regulations 109 | vk = cr.kernels.VelocityKernel(adata_target) 110 | vk.compute_transition_matrix() 111 | ck = cr.kernels.ConnectivityKernel(adata_target).compute_transition_matrix() 112 | 113 | if combined_kernel: 114 | g2 = cr.estimators.GPCCA(0.8 * vk + 0.2 * ck) 115 | else: 116 | g2 = cr.estimators.GPCCA(vk) 117 | ## evaluate the fate prob on original space 118 | n_states_perturb = n_states 119 | while True: 120 | try: 121 | # Perform some computation in f(a) 122 | g2.compute_macrostates(n_states=n_states_perturb, n_cells = 30, cluster_key=cluster_label) 123 | break 124 | except: 125 | # If an error is raised, increment a and try again, and need to recompute double knock-out reults 126 | n_states_perturb += 1 127 | vk = cr.kernels.VelocityKernel(adata) 128 | vk.compute_transition_matrix() 129 | ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix() 130 | if combined_kernel: 131 | g = cr.estimators.GPCCA(0.8 * vk + 0.2 * ck) 132 | else: 133 | g = cr.estimators.GPCCA(vk) 134 | ## evaluate the fate prob on original space 135 | g.compute_macrostates(n_states=n_states_perturb, n_cells = 30,cluster_key=cluster_label) 136 | ## set a high number of states, and merge some of them and rename 137 | if terminal_states is None: 138 | g.predict_terminal_states() 139 | terminal_states = g.terminal_states.cat.categories.tolist() 140 | g.set_terminal_states( 141 | terminal_states 142 | ) 143 | g.compute_fate_probabilities(solver="direct") 144 | fate_prob = g.fate_probabilities 145 | sampleID = adata.obs.index.tolist() 146 | fate_name = fate_prob.names.tolist() 147 | fate_prob = pd.DataFrame(fate_prob,index= sampleID,columns=fate_name) 148 | 149 | ## intersection the states 150 | terminal_states_perturb = g2.macrostates.cat.categories.tolist() 151 | terminal_states_perturb = list(set(terminal_states_perturb).intersection(terminal_states)) 152 | 153 | g2.set_terminal_states( 154 | terminal_states_perturb 155 | ) 156 | g2.compute_fate_probabilities(solver="direct") 157 | fb = g2.fate_probabilities 158 | sampleID = adata.obs.index.tolist() 159 | fate_name = fb.names.tolist() 160 | fb = pd.DataFrame(fb,index= sampleID,columns=fate_name) 161 | fate_prob2 = pd.DataFrame(columns= terminal_states, index=sampleID) 162 | 163 | for i in terminal_states_perturb: 164 | fate_prob2.loc[:,i] = fb.loc[:,i] 165 | 166 | fate_prob2 = fate_prob2.fillna(0) 167 | arr = np.array(fate_prob2.sum(0)) 168 | arr[arr!=0] = 1 169 | fate_prob = fate_prob * arr 170 | 171 | y = [0] * fate_prob.shape[0] + [1] * fate_prob2.shape[0] 172 | fate_prob2.index = [i + "_perturb" for i in fate_prob2.index] 173 | test_result = abundance_test(fate_prob, fate_prob2, method) 174 | coef.append(test_result.loc[:, "coefficient"]) 175 | pvalue.append(test_result.loc[:, "FDR adjusted p-value"]) 176 | logg.info("Done "+ combine_elements([tf])[0]) 177 | fate_prob = fate_prob_original.copy() 178 | 179 | d = {'TF': KO_list, 'coefficient': coef, 'pvalue': pvalue} 180 | return d -------------------------------------------------------------------------------- /regvelo/tools/TFscreening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import cellrank as cr 6 | from anndata import AnnData 7 | from scvelo import logging as logg 8 | import os, shutil 9 | from typing import Dict, Optional, Sequence, Tuple, Union 10 | 11 | from .._model import REGVELOVI 12 | from .utils import split_elements, get_list_name 13 | from .TFScanning_func import TFScanning_func 14 | 15 | 16 | def TFscreening( 17 | adata : AnnData, 18 | prior_graph : torch.Tensor, 19 | lam : Optional[int] = 1, 20 | lam2 : Optional[int] = 0, 21 | soft_constraint : Optional[bool] = True, 22 | TF_list : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None, 23 | cluster_label : Optional[str] = None, 24 | terminal_states : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None, 25 | KO_list : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None, 26 | n_states : Optional[Union[int, Sequence[int]]] = 8, 27 | cutoff : Optional[float] = 1e-3, 28 | max_nruns : Optional[float] = 5, 29 | method : Optional[str] = "likelihood", 30 | dir : Optional[str] = None 31 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 32 | """ 33 | Perform in silico TF regulon knock-out screening 34 | 35 | Parameters 36 | ---------- 37 | adata 38 | Anndata objects. 39 | prior_graph 40 | A prior graph for RegVelo inference 41 | lam 42 | Regularization parameter for controling the strengths of adding prior knowledge. 43 | lam2 44 | Regularization parameter for controling the strengths of L1 regularization to the Jacobian matrix. 45 | soft_constraint 46 | Apply soft constraint mode RegVelo. 47 | TF_list 48 | The TF list used for RegVelo inference. 49 | cluster_label 50 | Key in :attr:`~anndata.AnnData.obs` to associate names and colors with :attr:`terminal_states`. 51 | terminal_states 52 | subset of :attr:`macrostates`. 53 | KO_list 54 | List of TF combinations to simulate knock-out (KO) effects 55 | can be single TF e.g. geneA 56 | or double TFs e.g. geneB_geneC 57 | example input: ["geneA","geneB_geneC"] 58 | n_states 59 | Number of macrostates to compute. 60 | cutoff 61 | The threshold for determing which links need to be muted ( List[int]: 7 | """Generates a sequence of numbers from 1 to k. If the length of the sequence is less than n, the remaining positions are filled with the value k. 8 | 9 | Parameters 10 | ---------- 11 | k: The last value to appear in the initial sequence. 12 | n: The target length of the sequence. 13 | 14 | Returns 15 | ------- 16 | List[int]: A list of integers from 1 to k, padded with k to length n if necessary. 17 | """ 18 | sequence = list(range(1, k + 1)) 19 | 20 | # If the length of the sequence is already >= n, trim it to n 21 | if len(sequence) >= n: 22 | return sequence[:n] 23 | 24 | # Fill the rest of the sequence with the number k 25 | sequence.extend([k] * (n - len(sequence))) 26 | 27 | return sequence 28 | 29 | 30 | def plot_tsi( 31 | adata: AnnData, 32 | kernel: Any, 33 | threshold: float, 34 | terminal_states: Set[str], 35 | cluster_key: str, 36 | max_states: int = 12, 37 | ) -> List[int]: 38 | """Calculate the number of unique terminal states for each macrostate count. 39 | 40 | Parameters 41 | ---------- 42 | adata 43 | Annotated data matrix (e.g., from single-cell experiments). 44 | kernel 45 | Computational kernel used to compute macrostates and predict terminal states. 46 | threshold 47 | Stability threshold for predicting terminal states. 48 | terminal_states 49 | Set of known terminal state names to match against. 50 | cluster_key 51 | Key in `adata.obs` that identifies cluster assignments for cells. 52 | max_states 53 | Maximum number of macrostates to consider. Default is 12. 54 | 55 | Returns 56 | ------- 57 | list of int 58 | A list where each entry represents the count of unique terminal states found 59 | at each macrostate count from 1 to `max_states`. 60 | """ 61 | # Create a mapping of state identifiers to their corresponding types 62 | all_states = list(set(adata.obs[cluster_key].tolist())) 63 | all_id = all_states.copy() 64 | all_type = all_states.copy() 65 | for state in all_states: 66 | for i in range(1, max_states + 1): 67 | all_id.append(f"{state}_{i}") 68 | all_type.append(state) 69 | all_dict = dict(zip(all_id, all_type)) 70 | 71 | pre_value = [] 72 | 73 | for num_macro in range(1, max_states): 74 | try: 75 | # Compute macrostates and predict terminal states 76 | kernel.compute_macrostates(n_states=num_macro, cluster_key=cluster_key) 77 | kernel.predict_terminal_states(stability_threshold=threshold) 78 | 79 | # Map terminal states to their types using `all_dict` 80 | pre_terminal = kernel.terminal_states.cat.categories.tolist() 81 | subset_dict = {key: all_dict[key] for key in pre_terminal} 82 | pre_terminal_names = list(set(subset_dict.values())) 83 | 84 | # Count overlap with known terminal states 85 | pre_value.append(len(set(pre_terminal_names).intersection(terminal_states))) 86 | 87 | except: # noqa 88 | # Log error and repeat the last valid value or use 0 if empty 89 | pre_value.append(pre_value[-1] if pre_value else 0) 90 | 91 | return pre_value 92 | 93 | 94 | def get_tsi_score( 95 | adata: AnnData, 96 | points: List[float], 97 | cluster_key: str, 98 | terminal_states: Set[str], 99 | kernel: Any, 100 | max_states: int = 12, 101 | ) -> List[float]: 102 | """Calculate the Terminal State Integration (TSI) score for given thresholds. 103 | 104 | Parameters 105 | ---------- 106 | adata 107 | Annotated data matrix (e.g., from single-cell experiments). 108 | points 109 | List of threshold values to evaluate for stability of terminal states. 110 | cluster_key 111 | Key in `adata.obs` to access cluster assignments for cells. 112 | terminal_states 113 | Set of known terminal state names to match against. 114 | kernel 115 | Computational kernel used to compute macrostates and predict terminal states. 116 | max_states 117 | Maximum number of macrostates to consider. Default is 12. 118 | 119 | Returns 120 | ------- 121 | list of float 122 | A list of TSI scores, one for each threshold in `points`. Each score represents 123 | the normalized area under the staircase function compared to the goal sequence. 124 | """ 125 | # Define the goal sequence and calculate its area 126 | x_values = range(max_states) 127 | y_values = [0] + generate_sequence(len(terminal_states), max_states - 1) 128 | area_gs = sum((x_values[i + 1] - x_values[i]) * y_values[i] for i in range(len(x_values) - 1)) 129 | 130 | tsi_score = [] 131 | 132 | for threshold in points: 133 | # Compute the staircase function for the current threshold 134 | pre_value = plot_tsi(adata, kernel, threshold, terminal_states, cluster_key, max_states) 135 | y_values = [0] + pre_value 136 | 137 | # Calculate the area under the staircase function 138 | area_velo = sum((x_values[i + 1] - x_values[i]) * y_values[i] for i in range(len(x_values) - 1)) 139 | 140 | # Compute the normalized TSI score and append to results 141 | tsi_score.append(area_velo / area_gs if area_gs else 0) 142 | 143 | return tsi_score 144 | -------------------------------------------------------------------------------- /regvelo/tools/abundance_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from scipy.stats import ranksums, ttest_ind 7 | from sklearn.metrics import roc_auc_score 8 | 9 | import os, shutil 10 | 11 | from .._model import REGVELOVI 12 | from .utils import p_adjust_bh 13 | 14 | 15 | def abundance_test( 16 | prob_raw : pd.DataFrame, 17 | prob_pert : pd.DataFrame, 18 | method : str = "likelihood" 19 | ) -> pd.DataFrame: 20 | """ 21 | Perform an abundance test between two probability datasets. 22 | 23 | Parameters 24 | ---------- 25 | prob_raw : pd.DataFrame 26 | Raw probabilities dataset. 27 | prob_pert : pd.DataFrame 28 | Perturbed probabilities dataset. 29 | method : str, optional (default="likelihood") 30 | Method to calculate scores: "likelihood" or "t-statistics". 31 | 32 | Returns 33 | ------- 34 | pd.DataFrame 35 | Dataframe with coefficients, p-values, and FDR adjusted p-values. 36 | """ 37 | y = [1] * prob_raw.shape[0] + [0] * prob_pert.shape[0] 38 | X = pd.concat([prob_raw, prob_pert], axis=0) 39 | 40 | table = [] 41 | for i in range(prob_raw.shape[1]): 42 | pred = np.array(X.iloc[:, i]) 43 | if np.sum(pred) == 0: 44 | score, pval = np.nan, np.nan 45 | else: 46 | pval = ranksums(pred[np.array(y) == 0], pred[np.array(y) == 1])[1] 47 | if method == "t-statistics": 48 | score = ttest_ind(pred[np.array(y) == 0], pred[np.array(y) == 1])[0] 49 | elif method == "likelihood": 50 | score = roc_auc_score(y, pred) 51 | else: 52 | raise NotImplementedError("Supported methods are 't-statistics' and 'likelihood'.") 53 | 54 | table.append(np.expand_dims(np.array([score, pval]), 0)) 55 | 56 | table = np.concatenate(table, axis=0) 57 | table = pd.DataFrame(table, index=prob_raw.columns, columns=["coefficient", "p-value"]) 58 | table["FDR adjusted p-value"] = p_adjust_bh(table["p-value"].tolist()) 59 | return table -------------------------------------------------------------------------------- /regvelo/tools/depletion_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from anndata import AnnData 4 | from typing import Union, Sequence, Any, Optional, Tuple 5 | import cellrank as cr 6 | 7 | from .abundance_test import abundance_test 8 | 9 | def depletion_score(perturbed : dict[str, AnnData], 10 | baseline : AnnData, 11 | terminal_state : Union[str, Sequence[str]], 12 | **kwargs : Any, 13 | ) -> Tuple[pd.DataFrame, dict[str, AnnData]]: 14 | """ 15 | Compute depletion scores. 16 | 17 | Parameters 18 | ---------- 19 | perturbed : dict[str, AnnData] 20 | Dictionary mapping TF candidates to their perturbed AnnData objects. 21 | baseline : AnnData 22 | Annotated data matrix. Fate probabilities already computed. 23 | terminal_state : str or Sequence[str] 24 | List of terminal states to compute probabilities for. 25 | kwargs : Any 26 | Optional 27 | Additional keyword arguments passed to CellRank and plot functions. 28 | 29 | Returns 30 | ------- 31 | Tuple[pd.DataFrame, dict[str, AnnData]] 32 | A tuple containing: 33 | 34 | - **df** – Summary of depletion scores and associated statistics. 35 | - **adata_perturb_dict** – Dictionary mapping TFs to their perturbed AnnData objects. 36 | """ 37 | 38 | macro_kwargs = {k: kwargs[k] for k in ("n_states", "n_cells", "cluster_key", "method") if k in kwargs} 39 | compute_fate_probabilities_kwargs = {k: kwargs[k] for k in ("solver", "tol") if k in kwargs} 40 | 41 | if "lineages_fwd" not in baseline.obsm: 42 | raise KeyError("Lineages not found in baseline.obsm. Please compute lineages first.") 43 | 44 | if isinstance(terminal_state, str): 45 | terminal_state = [terminal_state] 46 | 47 | # selecting indices of cells that have reached a terminal state 48 | ct_indices = { 49 | ct: baseline.obs["term_states_fwd"][baseline.obs["term_states_fwd"] == ct].index.tolist() 50 | for ct in terminal_state 51 | } 52 | 53 | fate_prob_perturb = {} 54 | for TF, adata_target_perturb in perturbed.items(): 55 | vk = cr.kernels.VelocityKernel(adata_target_perturb).compute_transition_matrix() 56 | vk.write_to_adata() 57 | estimator = cr.estimators.GPCCA(vk) 58 | 59 | estimator.compute_macrostates(**macro_kwargs) 60 | 61 | estimator.set_terminal_states(ct_indices) 62 | estimator.compute_fate_probabilities(**compute_fate_probabilities_kwargs) 63 | 64 | perturbed[TF] = adata_target_perturb 65 | 66 | perturbed_prob = pd.DataFrame( 67 | adata_target_perturb.obsm["lineages_fwd"], 68 | columns=adata_target_perturb.obsm["lineages_fwd"].names.tolist() 69 | )[terminal_state] 70 | 71 | fate_prob_perturb[TF] = perturbed_prob 72 | 73 | fate_prob_raw = pd.DataFrame( 74 | baseline.obsm["lineages_fwd"], 75 | columns=baseline.obsm["lineages_fwd"].names.tolist() 76 | ) 77 | 78 | dfs = [] 79 | for TF, perturbed_prob in fate_prob_perturb.items(): 80 | stats = abundance_test(prob_raw=fate_prob_raw, prob_pert=perturbed_prob) 81 | df = pd.DataFrame( 82 | { 83 | "Depletion score": stats.iloc[:, 0].tolist(), 84 | "p-value": stats.iloc[:, 1].tolist(), 85 | "FDR adjusted p-value": stats.iloc[:, 2].tolist(), 86 | "Terminal state": stats.index.tolist(), 87 | "TF": [TF] * stats.shape[0], 88 | } 89 | ) 90 | dfs.append(df) 91 | 92 | df = pd.concat(dfs) 93 | 94 | df["Depletion score"] = 2 * (0.5 - df["Depletion score"]) 95 | 96 | return df, perturbed 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /regvelo/tools/in_silico_block_simulation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from anndata import AnnData 6 | 7 | import os, shutil 8 | 9 | from .._model import REGVELOVI 10 | 11 | def in_silico_block_simulation( 12 | model : str, 13 | adata : AnnData, 14 | TF : str, 15 | effects : int = 0, 16 | cutoff : int = 1e-3, 17 | customized_GRN : torch.Tensor = None 18 | ) -> tuple: 19 | """ Perform in silico TF regulon knock-out 20 | 21 | Parameters 22 | ---------- 23 | model 24 | The saved address for the RegVelo model. 25 | adata 26 | Anndata objects. 27 | TF 28 | The candidate TF, need to knockout its regulon. 29 | effect 30 | The coefficient for replacing the weights in GRN 31 | cutoff 32 | The threshold for determing which links need to be muted, 33 | customized_GRN 34 | The customized perturbed GRN 35 | """ 36 | 37 | reg_vae_perturb = REGVELOVI.load(model,adata) 38 | 39 | perturb_GRN = reg_vae_perturb.module.v_encoder.fc1.weight.detach().clone() 40 | 41 | if customized_GRN is None: 42 | perturb_GRN[(perturb_GRN[:,[i == TF for i in adata.var.index]].abs()>cutoff).cpu().numpy().reshape(-1),[i == TF for i in adata.var.index]] = effects 43 | reg_vae_perturb.module.v_encoder.fc1.weight.data = perturb_GRN 44 | else: 45 | device = perturb_GRN.device 46 | customized_GRN = customized_GRN.to(device) 47 | reg_vae_perturb.module.v_encoder.fc1.weight.data = customized_GRN 48 | 49 | adata_target_perturb = reg_vae_perturb.add_regvelo_outputs_to_adata(adata = adata) 50 | 51 | return adata_target_perturb, reg_vae_perturb -------------------------------------------------------------------------------- /regvelo/tools/markov_density_simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | def markov_density_simulation( 5 | adata : "AnnData", 6 | T : np.ndarray, 7 | start_indices, 8 | terminal_indices, 9 | terminal_states, 10 | n_steps : int = 100, 11 | n_simulations : int = 200, 12 | method: str = "stepwise", 13 | seed : int = 0, 14 | ): 15 | 16 | """ 17 | Simulate transitions on a velocity-derived Markov transition matrix. 18 | 19 | Parameters 20 | ---------- 21 | adata : AnnData 22 | Annotated data object. 23 | T : np.ndarray 24 | Transition matrix of shape (n_cells, n_cells). 25 | start_indices : array-like 26 | Indices of starting cells. 27 | terminal_indices : array-like 28 | Indices of terminal (absorbing) cells. 29 | terminal_states : array-like 30 | Labels of terminal states corresponding to cells in `adata.obs["term_states_fwd"]`. 31 | n_steps : int, optional 32 | Maximum number of steps per simulation (default: 100). 33 | n_simulations : int, optional 34 | Number of simulations per starting cell (default: 200). 35 | method : {'stepwise', 'one-step'}, optional 36 | Simulation method to use: 37 | - 'stepwise': simulate trajectories step by step. 38 | - 'one-step': sample directly from T^n. 39 | seed : int, optional 40 | Random seed for reproducibility (default: 0). 41 | 42 | Returns 43 | ------- 44 | visits : pd.Series 45 | Number of simulations that ended in each terminal cell. 46 | visits_dens : pd.Series 47 | Proportion of simulations that ended in each terminal cell. 48 | """ 49 | np.random.seed(seed) 50 | 51 | T = np.asarray(T) 52 | start_indices = np.asarray(start_indices) 53 | terminal_indices = np.asarray(terminal_indices) 54 | terminal_set = set(terminal_indices) 55 | n_cells = T.shape[0] 56 | 57 | arrivals_array = np.zeros(n_cells, dtype=int) 58 | 59 | if method == "stepwise": 60 | row_sums = T.sum(axis=1) 61 | cum_T = np.cumsum(T, axis=1) 62 | 63 | for start in start_indices: 64 | for _ in range(n_simulations): 65 | current = start 66 | for _ in range(n_steps): 67 | if row_sums[current] == 0: 68 | break # dead end 69 | r = np.random.rand() 70 | next_state = np.searchsorted(cum_T[current], r * row_sums[current]) 71 | current = next_state 72 | if current in terminal_set: 73 | arrivals_array[current] += 1 74 | break 75 | 76 | elif method == "one-step": 77 | T_end = np.linalg.matrix_power(T, n_steps) 78 | for start in start_indices: 79 | x0 = np.zeros(n_cells) 80 | x0[start] = 1 81 | x_end = x0 @ T_end # final distribution 82 | if x_end.sum() > 0: 83 | samples = np.random.choice(n_cells, size=n_simulations, p=x_end) 84 | for s in samples: 85 | if s in terminal_set: 86 | arrivals_array[s] += 1 87 | else: 88 | raise ValueError(f"Invalid probability distribution: x_end sums to 0 for start index {start}") 89 | else: 90 | raise ValueError("method must be 'stepwise' or 'one-step'") 91 | 92 | total_simulations = n_simulations * len(start_indices) 93 | visits = pd.Series({tid: arrivals_array[tid] for tid in terminal_indices}, dtype=int) 94 | visits_dens = pd.Series({tid: arrivals_array[tid] / total_simulations for tid in terminal_indices}) 95 | 96 | adata.obs[f"visits_{method}"] = np.nan 97 | adata.obs[f"visits_{method}"].iloc[terminal_indices] = visits_dens 98 | 99 | dens_cum = [] 100 | for ts in terminal_states: 101 | ts_cells = np.where(adata.obs["term_states_fwd"] == ts)[0] 102 | density = visits_dens.loc[ts_cells].sum() 103 | dens_cum.append(density) 104 | 105 | print("Proportion of simulations reaching a terminal cell", sum(dens_cum)) 106 | 107 | return visits, visits_dens -------------------------------------------------------------------------------- /regvelo/tools/perturbation_effect.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | from anndata import AnnData 5 | from typing import Union, Sequence 6 | 7 | def perturbation_effect( 8 | adata_perturb : AnnData, 9 | adata : AnnData, 10 | terminal_state : Union[str, Sequence[str]], 11 | ) -> AnnData: 12 | """ 13 | Compute change in fate probabilities towards terminal states after perturbation. Negative values correspond to a decrease in 14 | probabilities, while positive values indicate an increase. 15 | 16 | Parameters 17 | ---------- 18 | adata_perturb : AnnData 19 | Annotated data matrix of perturbed GRN. 20 | adata : AnnData 21 | Annotated data matrix of unperturbed GRN. 22 | terminal_state : str or Sequence[str] 23 | List of terminal states to compute probabilities for. 24 | 25 | Returns 26 | ------- 27 | AnnData 28 | Annotated data matrix with the following added: 29 | - `terminal state perturbation` : Change in fate probabilities towards terminal state after perturbation. 30 | """ 31 | 32 | if isinstance(terminal_state, str): 33 | terminal_state = [terminal_state] 34 | 35 | if "lineages_fwd" in adata.obsm and "lineages_fwd" in adata_perturb.obsm: 36 | perturb_df = pd.DataFrame( 37 | adata_perturb.obsm["lineages_fwd"], columns=adata_perturb.obsm["lineages_fwd"].names.tolist() 38 | ) 39 | original_df = pd.DataFrame( 40 | adata.obsm["lineages_fwd"], columns=adata.obsm["lineages_fwd"].names.tolist() 41 | ) 42 | 43 | for state in terminal_state: 44 | adata.obs[f"perturbation effect on {state}"] = np.array(perturb_df[state] - original_df[state]) 45 | 46 | return adata 47 | else: 48 | raise ValueError("Lineages not computed. Please compute lineages before using this function.") -------------------------------------------------------------------------------- /regvelo/tools/set_output.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from anndata import AnnData 4 | from typing import Any 5 | 6 | 7 | # Code mostly taken from veloVI reproducibility repo 8 | # https://yoseflab.github.io/velovi_reproducibility/estimation_comparison/simulation_w_inferred_rates.html 9 | 10 | 11 | def set_output( 12 | adata : AnnData, 13 | vae : Any, 14 | n_samples: int = 30, 15 | batch_size: int | None = None 16 | ) -> None: 17 | """ 18 | Add inference results to adata. 19 | Parameters 20 | ---------- 21 | adata : AnnData 22 | Annotated data matrix. 23 | vae : Any 24 | RegVelo model 25 | n_samples : int, optional 26 | Number of posterior samples to use for estimation. Default is 30. 27 | batch_size : int, optional 28 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 29 | """ 30 | 31 | latent_time = vae.get_latent_time(n_samples=n_samples, batch_size=batch_size) 32 | velocities = vae.get_velocity(n_samples=n_samples, batch_size=batch_size) 33 | 34 | t = latent_time.values 35 | scaling = 20 / t.max(0) 36 | 37 | adata.layers["velocity"] = velocities / scaling 38 | adata.layers["latent_time_velovi"] = latent_time 39 | 40 | rates = vae.get_rates() 41 | if "alpha" in rates: 42 | adata.var["fit_alpha"] = rates["alpha"] / scaling 43 | adata.var["fit_beta"] = rates["beta"] / scaling 44 | adata.var["fit_gamma"] = rates["gamma"] / scaling 45 | 46 | adata.layers["fit_t"] = latent_time * scaling[np.newaxis, :] 47 | adata.var["fit_scaling"] = 1.0 -------------------------------------------------------------------------------- /regvelo/tools/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import cellrank as cr 6 | from anndata import AnnData 7 | from scvelo import logging as logg 8 | import os,shutil 9 | from typing import Dict, Optional, Sequence, Tuple, Union 10 | 11 | from .._model import REGVELOVI 12 | 13 | def split_elements(character_list): 14 | """split elements.""" 15 | result_list = [] 16 | for element in character_list: 17 | if '_' in element: 18 | parts = element.split('_') 19 | result_list.append(parts) 20 | else: 21 | result_list.append([element]) 22 | return result_list 23 | 24 | def combine_elements(split_list): 25 | """combine elements.""" 26 | result_list = [] 27 | for parts in split_list: 28 | combined_element = "_".join(parts) 29 | result_list.append(combined_element) 30 | return result_list 31 | 32 | def get_list_name(lst): 33 | names = [] 34 | for name, obj in lst.items(): 35 | names.append(name) 36 | return names 37 | 38 | def p_adjust_bh(p): 39 | """Benjamini-Hochberg p-value correction for multiple hypothesis testing.""" 40 | p = np.asfarray(p) 41 | by_descend = p.argsort()[::-1] 42 | by_orig = by_descend.argsort() 43 | steps = float(len(p)) / np.arange(len(p), 0, -1) 44 | q = np.minimum(1, np.minimum.accumulate(steps * p[by_descend])) 45 | return q[by_orig] 46 | 47 | -------------------------------------------------------------------------------- /reproduce_env/regvelo.yaml: -------------------------------------------------------------------------------- 1 | name: regvelo_test 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - bioconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1 10 | - _openmp_mutex=4.5 11 | - absl-py=2.1.0 12 | - alsa-lib=1.2.13 13 | - aom=3.9.1 14 | - arboreto=0.1.6 15 | - attr=2.5.1 16 | - aws-c-auth=0.8.0 17 | - aws-c-cal=0.8.1 18 | - aws-c-common=0.10.5 19 | - aws-c-compression=0.3.0 20 | - aws-c-event-stream=0.5.0 21 | - aws-c-http=0.9.2 22 | - aws-c-io=0.15.3 23 | - aws-c-mqtt=0.11.0 24 | - aws-c-s3=0.7.5 25 | - aws-c-sdkutils=0.2.1 26 | - aws-checksums=0.2.2 27 | - aws-crt-cpp=0.29.7 28 | - aws-sdk-cpp=1.11.458 29 | - azure-core-cpp=1.14.0 30 | - azure-identity-cpp=1.10.0 31 | - azure-storage-blobs-cpp=12.13.0 32 | - azure-storage-common-cpp=12.8.0 33 | - azure-storage-files-datalake-cpp=12.12.0 34 | - bedtools=2.31.1 35 | - binutils=2.43 36 | - binutils_impl_linux-64=2.43 37 | - binutils_linux-64=2.43 38 | - blosc=1.21.6 39 | - bokeh=3.6.2 40 | - brotli=1.1.0 41 | - brotli-bin=1.1.0 42 | - brotli-python=1.1.0 43 | - bzip2=1.0.8 44 | - c-ares=1.34.3 45 | - c-compiler=1.7.0 46 | - ca-certificates=2025.1.31 47 | - cached-property=1.5.2 48 | - cached_property=1.5.2 49 | - cccl=2.5.0 50 | - cellrank=2.0.6 51 | - certifi=2025.1.31 52 | - cffi=1.17.1 53 | - charset-normalizer=3.4.0 54 | - click=8.1.7 55 | - cloudpickle=3.1.0 56 | - colorama=0.4.6 57 | - cpython=3.10.15 58 | - cuda-cccl=12.4.127 59 | - cuda-cccl_linux-64=12.4.127 60 | - cuda-command-line-tools=12.6.2 61 | - cuda-compiler=12.6.2 62 | - cuda-crt-dev_linux-64=12.4.131 63 | - cuda-crt-tools=12.4.131 64 | - cuda-cudart=12.4.127 65 | - cuda-cudart-dev=12.4.127 66 | - cuda-cudart-dev_linux-64=12.4.127 67 | - cuda-cudart-static=12.4.127 68 | - cuda-cudart-static_linux-64=12.4.127 69 | - cuda-cudart_linux-64=12.4.127 70 | - cuda-cuobjdump=12.4.127 71 | - cuda-cupti=12.4.127 72 | - cuda-cupti-dev=12.4.127 73 | - cuda-cuxxfilt=12.4.127 74 | - cuda-documentation=12.4.127 75 | - cuda-driver-dev=12.4.127 76 | - cuda-driver-dev_linux-64=12.4.127 77 | - cuda-gdb=12.4.127 78 | - cuda-libraries=12.4.1 79 | - cuda-libraries-dev=12.6.2 80 | - cuda-libraries-static=12.6.2 81 | - cuda-nsight=12.4.127 82 | - cuda-nvcc=12.4.131 83 | - cuda-nvcc-dev_linux-64=12.4.131 84 | - cuda-nvcc-impl=12.4.131 85 | - cuda-nvcc-tools=12.4.131 86 | - cuda-nvcc_linux-64=12.4.1 87 | - cuda-nvdisasm=12.4.127 88 | - cuda-nvml-dev=12.4.127 89 | - cuda-nvprof=12.4.127 90 | - cuda-nvprune=12.4.127 91 | - cuda-nvrtc=12.4.127 92 | - cuda-nvrtc-dev=12.4.127 93 | - cuda-nvrtc-static=12.4.127 94 | - cuda-nvtx=12.4.127 95 | - cuda-nvvm-dev_linux-64=12.4.131 96 | - cuda-nvvm-impl=12.4.131 97 | - cuda-nvvm-tools=12.4.131 98 | - cuda-nvvp=12.4.127 99 | - cuda-opencl=12.4.127 100 | - cuda-opencl-dev=12.4.127 101 | - cuda-profiler-api=12.4.127 102 | - cuda-runtime=12.4.1 103 | - cuda-sanitizer-api=12.4.127 104 | - cuda-toolkit=12.4.1 105 | - cuda-tools=12.6.2 106 | - cuda-version=12.4 107 | - cuda-visual-tools=12.6.2 108 | - cxx-compiler=1.7.0 109 | - cycler=0.12.1 110 | - cytoolz=1.0.0 111 | - dav1d=1.2.1 112 | - dbus=1.13.6 113 | - docrep=0.3.2 114 | - et_xmlfile=2.0.0 115 | - exceptiongroup=1.2.2 116 | - expat=2.6.4 117 | - ffmpeg=4.4.0 118 | - fftw=3.3.10 119 | - filelock=3.16.1 120 | - font-ttf-dejavu-sans-mono=2.37 121 | - font-ttf-inconsolata=3.000 122 | - font-ttf-source-code-pro=2.038 123 | - font-ttf-ubuntu=0.83 124 | - fontconfig=2.15.0 125 | - fonts-conda-ecosystem=1 126 | - fonts-conda-forge=1 127 | - freetype=2.12.1 128 | - future=1.0.0 129 | - gcc=12.4.0 130 | - gcc_impl_linux-64=12.4.0 131 | - gcc_linux-64=12.4.0 132 | - gds-tools=1.9.1.3 133 | - get-annotations=0.1.2 134 | - gflags=2.2.2 135 | - giflib=5.2.2 136 | - glog=0.7.1 137 | - gmp=6.3.0 138 | - gmpy2=2.1.5 139 | - gnutls=3.6.13 140 | - gxx=12.4.0 141 | - gxx_impl_linux-64=12.4.0 142 | - gxx_linux-64=12.4.0 143 | - h2=4.1.0 144 | - h5py=3.12.1 145 | - hdf5=1.14.3 146 | - hpack=4.0.0 147 | - htslib=1.21 148 | - humanize=4.11.0 149 | - hyperframe=6.0.1 150 | - hypre=2.31.0 151 | - idna=3.10 152 | - importlib-metadata=8.5.0 153 | - importlib_metadata=8.5.0 154 | - importlib_resources=6.4.5 155 | - jinja2=3.1.4 156 | - joblib=1.4.2 157 | - kernel-headers_linux-64=3.10.0 158 | - keyutils=1.6.1 159 | - kiwisolver=1.4.7 160 | - krb5=1.21.3 161 | - lame=3.100 162 | - lcms2=2.16 163 | - ld_impl_linux-64=2.43 164 | - lerc=4.0.0 165 | - libabseil=20240722.0 166 | - libaec=1.1.3 167 | - libarrow=18.1.0 168 | - libarrow-acero=18.1.0 169 | - libarrow-dataset=18.1.0 170 | - libarrow-substrait=18.1.0 171 | - libavif16=1.1.1 172 | - libblas=3.9.0 173 | - libbrotlicommon=1.1.0 174 | - libbrotlidec=1.1.0 175 | - libbrotlienc=1.1.0 176 | - libcap=2.69 177 | - libcblas=3.9.0 178 | - libcrc32c=1.1.2 179 | - libcublas=12.4.5.8 180 | - libcublas-dev=12.4.5.8 181 | - libcublas-static=12.4.5.8 182 | - libcufft=11.2.1.3 183 | - libcufft-dev=11.2.1.3 184 | - libcufft-static=11.2.1.3 185 | - libcufile=1.9.1.3 186 | - libcufile-dev=1.9.1.3 187 | - libcufile-static=1.9.1.3 188 | - libcurand=10.3.5.147 189 | - libcurand-dev=10.3.5.147 190 | - libcurand-static=10.3.5.147 191 | - libcurl=8.10.1 192 | - libcusolver=11.6.1.9 193 | - libcusolver-dev=11.6.1.9 194 | - libcusolver-static=11.6.1.9 195 | - libcusparse=12.3.1.170 196 | - libcusparse-dev=12.3.1.170 197 | - libcusparse-static=12.3.1.170 198 | - libdeflate=1.22 199 | - libedit=3.1.20191231 200 | - libev=4.33 201 | - libevent=2.1.12 202 | - libexpat=2.6.4 203 | - libfabric=1.22.0 204 | - libffi=3.4.2 205 | - libgcc=14.1.0 206 | - libgcc-devel_linux-64=12.4.0 207 | - libgcc-ng=14.1.0 208 | - libgcrypt=1.11.0 209 | - libgcrypt-devel=1.11.0 210 | - libgcrypt-lib=1.11.0 211 | - libgcrypt-tools=1.11.0 212 | - libgfortran=14.1.0 213 | - libgfortran-ng=14.1.0 214 | - libgfortran5=14.1.0 215 | - libglib=2.82.2 216 | - libgomp=14.1.0 217 | - libgoogle-cloud=2.31.0 218 | - libgoogle-cloud-storage=2.31.0 219 | - libgpg-error=1.51 220 | - libgrpc=1.67.1 221 | - libhwloc=2.11.2 222 | - libiconv=1.17 223 | - libjpeg-turbo=3.0.0 224 | - liblapack=3.9.0 225 | - libllvm14=14.0.6 226 | - libnghttp2=1.64.0 227 | - libnl=3.11.0 228 | - libnpp=12.2.5.30 229 | - libnpp-dev=12.2.5.30 230 | - libnpp-static=12.2.5.30 231 | - libnsl=2.0.1 232 | - libnvfatbin=12.4.127 233 | - libnvfatbin-dev=12.4.127 234 | - libnvfatbin-static=12.4.127 235 | - libnvjitlink=12.4.127 236 | - libnvjitlink-dev=12.4.127 237 | - libnvjitlink-static=12.4.127 238 | - libnvjpeg=12.3.1.117 239 | - libnvjpeg-dev=12.3.1.117 240 | - libnvjpeg-static=12.3.1.117 241 | - libopenblas=0.3.28 242 | - libparquet=18.1.0 243 | - libpng=1.6.44 244 | - libprotobuf=5.28.2 245 | - libptscotch=7.0.4 246 | - libre2-11=2024.07.02 247 | - libsanitizer=12.4.0 248 | - libscotch=7.0.4 249 | - libsqlite=3.46.1 250 | - libssh2=1.11.1 251 | - libstdcxx=14.1.0 252 | - libstdcxx-devel_linux-64=12.4.0 253 | - libstdcxx-ng=14.1.0 254 | - libsystemd0=256.7 255 | - libthrift=0.21.0 256 | - libtiff=4.7.0 257 | - libtorch=2.5.1 258 | - libudev1=256.7 259 | - libutf8proc=2.9.0 260 | - libuuid=2.38.1 261 | - libuv=1.49.2 262 | - libvpx=1.11.0 263 | - libwebp=1.4.0 264 | - libwebp-base=1.4.0 265 | - libxcb=1.17.0 266 | - libxcrypt=4.4.36 267 | - libxkbcommon=1.7.0 268 | - libxkbfile=1.1.0 269 | - libxml2=2.13.5 270 | - libzlib=1.3.1 271 | - locket=1.0.0 272 | - lz4=4.3.3 273 | - lz4-c=1.9.4 274 | - markdown=3.6 275 | - markdown-it-py=3.0.0 276 | - markupsafe=3.0.2 277 | - mdurl=0.1.2 278 | - metis=5.1.0 279 | - mpc=1.3.1 280 | - mpfr=4.2.1 281 | - mpi=1.0.1 282 | - mpi4py=4.0.1 283 | - mpich=4.2.3 284 | - mpmath=1.3.0 285 | - msgpack-python=1.1.0 286 | - mudata=0.3.1 287 | - mumps-include=5.7.3 288 | - mumps-mpi=5.7.3 289 | - munkres=1.1.4 290 | - natsort=8.4.0 291 | - ncurses=6.5 292 | - nest-asyncio=1.6.0 293 | - nettle=3.6 294 | - nomkl=1.0 295 | - nsight-compute=2024.1.1.4 296 | - nspr=4.36 297 | - nss=3.105 298 | - numba=0.60.0 299 | - numpy=1.26.4 300 | - numpy_groupies=0.11.2 301 | - numpyro=0.15.3 302 | - ocl-icd=2.3.2 303 | - openh264=2.1.1 304 | - openjpeg=2.5.2 305 | - openpyxl=3.1.5 306 | - openssl=3.4.1 307 | - opt-einsum=3.4.0 308 | - opt_einsum=3.4.0 309 | - orc=2.0.3 310 | - pandas=2.2.3 311 | - parmetis=4.0.3 312 | - partd=1.4.2 313 | - pcre2=10.44 314 | - petsc=3.21.5 315 | - petsc4py=3.21.5 316 | - pip=24.2 317 | - progressbar2=4.5.0 318 | - protobuf=5.28.2 319 | - pthread-stubs=0.4 320 | - pyarrow-core=18.1.0 321 | - pybind11-abi=4 322 | - pycparser=2.22 323 | - pygam=0.9.1 324 | - pygments=2.18.0 325 | - pygpcca=1.0.4 326 | - pynndescent=0.5.13 327 | - pyro-api=0.1.2 328 | - pysocks=1.7.1 329 | - python=3.10.15 330 | - python-dateutil=2.9.0.post0 331 | - python-tzdata=2024.2 332 | - python-utils=3.9.0 333 | - python_abi=3.10 334 | - pytorch-cuda=12.4 335 | - pytorch-lightning=2.4.0 336 | - pytorch-mutex=1.0 337 | - pyyaml=6.0.2 338 | - qhull=2020.2 339 | - rav1e=0.6.6 340 | - rdma-core=54.0 341 | - re2=2024.07.02 342 | - readline=8.2 343 | - requests=2.32.3 344 | - s2n=1.5.9 345 | - samtools=1.21 346 | - scalapack=2.2.0 347 | - scanpy=1.10.3 348 | - scvi-tools=1.2.0 349 | - seaborn=0.13.2 350 | - seaborn-base=0.13.2 351 | - session-info=1.0.0 352 | - setuptools=75.1.0 353 | - six=1.16.0 354 | - sleef=3.7 355 | - slepc=3.21.2 356 | - slepc4py=3.21.2 357 | - snappy=1.2.1 358 | - sortedcontainers=2.4.0 359 | - sparse=0.15.4 360 | - statsmodels=0.14.4 361 | - suitesparse=7.8.3 362 | - superlu=5.2.2 363 | - superlu_dist=9.1.0 364 | - svt-av1=2.3.0 365 | - sysroot_linux-64=2.17 366 | - tbb=2022.0.0 367 | - tblib=3.0.0 368 | - tensorboard=2.18.0 369 | - threadpoolctl=3.5.0 370 | - tk=8.6.13 371 | - toolz=1.0.0 372 | - tornado=6.4.2 373 | - typing-extensions=4.12.2 374 | - typing_extensions=4.12.2 375 | - tzdata=2024b 376 | - ucx=1.17.0 377 | - unicodedata2=15.1.0 378 | - urllib3=2.2.3 379 | - wayland=1.23.1 380 | - wheel=0.44.0 381 | - x264=1!161.3030 382 | - x265=3.5 383 | - xcb-util=0.4.1 384 | - xcb-util-cursor=0.1.5 385 | - xcb-util-image=0.4.0 386 | - xcb-util-keysyms=0.4.1 387 | - xcb-util-renderutil=0.3.10 388 | - xcb-util-wm=0.4.2 389 | - xkeyboard-config=2.43 390 | - xorg-libice=1.1.1 391 | - xorg-libsm=1.2.4 392 | - xorg-libx11=1.8.10 393 | - xorg-libxau=1.0.11 394 | - xorg-libxcomposite=0.4.6 395 | - xorg-libxdamage=1.1.6 396 | - xorg-libxdmcp=1.1.5 397 | - xorg-libxext=1.3.6 398 | - xorg-libxfixes=6.0.1 399 | - xorg-libxi=1.8.2 400 | - xorg-libxrandr=1.5.4 401 | - xorg-libxrender=0.9.11 402 | - xorg-libxtst=1.2.5 403 | - xorg-xorgproto=2024.1 404 | - xyzservices=2024.9.0 405 | - xz=5.2.6 406 | - yaml=0.2.5 407 | - zict=3.0.0 408 | - zlib=1.3.1 409 | - zstandard=0.23.0 410 | - zstd=1.5.6 411 | - pip: 412 | - aiohappyeyeballs==2.4.3 413 | - aiohttp==3.10.10 414 | - aiosignal==1.3.1 415 | - alabaster==0.7.16 416 | - anndata==0.11.0rc2 417 | - annotated-types==0.7.0 418 | - anyio==4.6.0 419 | - appdirs==1.4.4 420 | - argon2-cffi==23.1.0 421 | - argon2-cffi-bindings==21.2.0 422 | - array-api-compat==1.9 423 | - arrow==1.3.0 424 | - asciitree==0.3.3 425 | - asttokens==2.4.1 426 | - astunparse==1.6.3 427 | - async-lru==2.0.4 428 | - async-timeout==4.0.3 429 | - babel==2.16.0 430 | - backoff==2.2.1 431 | - biofluff==3.0.4 432 | - biopython==1.84 433 | - biothings-client==0.3.1 434 | - blessed==1.20.0 435 | - boltons==24.1.0 436 | - boto3==1.35.63 437 | - botocore==1.35.63 438 | - bucketcache==0.12.1 439 | - celloracle==0.20.0 440 | - chex==0.1.87 441 | - colorcet==3.1.0 442 | - comm==0.2.2 443 | - configparser==7.1.0 444 | - contextlib2==21.6.0 445 | - contourpy==1.3.0 446 | - croniter==1.4.1 447 | - csaps==1.2.0 448 | - cython==3.0.11 449 | - dask==2024.2.1 450 | - dask-expr==0.5.3 451 | - dateutils==0.6.12 452 | - debugpy==1.8.7 453 | - deepdiff==7.0.1 454 | - diskcache==5.6.3 455 | - distributed==2024.2.1 456 | - dm-tree==0.1.8 457 | - docopt==0.6.2 458 | - docutils==0.21.2 459 | - editor==1.6.6 460 | - etils==1.9.4 461 | - executing==2.1.0 462 | - fa2-modified==0.3.10 463 | - fastapi==0.115.5 464 | - fasteners==0.19 465 | - feather-format==0.4.1 466 | - flatbuffers==24.3.25 467 | - flax==0.9.0 468 | - fonttools==4.54.1 469 | - fqdn==1.5.1 470 | - frozenlist==1.4.1 471 | - fsspec==2024.9.0 472 | - ftpretty==0.4.0 473 | - gast==0.6.0 474 | - gdown==5.2.0 475 | - genomepy==0.16.1 476 | - gimmemotifs==0.17.2 477 | - goatools==1.4.12 478 | - google-pasta==0.2.0 479 | - grpcio==1.66.2 480 | - h11==0.14.0 481 | - hnswlib==0.8.0 482 | - htseq==2.0.9 483 | - httpcore==1.0.6 484 | - httpx==0.27.2 485 | - igraph==0.11.6 486 | - imagesize==1.4.1 487 | - inquirer==3.4.0 488 | - iprogress==0.4 489 | - ipykernel==6.29.5 490 | - ipython==8.28.0 491 | - ipywidgets==8.1.5 492 | - isoduration==20.11.0 493 | - iteround==1.0.4 494 | - itsdangerous==2.2.0 495 | - jax==0.4.34 496 | - jaxlib==0.4.34 497 | - jedi==0.19.1 498 | - jmespath==1.0.1 499 | - joypy==0.2.6 500 | - json5==0.9.25 501 | - jsonpickle==4.0.0 502 | - jsonpointer==3.0.0 503 | - jupyter==1.1.1 504 | - jupyter-console==6.6.3 505 | - jupyter-events==0.10.0 506 | - jupyter-lsp==2.2.5 507 | - jupyter-server==2.14.2 508 | - jupyter-server-terminals==0.5.3 509 | - jupyterlab==4.2.5 510 | - jupyterlab-server==2.27.3 511 | - jupyterlab-widgets==3.0.13 512 | - keras==3.7.0 513 | - legacy-api-wrap==1.4 514 | - leidenalg==0.10.2 515 | - libclang==18.1.1 516 | - lightning==2.0.9.post0 517 | - lightning-cloud==0.5.70 518 | - lightning-utilities==0.11.7 519 | - llvmlite==0.43.0 520 | - logbook==1.8.0 521 | - logomaker==0.8 522 | - loguru==0.7.2 523 | - loompy==3.0.7 524 | - louvain==0.8.2 525 | - matplotlib==3.6.3 526 | - matplotlib-inline==0.1.7 527 | - ml-collections==0.1.1 528 | - ml-dtypes==0.4.1 529 | - mplscience==0.0.7 530 | - multidict==6.1.0 531 | - multipledispatch==1.0.0 532 | - mygene==3.2.2 533 | - mysql-connector-python==9.1.0 534 | - namex==0.0.8 535 | - networkx==3.4.1 536 | - norns==0.1.6 537 | - nose==1.3.7 538 | - notebook==7.2.2 539 | - notebook-shim==0.2.4 540 | - numcodecs==0.13.1 541 | - numdifftools==0.9.41 542 | - nvidia-cublas-cu12==12.4.5.8 543 | - nvidia-cuda-cupti-cu12==12.4.127 544 | - nvidia-cuda-nvrtc-cu12==12.4.127 545 | - nvidia-cuda-runtime-cu12==12.4.127 546 | - nvidia-cudnn-cu12==9.1.0.70 547 | - nvidia-cufft-cu12==11.2.1.3 548 | - nvidia-curand-cu12==10.3.5.147 549 | - nvidia-cusolver-cu12==11.6.1.9 550 | - nvidia-cusparse-cu12==12.3.1.170 551 | - nvidia-nccl-cu12==2.21.5 552 | - nvidia-nvjitlink-cu12==12.4.127 553 | - nvidia-nvtx-cu12==12.4.127 554 | - optax==0.2.3 555 | - optree==0.13.1 556 | - orbax-checkpoint==0.7.0 557 | - ordered-set==4.1.0 558 | - overrides==7.7.0 559 | - packaging==24.1 560 | - parso==0.8.4 561 | - patsy==0.5.6 562 | - pexpect==4.9.0 563 | - pillow==10.4.0 564 | - platformdirs==4.3.6 565 | - plotly==5.24.1 566 | - prometheus-client==0.21.0 567 | - prompt-toolkit==3.0.48 568 | - propcache==0.2.0 569 | - psutil==6.0.0 570 | - ptyprocess==0.7.0 571 | - pure-eval==0.2.3 572 | - pyarrow==18.0.0 573 | - pybedtools==0.10.0 574 | - pybigwig==0.3.23 575 | - pydantic==2.1.1 576 | - pydantic-core==2.4.0 577 | - pydot==3.0.2 578 | - pyfaidx==0.8.1.3 579 | - pyjwt==2.10.0 580 | - pymde==0.1.18 581 | - pyparsing==3.1.4 582 | - pyro-ppl==1.9.1 583 | - pysam==0.22.1 584 | - python-igraph==0.11.6 585 | - python-json-logger==2.0.7 586 | - python-multipart==0.0.17 587 | - pytz==2024.2 588 | - pyvis==0.3.2 589 | - qnorm==0.8.1 590 | - readchar==4.2.1 591 | - regvelo==0.1.0 592 | - represent==2.1 593 | - rfc3339-validator==0.1.4 594 | - rfc3986-validator==0.1.1 595 | - rich==13.9.2 596 | - rpy2==3.5.17 597 | - runs==1.2.2 598 | - s3transfer==0.10.3 599 | - scib==1.1.5 600 | - scikit-learn==1.5.2 601 | - scipy==1.10.1 602 | - scvelo==0.3.2 603 | - send2trash==1.8.3 604 | - sniffio==1.3.1 605 | - snowballstemmer==2.2.0 606 | - sphinx==7.3.7 607 | - sphinx-autodoc-typehints==2.3.0 608 | - sphinxcontrib-applehelp==2.0.0 609 | - sphinxcontrib-devhelp==2.0.0 610 | - sphinxcontrib-htmlhelp==2.1.0 611 | - sphinxcontrib-jsmath==1.0.1 612 | - sphinxcontrib-qthelp==2.0.0 613 | - sphinxcontrib-serializinghtml==2.0.0 614 | - splicejac==0.0.1 615 | - stack-data==0.6.3 616 | - starlette==0.41.2 617 | - starsessions==1.3.0 618 | - stdlib-list==0.10.0 619 | - sympy==1.13.1 620 | - tenacity==9.0.0 621 | - tensorboard-data-server==0.7.2 622 | - tensorflow==2.18.0 623 | - tensorflow-io-gcs-filesystem==0.37.1 624 | - tensorstore==0.1.66 625 | - termcolor==2.5.0 626 | - terminado==0.18.1 627 | - texttable==1.7.0 628 | - tf-keras==2.18.0 629 | - tomli==2.0.2 630 | - torch==2.5.1 631 | - torchaudio==2.5.1 632 | - torchmetrics==1.4.3 633 | - torchode==1.0.0 634 | - torchtyping==0.1.5 635 | - torchvision==0.20.1 636 | - tqdm==4.66.5 637 | - traitlets==5.14.3 638 | - triton==3.1.0 639 | - typeguard==2.13.3 640 | - types-python-dateutil==2.9.0.20241003 641 | - tzlocal==5.3.1 642 | - umap-learn==0.5.6 643 | - unitvelo==0.2.5.2 644 | - uri-template==1.3.0 645 | - uvicorn==0.32.0 646 | - velocyto==0.17.17 647 | - velovi==0.3.1 648 | - wcwidth==0.2.13 649 | - webcolors==24.8.0 650 | - websocket-client==1.8.0 651 | - websockets==12.0 652 | - werkzeug==3.0.4 653 | - widgetsnbextension==4.0.13 654 | - wrapt==1.16.0 655 | - xarray==2024.9.0 656 | - xdg==6.0.0 657 | - xlsxwriter==3.2.0 658 | - xmod==1.8.1 659 | - xxhash==3.5.0 660 | - yarl==1.15.0 661 | - zarr==2.18.3 662 | - zipp==3.20.2 663 | prefix: /home/icb/weixu.wang/miniconda3/envs/regvelo_test 664 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry 4 | 5 | import setuptools 6 | 7 | if __name__ == "__main__": 8 | setuptools.setup(name="regvelo") 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_regvelo.py: -------------------------------------------------------------------------------- 1 | ## test RegVelo 2 | import numpy as np 3 | import pandas as pd 4 | from scvi.data import synthetic_iid 5 | from regvelo import REGVELOVI 6 | import torch 7 | 8 | def test_regvelo(): 9 | adata = synthetic_iid() 10 | adata.layers["spliced"] = adata.X.copy() 11 | adata.layers["unspliced"] = adata.X.copy() 12 | adata.var_names = "Gene" + adata.var_names 13 | n_gene = len(adata.var_names) 14 | ## create W 15 | grn_matrix = np.random.choice([0, 1], size=(n_gene,n_gene), p=[0.8, 0.2]).T 16 | W = pd.DataFrame(grn_matrix, index=adata.var_names, columns=adata.var_names) 17 | adata.uns["skeleton"] = W 18 | TF_list = adata.var_names.tolist() 19 | 20 | ## training process 21 | W = adata.uns["skeleton"].copy() 22 | W = torch.tensor(np.array(W)) 23 | REGVELOVI.setup_anndata(adata, spliced_layer="spliced", unspliced_layer="unspliced") 24 | 25 | ## Training the model 26 | # hard constraint 27 | reg_vae = REGVELOVI(adata,W=W.T,regulators = TF_list,soft_constraint = False) 28 | reg_vae.train() 29 | # soft constraint 30 | reg_vae = REGVELOVI(adata,W=W.T,regulators = TF_list,soft_constraint = True) 31 | reg_vae.train() 32 | 33 | reg_vae.get_latent_representation() 34 | reg_vae.get_velocity() 35 | reg_vae.get_latent_time() 36 | 37 | reg_vae.history 38 | print(reg_vae) --------------------------------------------------------------------------------