├── .editorconfig
├── .github
└── workflows
│ ├── release.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── codecov.yml
├── docs
├── Makefile
├── _static
│ └── img
│ │ ├── overview_fig.png
│ │ └── perturbation_overview_fig.png
├── _templates
│ ├── autosummary
│ │ └── class.rst
│ └── class_no_inherited.rst
├── about
│ └── index.rst
├── api
│ ├── .Rhistory
│ └── index.md
├── conf.py
├── extensions
│ └── typed_returns.py
├── index.md
├── make.bat
├── references.bib
├── references.md
├── release_notes
│ ├── index.rst
│ └── v0.1.0.rst
└── tutorials
│ ├── index.md
│ ├── index_modelcomp.md
│ ├── index_murine.md
│ ├── index_zebrafish.md
│ ├── modelcomparison
│ └── ModelComp.ipynb
│ ├── murine
│ ├── 01_SCENIC_tutorial.ipynb
│ ├── 02_RegVelo_preparation.ipynb
│ └── 03_perturbation_tutorial_murine.ipynb
│ └── zebrafish
│ ├── _static
│ └── perturbation_metrics.svg
│ └── tutorial.ipynb
├── pyproject.toml
├── readthedocs.yml
├── regvelo
├── ModelComparison.py
├── __init__.py
├── _constants.py
├── _model.py
├── _module.py
├── datasets
│ ├── __init__.py
│ └── _datasets.py
├── plotting
│ ├── __init__.py
│ ├── commitment_score.py
│ ├── depletion_score.py
│ ├── fate_probabilities.py
│ ├── get_significance.py
│ └── utils.py
├── preprocessing
│ ├── __init__.py
│ ├── filter_genes.py
│ ├── preprocess_data.py
│ ├── sanity_check.py
│ └── set_prior_grn.py
└── tools
│ ├── TFScanning_func.py
│ ├── TFscreening.py
│ ├── __init__.py
│ ├── _tsi.py
│ ├── abundance_test.py
│ ├── depletion_score.py
│ ├── in_silico_block_simulation.py
│ ├── markov_density_simulation.py
│ ├── perturbation_effect.py
│ ├── set_output.py
│ └── utils.py
├── reproduce_env
└── regvelo.yaml
├── setup.py
└── tests
├── __init__.py
└── test_regvelo.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | tags:
6 | - "*.*.*"
7 |
8 | jobs:
9 | release:
10 | name: Release
11 | runs-on: ubuntu-latest
12 | steps:
13 | # will use ref/SHA that triggered it
14 | - name: Checkout code
15 | uses: actions/checkout@v3
16 |
17 | - name: Set up Python 3.10
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: "3.9"
21 |
22 | - name: Install poetry
23 | uses: abatilo/actions-poetry@v2.0.0
24 | with:
25 | poetry-version: 1.4.2
26 |
27 | - name: Build project for distribution
28 | run: poetry build
29 |
30 | - name: Check Version
31 | id: check-version
32 | run: |
33 | [[ "$(poetry version --short)" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]] \
34 | || echo ::set-output name=prerelease::true
35 |
36 | - name: Publish to PyPI
37 | env:
38 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }}
39 | run: poetry publish
40 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: regvelo
5 |
6 | on:
7 | push:
8 | branches: [main]
9 | pull_request:
10 | branches: [main]
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ["3.9", "3.10", "3.11"]
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Cache pip
26 | uses: actions/cache@v2
27 | with:
28 | path: ~/.cache/pip
29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
30 | restore-keys: |
31 | ${{ runner.os }}-pip-
32 | - name: Install dependencies
33 | run: |
34 | pip install pytest-cov
35 | pip install .[dev]
36 | - name: Test with pytest
37 | run: |
38 | pytest --cov-report=xml --cov=velovi
39 | - name: After success
40 | run: |
41 | bash <(curl -s https://codecov.io/bash)
42 | pip list
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # DS_Store
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # vscode
135 | .vscode/settings.json
136 | docs/api/reference/
137 | *.h5ad
138 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - commit
6 | - push
7 | minimum_pre_commit_version: 2.16.0
8 | repos:
9 | - repo: https://github.com/psf/black
10 | rev: "23.1.0"
11 | hooks:
12 | - id: black
13 | - repo: https://github.com/asottile/blacken-docs
14 | rev: 1.13.0
15 | hooks:
16 | - id: blacken-docs
17 | - repo: https://github.com/pre-commit/mirrors-prettier
18 | rev: v3.0.0-alpha.6
19 | hooks:
20 | - id: prettier
21 | # Newer versions of node don't work on systems that have an older version of GLIBC
22 | # (in particular Ubuntu 18.04 and Centos 7)
23 | # EOL of Centos 7 is in 2024-06, we can probably get rid of this then.
24 | # See https://github.com/scverse/cookiecutter-scverse/issues/143 and
25 | # https://github.com/jupyterlab/jupyterlab/issues/12675
26 | language_version: "17.9.1"
27 | - repo: https://github.com/charliermarsh/ruff-pre-commit
28 | rev: v0.0.254
29 | hooks:
30 | - id: ruff
31 | args: [--fix, --exit-non-zero-on-fix]
32 | - repo: https://github.com/pre-commit/pre-commit-hooks
33 | rev: v4.4.0
34 | hooks:
35 | - id: detect-private-key
36 | - id: check-ast
37 | - id: end-of-file-fixer
38 | - id: mixed-line-ending
39 | args: [--fix=lf]
40 | - id: trailing-whitespace
41 | - id: check-case-conflict
42 | - repo: local
43 | hooks:
44 | - id: forbid-to-commit
45 | name: Don't commit rej files
46 | entry: |
47 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
48 | Fix the merge conflicts manually and remove the .rej files.
49 | language: fail
50 | files: '.*\.rej$'
51 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, Yosef Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RegVelo: gene-regulatory-informed dynamics of single cells
2 |
3 |
4 |
5 | **RegVelo** is a end-to-end framework to infer regulatory cellular dynamics through coupled splicing dynamics. See our [RegVelo manuscript](https://www.biorxiv.org/content/10.1101/2024.12.11.627935v1) to learn more. If you use our tool in your own work, please cite it as
6 |
7 | ```
8 | @article{wang2024regvelo,
9 | title={RegVelo: gene-regulatory-informed dynamics of single cells},
10 | author={Wang, Weixu and Hu, Zhiyuan and Weiler, Philipp and Mayes, Sarah and Lange, Marius and Wang, Jingye and Xue, Zhengyuan and Sauka-Spengler, Tatjana and Theis, Fabian J},
11 | journal={bioRxiv},
12 | pages={2024--12},
13 | year={2024},
14 | publisher={Cold Spring Harbor Laboratory}
15 | }
16 | ```
17 | ## Getting started
18 | Please refer to the [Tutorials](https://regvelo.readthedocs.io/en/latest/index.html)
19 |
20 | ## Installation
21 |
22 | You need to have Python 3.8 or newer installed on your system. If you don't have
23 | Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
24 |
25 | To create and activate a new environment
26 |
27 | ```bash
28 | conda create -n regvelo-py310 python=3.10 --yes && conda activate regvelo-py310
29 | ```
30 |
31 | Next, install the package with
32 |
33 | ```bash
34 | pip install git+https://github.com/theislab/regvelo.git@main --no-cache-dir --force-reinstall
35 | ```
36 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # Run to check if valid
2 | # curl --data-binary @codecov.yml https://codecov.io/validate
3 | coverage:
4 | status:
5 | project:
6 | default:
7 | target: 80%
8 | threshold: 1%
9 | patch: off
10 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = scvi
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/img/overview_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/docs/_static/img/overview_fig.png
--------------------------------------------------------------------------------
/docs/_static/img/perturbation_overview_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/docs/_static/img/perturbation_overview_fig.png
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/_templates/class_no_inherited.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 | :show-inheritance:
9 |
10 | {% block attributes %}
11 | {% if attributes %}
12 | Attributes table
13 | ~~~~~~~~~~~~~~~~
14 |
15 | .. autosummary::
16 | {% for item in attributes %}
17 | {%- if item not in inherited_members%}
18 | ~{{ fullname }}.{{ item }}
19 | {%- endif -%}
20 | {%- endfor %}
21 | {% endif %}
22 | {% endblock %}
23 |
24 |
25 | {% block methods %}
26 | {% if methods %}
27 | Methods table
28 | ~~~~~~~~~~~~~~
29 |
30 | .. autosummary::
31 | {% for item in methods %}
32 | {%- if item != '__init__' and item not in inherited_members%}
33 | ~{{ fullname }}.{{ item }}
34 | {%- endif -%}
35 |
36 | {%- endfor %}
37 | {% endif %}
38 | {% endblock %}
39 |
40 | {% block attributes_documentation %}
41 | {% if attributes %}
42 | Attributes
43 | ~~~~~~~~~~
44 |
45 | {% for item in attributes %}
46 | {%- if item not in inherited_members%}
47 |
48 | .. autoattribute:: {{ [objname, item] | join(".") }}
49 | {%- endif -%}
50 | {%- endfor %}
51 |
52 | {% endif %}
53 | {% endblock %}
54 |
55 | {% block methods_documentation %}
56 | {% if methods %}
57 | Methods
58 | ~~~~~~~
59 |
60 | {% for item in methods %}
61 | {%- if item != '__init__' and item not in inherited_members%}
62 |
63 | .. automethod:: {{ [objname, item] | join(".") }}
64 | {%- endif -%}
65 | {%- endfor %}
66 |
67 | {% endif %}
68 | {% endblock %}
69 |
--------------------------------------------------------------------------------
/docs/about/index.rst:
--------------------------------------------------------------------------------
1 | About RegVelo
2 | ------------
3 |
4 | Understanding cellular dynamics and regulatory interactions is crucial for decoding the complex processes that govern cell fate and differentiation.
5 | Traditional RNA velocity methods capture dynamic cellular transitions by modeling changes in spliced and unspliced mRNA but lack integration with gene regulatory networks (GRNs), omitting critical regulatory mechanisms underlying cellular decisions.
6 | Conversely, GRN inference techniques map regulatory connections but fail to account for the temporal dynamics of gene expression.
7 |
8 | With RegVelo, developed by `Wang et al. (biorxiv, 2024) `_,
9 | the research gap is bridged through combining RNA velocity's temporal insights with a regulatory framework to model transcriptome-wide splicing kinetics informed by GRNs.
10 | This extend current RNA velocity framework to a full mechanism model, allows to model more complicated development process.
11 | Further, by coupling with CellRank `Weiler et al. (Nature Methods, 2024) `_, RegVelo expands its capabilities to include robust perturbation predictions, linking regulatory changes to long-term cell fate decisions.
12 | CellRank employs velocity-based state transition probabilities to predict terminal cell states, and RegVelo enhances this framework with GRN-informed splicing kinetics,
13 | enabling precise simulations of transcription factor (TF) knockouts. This synergy allows for the identification of lineage drivers and the prediction of cell fate changes upon genetic perturbations.
14 |
15 | RegVelo's application
16 | ~~~~~~~~~~~~~~~~~~~
17 | - estimate RNA velocity govarned by gene regulation.
18 | - infer latent time to indicating cellular differentiation process.
19 | - estimate velocity intrinsic and extrinsic uncertainty :cite:p:`gayoso2024deep`.
20 | - estimate regulon perturbation effects via CellRank framework :cite:p:`lange2022cellrank, weiler2024cellrank`.
21 |
22 | RegVelo model
23 | ~~~~~~~~~~~~~~~~~~~
24 | RegVelo leverages deep generative modeling to infer splicing kinetics, transcription rates, and latent cellular time while integrating GRN priors derived from multi-omics data or curated databases.
25 | RegVelo incorporates cellular dynamic estimates by first encoding unspliced (*u*) and spliced RNA (*s*) readouts into posterior parameters of a low dimensional latent variable - the cell representation - with a neural network.
26 | An additional neural network takes samples of this cell representation as input to parameterize gene-wise latent time as in our previous model veloVI.
27 | We then model splicing dynamics with ordinary differential equations (ODEs) specified by a base transcription *b* and GRN weight matrix *W* ,
28 | describing transcription and inferred by a shallow neural network, constant splicing and degradation rate parameters *beta* and *gamma* , respectively,
29 | and estimated cell and gene-specific latent times. Importantly, existing methods for inferring RNA velocity consider a set of decoupled one-dimensional ODEs for which analytic solutions exist, but RegVelo relies on the single, high-dimensional ODE
30 |
31 | .. math::
32 | \begin{align}
33 | \frac{\mathrm{d} u_{g}(t)}{\mathrm{d} t} =\alpha_{g}(t) - \beta_{g} u_{g}(t), \\
34 | \frac{\mathrm{d} s_{g}(t)}{\mathrm{d} t} = \beta_{g} u_{g}(t) - \gamma_{g} s_{g}(t),
35 | \end{align}
36 |
37 | that is now coupled through gene regulation-informed transcription
38 |
39 | .. math::
40 | \alpha_g = h \left( \left [ W s(t) +b \right ] _{g} \right)
41 |
42 | where *g* indicates the gene and *h* is a non-linear activation function.
43 | We predict the gene and cell-specific spliced and unspliced abundances using a parallelizable ODE solver,
44 | as this new system does not pose an analytic solution anymore; compared to previous approaches, we solve all gene dynamics at once instead of sequentially for each gene independently of all others.
45 | The forward simulation of the ODE solver allows for computing the likelihood function encompassing all neural network and kinetic parameters.
46 | We assume that the predicted spliced and unspliced abundances are the expected value of the Gaussian likelihood of the observed dataset and use gradient-based optimization to update all parameters.
47 | After optimization, we define cell-gene-specific velocities as splicing velocities based on the estimated splicing and degradation rates and predicted spliced and unspliced abundance.
48 | Overall, RegVelo allows sampling predicted readouts and velocities from the learned posterior distribution.
49 |
50 |
51 | Perturbation prediction
52 | ~~~~~~~~~~~~~~~~~~~
53 |
54 | .. image:: https://github.com/theislab/regvelo/blob/main/docs/_static/img/perturbation_overview_fig.png?raw=true
55 | :alt: RegVelo perturbation introduction
56 | :width: 600px
57 |
58 | RegVelo is a generative model that couples cellular dynamics with regulatory networks.
59 | We can, thus, perform in silico counterfactual inference to test the cellular response upon unseen perturbations of a TF in the regulatory network: for a trained RegVelo model,
60 | we ignore regulatory effects of the TF by removing all its downstream targets from the GRN, i.e., depleting the regulon, and generate the perturbed velocity vector field.
61 | The dissimilarity between the original and perturbed cell velocities - the perturbation effect score - reflects the local changes on each cell induced by perturbations; we quantify this score with cosine dissimilarity.
62 |
63 | RNA velocity describes a high dimensional vector field representing cellular change along the phenotypic manifold but lacks interpretability and quantifiable measures of the long-term cell behavior.
64 | We recently proposed CellRank to bridge this gap by leveraging gene expression and an estimated vector field to model cell state transitions through Markov chains and infer terminal cell states.
65 | For each terminal state identified, CellRank calculates the probability of a cell transitioning to this state - the fate probability - that allows us to predict the cell's future state.
66 | By combining RegVelo’s generative model with CellRank, we connect gene regulation with both local cell dynamics and long-term cell fate decisions, and how they change upon in silico perturbations.
67 | In the context of our perturbation analyses, we compare CellRank’s prediction of cell fate probabilities for the original and perturbed vector fields,
68 | to find enrichment (increased cell fate probability) or depletion (decreased cell fate probability) effects towards terminal states.
69 |
70 | See `Wang et al. (biorxiv, 2024) `_ for a detailed description of the methods and applications on different biological systems.
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/docs/api/.Rhistory:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/docs/api/.Rhistory
--------------------------------------------------------------------------------
/docs/api/index.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: regvelo
5 | ```
6 |
7 | ```{eval-rst}
8 | .. autosummary::
9 | :toctree: reference/
10 | :nosignatures:
11 |
12 | REGVELOVI
13 | ```
14 |
15 | ```{eval-rst}
16 | .. autosummary::
17 | :toctree: reference/
18 | :template: class_no_inherited.rst
19 | :nosignatures:
20 |
21 | VELOVAE
22 | ```
23 |
24 | ```{eval-rst}
25 | .. autosummary::
26 | :toctree: reference/
27 | :nosignatures:
28 |
29 | TFscreening
30 | ```
31 |
32 | ```{eval-rst}
33 | .. autosummary::
34 | :toctree: reference/
35 | :nosignatures:
36 |
37 | in_silico_block_simulation
38 | ```
39 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import subprocess
3 | import os
4 | import importlib
5 | import inspect
6 | import re
7 | import sys
8 | from datetime import datetime
9 | from importlib.metadata import metadata
10 | from pathlib import Path
11 |
12 | HERE = Path(__file__).parent
13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")]
14 |
15 |
16 | # -- Project information -----------------------------------------------------
17 |
18 | project_name = "regvelo"
19 | package_name = "regvelo"
20 | author = "Weixu Wang"
21 | copyright = f"{datetime.now():%Y}, {author}."
22 | version = "0.2.0"
23 | repository_url = "https://github.com/theislab/regvelo"
24 |
25 | # The full version, including alpha/beta/rc tags
26 | release = "0.2.0"
27 |
28 | bibtex_bibfiles = ["references.bib"]
29 | bibtex_reference_style = "author_year"
30 |
31 | templates_path = ["_templates"]
32 | nitpicky = True # Warn about broken links
33 | needs_sphinx = "4.0"
34 |
35 | html_context = {
36 | "display_github": True, # Integrate GitHub
37 | "github_user": "theislab", # Username
38 | "github_repo": project_name, # Repo name
39 | "github_version": "main", # Version
40 | "conf_py_path": "/docs/", # Path in the checkout to the docs root
41 | }
42 |
43 | # -- General configuration ---------------------------------------------------
44 |
45 | # Add any Sphinx extension module names here, as strings.
46 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
47 | extensions = [
48 | "myst_nb",
49 | "sphinx.ext.autodoc",
50 | "sphinx.ext.linkcode",
51 | "sphinx.ext.intersphinx",
52 | "sphinx.ext.autosummary",
53 | "sphinx.ext.napoleon",
54 | "sphinxcontrib.bibtex",
55 | "sphinx.ext.mathjax",
56 | "sphinx.ext.extlinks",
57 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
58 | "sphinx_copybutton",
59 | ]
60 |
61 | autosummary_generate = True
62 | autodoc_member_order = "bysource"
63 | default_role = "literal"
64 | autodoc_typehints = "description"
65 | bibtex_reference_style = "author_year"
66 | napoleon_google_docstring = True
67 | napoleon_numpy_docstring = True
68 | napoleon_include_init_with_doc = False
69 | napoleon_use_rtype = True # having a separate entry generally helps readability
70 | napoleon_use_param = True
71 | myst_enable_extensions = [
72 | "amsmath",
73 | "colon_fence",
74 | "deflist",
75 | "dollarmath",
76 | "html_image",
77 | "html_admonition",
78 | ]
79 | myst_url_schemes = ("http", "https", "mailto")
80 | nb_output_stderr = "remove"
81 | nb_execution_mode = "off"
82 | nb_merge_streams = True
83 |
84 | source_suffix = {
85 | ".rst": "restructuredtext",
86 | ".ipynb": "myst-nb",
87 | ".myst": "myst-nb",
88 | }
89 |
90 | intersphinx_mapping = {
91 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
92 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None),
93 | "matplotlib": ("https://matplotlib.org/", None),
94 | "numpy": ("https://numpy.org/doc/stable/", None),
95 | "pandas": ("https://pandas.pydata.org/docs/", None),
96 | "python": ("https://docs.python.org/3", None),
97 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
98 | "sklearn": ("https://scikit-learn.org/stable/", None),
99 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None),
100 | "jax": ("https://jax.readthedocs.io/en/latest/", None),
101 | "torch": ("https://pytorch.org/docs/master/", None),
102 | "plottable": ("https://plottable.readthedocs.io/en/latest/", None),
103 | "scvi-tools": ("https://docs.scvi-tools.org/en/stable/", None),
104 | "mudata": ("https://mudata.readthedocs.io/en/latest/", None),
105 | }
106 |
107 | # List of patterns, relative to source directory, that match files and
108 | # directories to ignore when looking for source files.
109 | # This pattern also affects html_static_path and html_extra_path.
110 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
111 |
112 | # extlinks config
113 | extlinks = {
114 | "issue": (f"{repository_url}/issues/%s", "#%s"),
115 | "pr": (f"{repository_url}/pull/%s", "#%s"),
116 | "ghuser": ("https://github.com/%s", "@%s"),
117 | }
118 |
119 |
120 | # -- Linkcode settings -------------------------------------------------
121 |
122 |
123 | def git(*args):
124 | """Run a git command and return the output."""
125 | return subprocess.check_output(["git", *args]).strip().decode()
126 |
127 |
128 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192
129 | # Current git reference. Uses branch/tag name if found, otherwise uses commit hash
130 | git_ref = None
131 | try:
132 | git_ref = git("name-rev", "--name-only", "--no-undefined", "HEAD")
133 | git_ref = re.sub(r"^(remotes/[^/]+|tags)/", "", git_ref)
134 | except Exception:
135 | pass
136 |
137 | # (if no name found or relative ref, use commit hash instead)
138 | if not git_ref or re.search(r"[\^~]", git_ref):
139 | try:
140 | git_ref = git("rev-parse", "HEAD")
141 | except Exception:
142 | git_ref = "main"
143 |
144 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192
145 | github_repo = "https://github.com/" + html_context["github_user"] + "/" + project_name
146 | _project_module_path = os.path.dirname(importlib.util.find_spec(package_name).origin) # type: ignore
147 |
148 |
149 | def linkcode_resolve(domain, info):
150 | """Resolve links for the linkcode extension."""
151 | if domain != "py":
152 | return None
153 |
154 | try:
155 | obj: Any = sys.modules[info["module"]]
156 | for part in info["fullname"].split("."):
157 | obj = getattr(obj, part)
158 | obj = inspect.unwrap(obj)
159 |
160 | if isinstance(obj, property):
161 | obj = inspect.unwrap(obj.fget) # type: ignore
162 |
163 | path = os.path.relpath(inspect.getsourcefile(obj), start=_project_module_path) # type: ignore
164 | src, lineno = inspect.getsourcelines(obj)
165 | except Exception:
166 | return None
167 |
168 | path = f"{path}#L{lineno}-L{lineno + len(src) - 1}"
169 | return f"{github_repo}/blob/{git_ref}/{package_name}/{path}"
170 |
171 |
172 | # -- Options for HTML output -------------------------------------------------
173 |
174 | # The theme to use for HTML and HTML Help pages. See the documentation for
175 | # a list of builtin themes.
176 | #
177 | html_theme = "sphinx_book_theme"
178 | html_static_path = ["_static"]
179 | html_title = "RegVelo"
180 |
181 | html_theme_options = {
182 | "repository_url": github_repo,
183 | "use_repository_button": True,
184 | }
185 |
186 | pygments_style = "default"
187 |
188 | nitpick_ignore = [
189 | # If building the documentation fails because of a missing link that is outside your control,
190 | # you can add an exception to this list.
191 | ]
192 |
193 |
194 | def setup(app):
195 | """App setup hook."""
196 | app.add_config_value(
197 | "recommonmark_config",
198 | {
199 | "auto_toc_tree_section": "Contents",
200 | "enable_auto_toc_tree": True,
201 | "enable_math": True,
202 | "enable_inline_math": False,
203 | "enable_eval_rst": True,
204 | },
205 | True,
206 | )
207 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | import re
4 |
5 | from sphinx.application import Sphinx
6 | from sphinx.ext.napoleon import NumpyDocstring
7 |
8 |
9 | def process_return(lines):
10 | """Process the return section of a docstring."""
11 | for line in lines:
12 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line)
13 | if m:
14 | # Once this is in scanpydoc, we can use the fancy hover stuff
15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
16 | else:
17 | yield line
18 |
19 |
20 | def scanpy_parse_returns_section(self, section):
21 | """Parse the returns section of the docstring."""
22 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section())))
23 | lines = self._format_block(":returns: ", lines_raw)
24 | if lines and lines[-1]:
25 | lines.append("")
26 | return lines
27 |
28 |
29 | def setup(app: Sphinx):
30 | """Setup the extension."""
31 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section
32 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{include} ../README.md
2 |
3 | ```
4 |
5 | # Welcome to the RegVelo documentation.
6 |
7 | ```{toctree}
8 | :maxdepth: 3
9 | :titlesonly: true
10 |
11 | about/index
12 | tutorials/index
13 | api/index
14 | release_notes/index
15 | references
16 |
17 | ```
18 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=scvi
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{wang2024regvelo,
2 | title={RegVelo: gene-regulatory-informed dynamics of single cells},
3 | author={Wang, Weixu and Hu, Zhiyuan and Weiler, Philipp and Mayes, Sarah and Lange, Marius and Wang, Jingye and Xue, Zhengyuan and Sauka-Spengler, Tatjana and Theis, Fabian J},
4 | journal={bioRxiv},
5 | pages={2024--12},
6 | year={2024},
7 | publisher={Cold Spring Harbor Laboratory}
8 | }
9 |
10 | @article{lange2022cellrank,
11 | title={CellRank for directed single-cell fate mapping},
12 | author={Lange, Marius and Bergen, Volker and Klein, Michal and Setty, Manu and Reuter, Bernhard and Bakhti, Mostafa and Lickert, Heiko and Ansari, Meshal and Schniering, Janine and Schiller, Herbert B and others},
13 | journal={Nature methods},
14 | volume={19},
15 | number={2},
16 | pages={159--170},
17 | year={2022},
18 | publisher={Nature Publishing Group US New York}
19 | }
20 |
21 | @article{weiler2024cellrank,
22 | title={CellRank 2: unified fate mapping in multiview single-cell data},
23 | author={Weiler, Philipp and Lange, Marius and Klein, Michal and Pe’er, Dana and Theis, Fabian},
24 | journal={Nature Methods},
25 | pages={1--10},
26 | year={2024},
27 | publisher={Nature Publishing Group US New York}
28 | }
29 |
30 | @article{bergen2020generalizing,
31 | title={Generalizing RNA velocity to transient cell states through dynamical modeling},
32 | author={Bergen, Volker and Lange, Marius and Peidli, Stefan and Wolf, F Alexander and Theis, Fabian J},
33 | journal={Nature biotechnology},
34 | volume={38},
35 | number={12},
36 | pages={1408--1414},
37 | year={2020},
38 | publisher={Nature Publishing Group US New York}
39 | }
40 |
41 | @article{gayoso2024deep,
42 | title={Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells},
43 | author={Gayoso, Adam and Weiler, Philipp and Lotfollahi, Mohammad and Klein, Dominik and Hong, Justin and Streets, Aaron and Theis, Fabian J and Yosef, Nir},
44 | journal={Nature methods},
45 | volume={21},
46 | number={1},
47 | pages={50--59},
48 | year={2024},
49 | publisher={Nature Publishing Group US New York}
50 | }
51 |
52 | @article{hu2024single,
53 | title={Single-cell multi-omics, spatial transcriptomics and systematic perturbation decode circuitry of neural crest fate decisions},
54 | author={Hu, Zhiyuan and Mayes, Sarah and Wang, Weixu and Santos-Pereira, Jose M and Theis, Fabian and Sauka-Spengler, Tatjana},
55 | journal={bioRxiv},
56 | pages={2024--09},
57 | year={2024},
58 | publisher={Cold Spring Harbor Laboratory}
59 | }
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | If RegVelo is helpful in your research, please cite {cite:p}`wang2024regvelo`. If you are using zebrafish smart-seq3, 10x multiome and in vivo perturb-seq, please cite {cite:p}`wang2024regvelo, hu2024single`.
4 |
5 | ```{bibliography}
6 | :cited:
7 | ```
8 |
--------------------------------------------------------------------------------
/docs/release_notes/index.rst:
--------------------------------------------------------------------------------
1 | Release notes
2 | =============
3 |
4 | Version 0.1
5 | -----------
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | v0.1.0
10 |
--------------------------------------------------------------------------------
/docs/release_notes/v0.1.0.rst:
--------------------------------------------------------------------------------
1 | New in 0.1.0 (2024-09-03)
2 | -------------------------
3 |
--------------------------------------------------------------------------------
/docs/tutorials/index.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 |
4 | ```{toctree}
5 | :maxdepth: 2
6 |
7 | index_murine
8 | index_zebrafish
9 | index_modelcomp
10 | ```
11 |
--------------------------------------------------------------------------------
/docs/tutorials/index_modelcomp.md:
--------------------------------------------------------------------------------
1 | # Model comparison with multi-view information
2 |
3 | ```{toctree}
4 | :maxdepth: 1
5 | :titlesonly:
6 |
7 | modelcomparison/ModelComp.ipynb
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/tutorials/index_murine.md:
--------------------------------------------------------------------------------
1 | # Murine Neural Crest Development
2 |
3 | ```{toctree}
4 | :maxdepth:1
5 | :titlesonly:
6 |
7 | murine/01_SCENIC_tutorial
8 | murine/02_RegVelo_preparation
9 | murine/03_perturbation_tutorial_murine
10 |
11 | ```
12 |
13 |
14 |
--------------------------------------------------------------------------------
/docs/tutorials/index_zebrafish.md:
--------------------------------------------------------------------------------
1 | # Zebrafish Neural Crest Development
2 |
3 | ```{toctree}
4 | :maxdepth: 1
5 | :titlesonly:
6 |
7 | zebrafish/tutorial
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/tutorials/murine/01_SCENIC_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "37f7b867-f3e7-4d1a-ad57-36898e03a0d8",
6 | "metadata": {},
7 | "source": [
8 | "# Infer prior GRN from [pySCENIC](https://pyscenic.readthedocs.io/en/latest/installation.html)\n",
9 | "In this notebook, we use [SCENIC](https://scenic.aertslab.org/) to infer a prior gene regulatory network (GRN) for the RegVelo pipeline."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "id": "405f8373-e7d0-4ebe-b03f-3937d1aa9d46",
15 | "metadata": {},
16 | "source": [
17 | "## Library import"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "id": "b3b40717-9fb5-4ef4-8733-6fb7f4dff658",
24 | "metadata": {
25 | "tags": []
26 | },
27 | "outputs": [],
28 | "source": [
29 | "import os\n",
30 | "import numpy as np\n",
31 | "import pandas as pd\n",
32 | "import scanpy as sc\n",
33 | "import loompy as lp\n",
34 | "\n",
35 | "import glob"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "id": "286984e3-5844-418b-bd2e-70eda9c109a1",
42 | "metadata": {
43 | "tags": []
44 | },
45 | "outputs": [
46 | {
47 | "name": "stderr",
48 | "output_type": "stream",
49 | "text": [
50 | "/home/icb/yifan.chen/miniconda3/envs/pyscenic/lib/python3.10/site-packages/session_info/main.py:213: UserWarning: The '__version__' attribute is deprecated and will be removed in MarkupSafe 3.1. Use feature detection, or `importlib.metadata.version(\"markupsafe\")`, instead.\n",
51 | " mod_version = _find_version(mod.__version__)\n"
52 | ]
53 | },
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "-----\n",
59 | "anndata 0.11.4\n",
60 | "scanpy 1.10.4\n",
61 | "-----\n",
62 | "PIL 11.2.1\n",
63 | "asttokens NA\n",
64 | "charset_normalizer 3.4.1\n",
65 | "cloudpickle 3.1.1\n",
66 | "comm 0.2.1\n",
67 | "cycler 0.12.1\n",
68 | "cython_runtime NA\n",
69 | "cytoolz 1.0.1\n",
70 | "dask 2025.4.1\n",
71 | "dateutil 2.9.0.post0\n",
72 | "debugpy 1.8.11\n",
73 | "decorator 5.1.1\n",
74 | "exceptiongroup 1.2.0\n",
75 | "executing 0.8.3\n",
76 | "h5py 3.13.0\n",
77 | "ipykernel 6.29.5\n",
78 | "jedi 0.19.2\n",
79 | "jinja2 3.1.6\n",
80 | "joblib 1.4.2\n",
81 | "kiwisolver 1.4.8\n",
82 | "legacy_api_wrap NA\n",
83 | "llvmlite 0.44.0\n",
84 | "loompy 3.0.8\n",
85 | "lz4 4.4.4\n",
86 | "markupsafe 3.0.2\n",
87 | "matplotlib 3.10.1\n",
88 | "mpl_toolkits NA\n",
89 | "natsort 8.4.0\n",
90 | "numba 0.61.2\n",
91 | "numexpr 2.10.2\n",
92 | "numpy 2.2.5\n",
93 | "numpy_groupies 0.11.2\n",
94 | "packaging 24.2\n",
95 | "pandas 2.2.3\n",
96 | "parso 0.8.4\n",
97 | "platformdirs 4.3.7\n",
98 | "prompt_toolkit 3.0.43\n",
99 | "psutil 5.9.0\n",
100 | "pure_eval 0.2.2\n",
101 | "pyarrow 20.0.0\n",
102 | "pydev_ipython NA\n",
103 | "pydevconsole NA\n",
104 | "pydevd 3.2.3\n",
105 | "pydevd_file_utils NA\n",
106 | "pydevd_plugins NA\n",
107 | "pydevd_tracing NA\n",
108 | "pygments 2.19.1\n",
109 | "pyparsing 3.2.3\n",
110 | "pytz 2025.2\n",
111 | "scipy 1.15.2\n",
112 | "session_info v1.0.1\n",
113 | "six 1.17.0\n",
114 | "sklearn 1.6.1\n",
115 | "stack_data 0.2.0\n",
116 | "tblib 3.1.0\n",
117 | "threadpoolctl 3.6.0\n",
118 | "tlz 1.0.1\n",
119 | "toolz 1.0.0\n",
120 | "tornado 6.4.2\n",
121 | "traitlets 5.14.3\n",
122 | "typing_extensions NA\n",
123 | "wcwidth 0.2.5\n",
124 | "yaml 6.0.2\n",
125 | "zipp NA\n",
126 | "zmq 26.2.0\n",
127 | "zoneinfo NA\n",
128 | "-----\n",
129 | "IPython 8.30.0\n",
130 | "jupyter_client 8.6.3\n",
131 | "jupyter_core 5.7.2\n",
132 | "-----\n",
133 | "Python 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0]\n",
134 | "Linux-5.14.0-427.37.1.el9_4.x86_64-x86_64-with-glibc2.34\n",
135 | "-----\n",
136 | "Session information updated at 2025-04-28 11:30\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "sc.settings.verbosity = 3\n",
142 | "sc.logging.print_versions()"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "id": "943265ea-267c-4c47-b00c-28ef7fbd8ab2",
148 | "metadata": {},
149 | "source": [
150 | "## Load data and output to loom file\n",
151 | "Read murine neural crest data."
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "id": "6f4a000b-455a-4625-a3e2-c0b2bafc5cea",
158 | "metadata": {
159 | "tags": []
160 | },
161 | "outputs": [],
162 | "source": [
163 | "adata = rgv.datasets.murine_nc(data_type = \"normalized\")"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 4,
169 | "id": "b7ce4c29-e8b3-428d-a930-469a155bd6f0",
170 | "metadata": {
171 | "tags": []
172 | },
173 | "outputs": [
174 | {
175 | "data": {
176 | "text/plain": [
177 | "AnnData object with n_obs × n_vars = 6788 × 30717\n",
178 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron'\n",
179 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank'\n",
180 | " obsm: 'X_pca', 'X_umap'"
181 | ]
182 | },
183 | "execution_count": 4,
184 | "metadata": {},
185 | "output_type": "execute_result"
186 | }
187 | ],
188 | "source": [
189 | "adata"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 5,
195 | "id": "56baecec-0ee1-4f30-bddb-31426b2e3b35",
196 | "metadata": {
197 | "tags": []
198 | },
199 | "outputs": [],
200 | "source": [
201 | "adata = sc.AnnData(adata.X, obs=adata.obs, var=adata.var)\n",
202 | "adata.var[\"Gene\"] = adata.var_names\n",
203 | "adata.obs[\"CellID\"] = adata.obs_names\n",
204 | "adata.write_loom(\"adata.loom\")"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "id": "2005e02a-abb1-4db9-b944-fe9fdb8ac9f5",
210 | "metadata": {},
211 | "source": [
212 | "## SCENIC steps\n",
213 | "In the following, we use [SCENIC](https://scenic.aertslab.org/) to infer prior regulation information. Installation and usage steps are given in [pySCENIC](https://pyscenic.readthedocs.io/en/latest/installation.html) and are demonstrated in [SCENICprotocol](https://github.com/aertslab/SCENICprotocol/tree/master)."
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 7,
219 | "id": "64e1477e-6539-4605-abbf-c7f72ddfb8dc",
220 | "metadata": {
221 | "tags": []
222 | },
223 | "outputs": [],
224 | "source": [
225 | "# path to loom file created previously\n",
226 | "f_loom_path_scenic = \"adata.loom\"\n",
227 | "# path to list of transcription factors\n",
228 | "f_tfs = \"allTFs_mm.txt\""
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": 9,
234 | "id": "6426c2fe-a664-479e-9113-45fbfb62bcef",
235 | "metadata": {
236 | "tags": []
237 | },
238 | "outputs": [
239 | {
240 | "name": "stdout",
241 | "output_type": "stream",
242 | "text": [
243 | "/bin/bash: line 1: pyscenic: command not found\n"
244 | ]
245 | }
246 | ],
247 | "source": [
248 | "!pyscenic grn {f_loom_path_scenic} {f_tfs} -o \"adj.csv\" --num_workers 24"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "id": "1f9fc0ba-0012-48eb-bcbd-3f859dcc37a8",
255 | "metadata": {
256 | "tags": []
257 | },
258 | "outputs": [],
259 | "source": [
260 | "# path to ranking databases in feather format\n",
261 | "f_db_glob = \"scenic/cisTarget_databases/*feather\"\n",
262 | "f_db_names = ' '.join( glob.glob(f_db_glob) )\n",
263 | "\n",
264 | "# path to motif databases\n",
265 | "f_motif_path = \"scenic/cisTarget_databases/motifs-v9-nr.mgi-m0.001-o0.0.tbl\""
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "id": "e1bdaf1e-91e2-49ff-91fd-4143bbdc032a",
272 | "metadata": {
273 | "tags": []
274 | },
275 | "outputs": [],
276 | "source": [
277 | "!pyscenic ctx \"adj.csv\" \\\n",
278 | " {f_db_names} \\\n",
279 | " --annotations_fname {f_motif_path} \\\n",
280 | " --expression_mtx_fname {f_loom_path_scenic} \\\n",
281 | " --output \"reg.csv\" \\\n",
282 | " --all_modules \\\n",
283 | " --num_workers 24"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "id": "6c2f499b-b032-4afd-bb12-57eae82d9a46",
290 | "metadata": {
291 | "tags": []
292 | },
293 | "outputs": [],
294 | "source": [
295 | "f_pyscenic_output = \"pyscenic_output_all_regulon_no_mask.loom\""
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": null,
301 | "id": "db083910-2345-4e39-8ddd-ff14f210c498",
302 | "metadata": {
303 | "tags": []
304 | },
305 | "outputs": [],
306 | "source": [
307 | "!pyscenic aucell \\\n",
308 | " {f_loom_path_scenic} \\\n",
309 | " \"reg.csv\" \\\n",
310 | " --output {f_pyscenic_output} \\\n",
311 | " --num_workers 4"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 26,
317 | "id": "fbba0c0c-109a-4112-a1eb-340c3942a3b7",
318 | "metadata": {
319 | "tags": []
320 | },
321 | "outputs": [],
322 | "source": [
323 | "# collect SCENIC AUCell output\n",
324 | "lf = lp.connect(f_pyscenic_output, mode='r+', validate=False )\n",
325 | "auc_mtx = pd.DataFrame(lf.ca.RegulonsAUC, index=lf.ca.CellID)\n",
326 | "regulons = lf.ra.Regulons"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 27,
332 | "id": "7a0925e3-d1da-4158-a54a-44535613c6c0",
333 | "metadata": {
334 | "tags": []
335 | },
336 | "outputs": [],
337 | "source": [
338 | "res = pd.concat([pd.Series(r.tolist(), index=regulons.dtype.names) for r in regulons], axis=1)\n",
339 | "res.columns = lf.row_attrs[\"SYMBOL\"]\n",
340 | "res.to_csv(\"regulon_mat_all_regulons.csv\")"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "id": "512bff3d-6cfb-4dd0-a332-703f867943af",
346 | "metadata": {},
347 | "source": [
348 | "## Create prior GRN for RegVelo\n",
349 | "In the following, we preprocess the GRN inferred from [SCENIC](https://scenic.aertslab.org/), saved as `regulon_mat_all_regulons.csv`. We first read the regulon file, where rows are regulators and columns are target genes. We further extract the names of the transcription factors (TFs) from the row indices using a regex and collapse dublicte TFs by summing their rows."
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": null,
355 | "id": "f5f040df-aff8-43ff-91b4-b4e7ff6dbf18",
356 | "metadata": {},
357 | "outputs": [],
358 | "source": [
359 | "# load saved regulon-target matrix\n",
360 | "reg = pd.read_csv(\"regulon_mat_all_regulons.csv\", index_col = 0)\n",
361 | "\n",
362 | "reg.index = reg.index.str.extract(r\"(\\w+)\")[0]\n",
363 | "reg = reg.groupby(reg.index).sum()"
364 | ]
365 | },
366 | {
367 | "cell_type": "markdown",
368 | "id": "5c18d266-5726-4ecc-ab38-2de5ce0d185b",
369 | "metadata": {},
370 | "source": [
371 | "We further binarize the matrix, where 1 indicates the presence of regulation and 0 indicates otherwise and get the list of TFs and genes."
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "id": "bd475a15-b5b7-49a9-b549-c7fece33ba81",
378 | "metadata": {},
379 | "outputs": [],
380 | "source": [
381 | "reg[reg != 0] = 1\n",
382 | "\n",
383 | "TF = np.unique(list(map(lambda x: x.split(\"(\")[0], reg.index.tolist())))\n",
384 | "genes = np.unique(TF.tolist() + reg.columns.tolist())"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "id": "4a9633d5-204a-431f-a121-7cd927604b72",
390 | "metadata": {},
391 | "source": [
392 | "For the prior GRN, we first construct an empty square matrix and populate it based on the previously binarized regulation information. We further remove the genes that are neither a TF nor a target gene (i.e. remove empty rows and comlumns) and save the cleaned and structured GRN to a `.parquet` file for RegVelo's downstream pipeline."
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": 76,
398 | "id": "dd55d3a2-f98e-42a3-849a-22471675db2d",
399 | "metadata": {
400 | "tags": []
401 | },
402 | "outputs": [
403 | {
404 | "name": "stdout",
405 | "output_type": "stream",
406 | "text": [
407 | "Done! processed GRN with 543 TF and 30717 targets\n"
408 | ]
409 | }
410 | ],
411 | "source": [
412 | "GRN = pd.DataFrame(0, index=genes, columns=genes)\n",
413 | "GRN.loc[TF,reg.columns.tolist()] = np.array(reg)\n",
414 | "\n",
415 | "mask = (GRN.sum(0) != 0) | (GRN.sum(1) != 0)\n",
416 | "GRN = GRN.loc[mask, mask].copy()\n",
417 | "\n",
418 | "GRN.to_parquet(\"regulon_mat_processed_all_regulons.parquet\")\n",
419 | "print(\"Done! processed GRN with \" + str(reg.shape[0]) + \" TFs and \" + str(reg.shape[1]) + \" targets\")"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": null,
425 | "id": "5a5ff903-a13d-400c-9dcc-d2d75bd4d4f7",
426 | "metadata": {},
427 | "outputs": [],
428 | "source": []
429 | }
430 | ],
431 | "metadata": {
432 | "kernelspec": {
433 | "display_name": "Python (pyscenic)",
434 | "language": "python",
435 | "name": "pyscenic"
436 | },
437 | "language_info": {
438 | "codemirror_mode": {
439 | "name": "ipython",
440 | "version": 3
441 | },
442 | "file_extension": ".py",
443 | "mimetype": "text/x-python",
444 | "name": "python",
445 | "nbconvert_exporter": "python",
446 | "pygments_lexer": "ipython3",
447 | "version": "3.10.16"
448 | }
449 | },
450 | "nbformat": 4,
451 | "nbformat_minor": 5
452 | }
453 |
--------------------------------------------------------------------------------
/docs/tutorials/murine/02_RegVelo_preparation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e224f114-2918-4942-82d1-801e1429af7b",
6 | "metadata": {},
7 | "source": [
8 | "# Preprocess data and add prior GRN information\n",
9 | "In this notebook, we will go through the preprocessing steps needed prior to running RegVelo pipeline."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "id": "862aad5e-ac38-4ad6-9557-989e8f261fe8",
15 | "metadata": {},
16 | "source": [
17 | "## Library import "
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "id": "b1205fa6-37b9-45b7-bae4-6f45369db6f6",
24 | "metadata": {
25 | "tags": []
26 | },
27 | "outputs": [
28 | {
29 | "name": "stderr",
30 | "output_type": "stream",
31 | "text": [
32 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.\n",
33 | " warnings.warn(msg, FutureWarning)\n",
34 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.\n",
35 | " warnings.warn(msg, FutureWarning)\n",
36 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.\n",
37 | " warnings.warn(msg, FutureWarning)\n",
38 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.\n",
39 | " warnings.warn(msg, FutureWarning)\n",
40 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.\n",
41 | " warnings.warn(msg, FutureWarning)\n",
42 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.\n",
43 | " warnings.warn(msg, FutureWarning)\n",
44 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.\n",
45 | " warnings.warn(msg, FutureWarning)\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "import scvelo as scv\n",
51 | "import scanpy as sc\n",
52 | "import pandas as pd\n",
53 | "import numpy as np\n",
54 | "\n",
55 | "import matplotlib.pyplot as plt\n",
56 | "import mplscience\n",
57 | "import seaborn as sns\n",
58 | "\n",
59 | "import regvelo as rgv"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "cd088464-c366-40a8-a34e-8d3947fa4991",
65 | "metadata": {},
66 | "source": [
67 | "## Load data\n",
68 | "Read murine neural crest data that contains `.layers['spliced']` and `.layers['unspliced']`."
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "id": "2c1022b2-6133-4072-acb1-df900c0875b7",
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "adata = rgv.datasets.murine_nc(data_type = \"velocyto\")"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 3,
84 | "id": "cf07867d-5b7a-4feb-82e1-fa382ad19613",
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "data": {
89 | "text/plain": [
90 | "AnnData object with n_obs × n_vars = 6788 × 30717\n",
91 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron'\n",
92 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank'\n",
93 | " obsm: 'X_pca', 'X_umap'\n",
94 | " layers: 'spliced', 'unspliced'"
95 | ]
96 | },
97 | "execution_count": 3,
98 | "metadata": {},
99 | "output_type": "execute_result"
100 | }
101 | ],
102 | "source": [
103 | "adata"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "id": "aa1e8c0d-7e11-4a77-bd6f-042e27a09b6e",
109 | "metadata": {},
110 | "source": [
111 | "## Data preprocessing"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 4,
117 | "id": "85fa5243-6477-4a86-9640-4fd2f1172d1d",
118 | "metadata": {
119 | "tags": []
120 | },
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "Filtered out 22217 genes that are detected 20 counts (shared).\n",
127 | "Normalized count data: X, spliced, unspliced.\n",
128 | "Extracted 3000 highly variable genes.\n"
129 | ]
130 | },
131 | {
132 | "name": "stderr",
133 | "output_type": "stream",
134 | "text": [
135 | "/tmp/ipykernel_3655946/1257038519.py:4: DeprecationWarning: `log1p` is deprecated since scVelo v0.3.0 and will be removed in a future version. Please use `log1p` from `scanpy.pp` instead.\n",
136 | " scv.pp.log1p(adata)\n"
137 | ]
138 | },
139 | {
140 | "name": "stdout",
141 | "output_type": "stream",
142 | "text": [
143 | "computing moments based on connectivities\n",
144 | " finished (0:00:02) --> added \n",
145 | " 'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)\n"
146 | ]
147 | }
148 | ],
149 | "source": [
150 | "scv.pp.filter_genes(adata, min_shared_counts=20)\n",
151 | "scv.pp.normalize_per_cell(adata)\n",
152 | "scv.pp.filter_genes_dispersion(adata, n_top_genes=3000)\n",
153 | "scv.pp.log1p(adata)\n",
154 | "\n",
155 | "sc.pp.neighbors(adata,n_pcs = 30,n_neighbors = 50)\n",
156 | "sc.tl.umap(adata)\n",
157 | "scv.pp.moments(adata, n_pcs=None, n_neighbors=None)"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 5,
163 | "id": "91c087a7-ec36-4d86-a891-c297eba29a7c",
164 | "metadata": {},
165 | "outputs": [
166 | {
167 | "data": {
168 | "text/plain": [
169 | "AnnData object with n_obs × n_vars = 6788 × 3000\n",
170 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'\n",
171 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'\n",
172 | " uns: 'log1p', 'neighbors', 'umap'\n",
173 | " obsm: 'X_pca', 'X_umap'\n",
174 | " layers: 'spliced', 'unspliced', 'Ms', 'Mu'\n",
175 | " obsp: 'distances', 'connectivities'"
176 | ]
177 | },
178 | "execution_count": 5,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "adata"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "id": "943f6605-e091-4b8c-ad4b-075dbc4ef803",
190 | "metadata": {},
191 | "source": [
192 | "## Load prior GRN created from notebook 'Infer prior GRN from [pySCENIC](https://pyscenic.readthedocs.io/en/latest/installation.html)' for RegVelo\n",
193 | "In the following, we load the processed prior GRN infromation into our AnnData object. In this step `.uns['skeleton']` and `.var['TF']` are added, which will be needed for RegVelo's velocity pipeline."
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 6,
199 | "id": "f7d9c59f-6238-494b-8649-178ae5dd783b",
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "GRN = pd.read_parquet(\"regulon_mat_processed_all_regulons.parquet\")"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 7,
209 | "id": "63410f40-8ee1-48cb-9a93-5737fb3aa9ef",
210 | "metadata": {},
211 | "outputs": [
212 | {
213 | "data": {
214 | "text/html": [
215 | "\n",
216 | "\n",
229 | "
\n",
230 | " \n",
231 | " \n",
232 | " | \n",
233 | " 0610005C13Rik | \n",
234 | " 0610009L18Rik | \n",
235 | " 0610010K14Rik | \n",
236 | " 0610012G03Rik | \n",
237 | " 0610030E20Rik | \n",
238 | " 0610038B21Rik | \n",
239 | " 0610040B10Rik | \n",
240 | " 0610040J01Rik | \n",
241 | " 0610043K17Rik | \n",
242 | " 1110002L01Rik | \n",
243 | " ... | \n",
244 | " Zswim8 | \n",
245 | " Zw10 | \n",
246 | " Zwilch | \n",
247 | " Zwint | \n",
248 | " Zxdb | \n",
249 | " Zxdc | \n",
250 | " Zyg11b | \n",
251 | " Zyx | \n",
252 | " Zzef1 | \n",
253 | " Zzz3 | \n",
254 | "
\n",
255 | " \n",
256 | " \n",
257 | " \n",
258 | " 0610005C13Rik | \n",
259 | " 0 | \n",
260 | " 0 | \n",
261 | " 0 | \n",
262 | " 0 | \n",
263 | " 0 | \n",
264 | " 0 | \n",
265 | " 0 | \n",
266 | " 0 | \n",
267 | " 0 | \n",
268 | " 0 | \n",
269 | " ... | \n",
270 | " 0 | \n",
271 | " 0 | \n",
272 | " 0 | \n",
273 | " 0 | \n",
274 | " 0 | \n",
275 | " 0 | \n",
276 | " 0 | \n",
277 | " 0 | \n",
278 | " 0 | \n",
279 | " 0 | \n",
280 | "
\n",
281 | " \n",
282 | " 0610009L18Rik | \n",
283 | " 0 | \n",
284 | " 0 | \n",
285 | " 0 | \n",
286 | " 0 | \n",
287 | " 0 | \n",
288 | " 0 | \n",
289 | " 0 | \n",
290 | " 0 | \n",
291 | " 0 | \n",
292 | " 0 | \n",
293 | " ... | \n",
294 | " 0 | \n",
295 | " 0 | \n",
296 | " 0 | \n",
297 | " 0 | \n",
298 | " 0 | \n",
299 | " 0 | \n",
300 | " 0 | \n",
301 | " 0 | \n",
302 | " 0 | \n",
303 | " 0 | \n",
304 | "
\n",
305 | " \n",
306 | " 0610010K14Rik | \n",
307 | " 0 | \n",
308 | " 0 | \n",
309 | " 0 | \n",
310 | " 0 | \n",
311 | " 0 | \n",
312 | " 0 | \n",
313 | " 0 | \n",
314 | " 0 | \n",
315 | " 0 | \n",
316 | " 0 | \n",
317 | " ... | \n",
318 | " 0 | \n",
319 | " 0 | \n",
320 | " 0 | \n",
321 | " 0 | \n",
322 | " 0 | \n",
323 | " 0 | \n",
324 | " 0 | \n",
325 | " 0 | \n",
326 | " 0 | \n",
327 | " 0 | \n",
328 | "
\n",
329 | " \n",
330 | " 0610012G03Rik | \n",
331 | " 0 | \n",
332 | " 0 | \n",
333 | " 0 | \n",
334 | " 0 | \n",
335 | " 0 | \n",
336 | " 0 | \n",
337 | " 0 | \n",
338 | " 0 | \n",
339 | " 0 | \n",
340 | " 0 | \n",
341 | " ... | \n",
342 | " 0 | \n",
343 | " 0 | \n",
344 | " 0 | \n",
345 | " 0 | \n",
346 | " 0 | \n",
347 | " 0 | \n",
348 | " 0 | \n",
349 | " 0 | \n",
350 | " 0 | \n",
351 | " 0 | \n",
352 | "
\n",
353 | " \n",
354 | " 0610030E20Rik | \n",
355 | " 0 | \n",
356 | " 0 | \n",
357 | " 0 | \n",
358 | " 0 | \n",
359 | " 0 | \n",
360 | " 0 | \n",
361 | " 0 | \n",
362 | " 0 | \n",
363 | " 0 | \n",
364 | " 0 | \n",
365 | " ... | \n",
366 | " 0 | \n",
367 | " 0 | \n",
368 | " 0 | \n",
369 | " 0 | \n",
370 | " 0 | \n",
371 | " 0 | \n",
372 | " 0 | \n",
373 | " 0 | \n",
374 | " 0 | \n",
375 | " 0 | \n",
376 | "
\n",
377 | " \n",
378 | "
\n",
379 | "
5 rows × 13697 columns
\n",
380 | "
"
381 | ],
382 | "text/plain": [
383 | " 0610005C13Rik 0610009L18Rik 0610010K14Rik 0610012G03Rik \\\n",
384 | "0610005C13Rik 0 0 0 0 \n",
385 | "0610009L18Rik 0 0 0 0 \n",
386 | "0610010K14Rik 0 0 0 0 \n",
387 | "0610012G03Rik 0 0 0 0 \n",
388 | "0610030E20Rik 0 0 0 0 \n",
389 | "\n",
390 | " 0610030E20Rik 0610038B21Rik 0610040B10Rik 0610040J01Rik \\\n",
391 | "0610005C13Rik 0 0 0 0 \n",
392 | "0610009L18Rik 0 0 0 0 \n",
393 | "0610010K14Rik 0 0 0 0 \n",
394 | "0610012G03Rik 0 0 0 0 \n",
395 | "0610030E20Rik 0 0 0 0 \n",
396 | "\n",
397 | " 0610043K17Rik 1110002L01Rik ... Zswim8 Zw10 Zwilch Zwint \\\n",
398 | "0610005C13Rik 0 0 ... 0 0 0 0 \n",
399 | "0610009L18Rik 0 0 ... 0 0 0 0 \n",
400 | "0610010K14Rik 0 0 ... 0 0 0 0 \n",
401 | "0610012G03Rik 0 0 ... 0 0 0 0 \n",
402 | "0610030E20Rik 0 0 ... 0 0 0 0 \n",
403 | "\n",
404 | " Zxdb Zxdc Zyg11b Zyx Zzef1 Zzz3 \n",
405 | "0610005C13Rik 0 0 0 0 0 0 \n",
406 | "0610009L18Rik 0 0 0 0 0 0 \n",
407 | "0610010K14Rik 0 0 0 0 0 0 \n",
408 | "0610012G03Rik 0 0 0 0 0 0 \n",
409 | "0610030E20Rik 0 0 0 0 0 0 \n",
410 | "\n",
411 | "[5 rows x 13697 columns]"
412 | ]
413 | },
414 | "execution_count": 7,
415 | "metadata": {},
416 | "output_type": "execute_result"
417 | }
418 | ],
419 | "source": [
420 | "GRN.head()"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": 8,
426 | "id": "2087dc84-d2cd-4ae1-989f-5691ac6dbbc0",
427 | "metadata": {},
428 | "outputs": [
429 | {
430 | "name": "stderr",
431 | "output_type": "stream",
432 | "text": [
433 | "/home/icb/yifan.chen/miniconda3/envs/regvelo-py310-v2/lib/python3.10/site-packages/regvelo/preprocessing/set_prior_grn.py:75: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.\n",
434 | " adata.uns[\"regulators\"] = adata.var_names.to_numpy()\n"
435 | ]
436 | }
437 | ],
438 | "source": [
439 | "adata = rgv.pp.set_prior_grn(adata, GRN.T)"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": 9,
445 | "id": "652d5597-1abd-4c71-a328-00d3d8f5510f",
446 | "metadata": {},
447 | "outputs": [
448 | {
449 | "data": {
450 | "text/plain": [
451 | "AnnData object with n_obs × n_vars = 6788 × 2112\n",
452 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'\n",
453 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'\n",
454 | " uns: 'log1p', 'neighbors', 'umap', 'regulators', 'targets', 'skeleton', 'network'\n",
455 | " obsm: 'X_pca', 'X_umap'\n",
456 | " layers: 'spliced', 'unspliced', 'Ms', 'Mu'\n",
457 | " obsp: 'distances', 'connectivities'"
458 | ]
459 | },
460 | "execution_count": 9,
461 | "metadata": {},
462 | "output_type": "execute_result"
463 | }
464 | ],
465 | "source": [
466 | "adata"
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": 10,
472 | "id": "36ccbcad-1c0e-4807-9fd0-3ba512b8ced3",
473 | "metadata": {},
474 | "outputs": [
475 | {
476 | "name": "stdout",
477 | "output_type": "stream",
478 | "text": [
479 | "computing velocities\n",
480 | " finished (0:00:00) --> added \n",
481 | " 'velocity', velocity vectors for each individual cell (adata.layers)\n"
482 | ]
483 | }
484 | ],
485 | "source": [
486 | "velocity_genes = rgv.pp.preprocess_data(adata.copy()).var_names.tolist()\n",
487 | "\n",
488 | "# select TFs that regulate at least one gene\n",
489 | "TF = adata.var_names[adata.uns[\"skeleton\"].sum(1) != 0]\n",
490 | "var_mask = np.union1d(TF, velocity_genes)\n",
491 | "\n",
492 | "adata = adata[:, var_mask].copy()"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": 11,
498 | "id": "575c89de-7fc5-4a3b-9e8a-0287028f90fe",
499 | "metadata": {},
500 | "outputs": [
501 | {
502 | "name": "stdout",
503 | "output_type": "stream",
504 | "text": [
505 | "Number of genes: 1187\n",
506 | "Number of genes: 1164\n"
507 | ]
508 | }
509 | ],
510 | "source": [
511 | "adata = rgv.pp.filter_genes(adata)\n",
512 | "adata = rgv.pp.preprocess_data(adata, filter_on_r2=False)\n",
513 | "\n",
514 | "adata.var[\"velocity_genes\"] = adata.var_names.isin(velocity_genes)\n",
515 | "adata.var[\"TF\"] = adata.var_names.isin(TF)"
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": 12,
521 | "id": "152b7016-ff2a-435e-9f02-601ceadc31a7",
522 | "metadata": {},
523 | "outputs": [
524 | {
525 | "data": {
526 | "text/plain": [
527 | "AnnData object with n_obs × n_vars = 6788 × 1164\n",
528 | " obs: 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'UMI_count', 'gene_count', 'major_trajectory', 'celltype_update', 'UMAP_1', 'UMAP_2', 'UMAP_3', 'UMAP_2d_1', 'UMAP_2d_2', 'terminal_state', 'nCount_intron', 'nFeature_intron', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'\n",
529 | " var: 'vf_vst_counts_mean', 'vf_vst_counts_variance', 'vf_vst_counts_variance.expected', 'vf_vst_counts_variance.standardized', 'vf_vst_counts_variable', 'vf_vst_counts_rank', 'var.features', 'var.features.rank', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_genes', 'TF'\n",
530 | " uns: 'log1p', 'neighbors', 'umap', 'regulators', 'targets', 'skeleton', 'network'\n",
531 | " obsm: 'X_pca', 'X_umap'\n",
532 | " layers: 'spliced', 'unspliced', 'Ms', 'Mu'\n",
533 | " obsp: 'distances', 'connectivities'"
534 | ]
535 | },
536 | "execution_count": 12,
537 | "metadata": {},
538 | "output_type": "execute_result"
539 | }
540 | ],
541 | "source": [
542 | "adata"
543 | ]
544 | },
545 | {
546 | "cell_type": "markdown",
547 | "id": "ad07a3c3-9086-484b-8447-adea76ce2e45",
548 | "metadata": {},
549 | "source": [
550 | "## Save data\n",
551 | "This data, with the parameters chosen in this tutorial, can also be assessed by calling `rgv.datasets.murine_nc(data_type = \"preprocessed\")`"
552 | ]
553 | },
554 | {
555 | "cell_type": "code",
556 | "execution_count": 13,
557 | "id": "534c33f0-197a-424c-928e-9c5215fe2660",
558 | "metadata": {},
559 | "outputs": [],
560 | "source": [
561 | "adata.write_h5ad(\"adata_processed_velo.h5ad\")"
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": null,
567 | "id": "1d0da013-f4d5-4efd-9145-ba497056581b",
568 | "metadata": {},
569 | "outputs": [],
570 | "source": []
571 | }
572 | ],
573 | "metadata": {
574 | "kernelspec": {
575 | "display_name": "regvelo-py310-v2",
576 | "language": "python",
577 | "name": "regvelo-py310-v2"
578 | },
579 | "language_info": {
580 | "codemirror_mode": {
581 | "name": "ipython",
582 | "version": 3
583 | },
584 | "file_extension": ".py",
585 | "mimetype": "text/x-python",
586 | "name": "python",
587 | "nbconvert_exporter": "python",
588 | "pygments_lexer": "ipython3",
589 | "version": "3.10.13"
590 | }
591 | },
592 | "nbformat": 4,
593 | "nbformat_minor": 5
594 | }
595 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["poetry-core>=1.0.0"]
3 | build-backend = "poetry.core.masonry.api"
4 |
5 | [tool.poetry]
6 | authors = ["Weixu Wang "]
7 | classifiers = [
8 | "Development Status :: 4 - Beta",
9 | "Intended Audience :: Science/Research",
10 | "Natural Language :: English",
11 | "Programming Language :: Python :: 3.9",
12 | "Programming Language :: Python :: 3.10",
13 | "Programming Language :: Python :: 3.11",
14 | "Operating System :: MacOS :: MacOS X",
15 | "Operating System :: Microsoft :: Windows",
16 | "Operating System :: POSIX :: Linux",
17 | "Topic :: Scientific/Engineering :: Bio-Informatics",
18 | ]
19 | description = "Estimation of RNA velocity with variational inference."
20 | documentation = "https://scvi-tools.org"
21 | homepage = "https://github.com/theislab/RegVelo/"
22 | license = "BSD-3-Clause"
23 | name = "regvelo"
24 | packages = [
25 | {include = "regvelo"},
26 | ]
27 | readme = "README.md"
28 | version = "0.2.0"
29 |
30 | [tool.poetry.dependencies]
31 | anndata = ">=0.10.8"
32 | black = {version = ">=20.8b1", optional = true}
33 | codecov = {version = ">=2.0.8", optional = true}
34 | ruff = {version = "*", optional = true}
35 | importlib-metadata = {version = "^1.0", python = "<3.8"}
36 | ipython = {version = ">=7.1.1", optional = true}
37 | jupyter = {version = ">=1.0", optional = true}
38 | pre-commit = {version = ">=2.7.1", optional = true}
39 | sphinx-book-theme = {version = ">=1.0.0", optional = true}
40 | myst-nb = {version = "*", optional = true}
41 | sphinx-copybutton = {version = "*", optional = true}
42 | sphinxcontrib-bibtex = {version = "2.6.3", optional = true}
43 | ipykernel = {version = "*", optional = true}
44 | pytest = {version = ">=4.4", optional = true}
45 | pytest-cov = {version = "*", optional = true}
46 | python = ">=3.9,<4.0"
47 | python-igraph = {version = "*", optional = true}
48 | scanpy = {version = ">=1.10.3", optional = true}
49 | scanpydoc = {version = ">=0.5", optional = true}
50 | scvelo = ">=0.3.2"
51 | scvi-tools = ">=1.0.0,<1.2.1"
52 | scikit-learn = ">=0.21.2"
53 | velovi = ">=0.3.1"
54 | torchode = ">=0.1.6"
55 | cellrank = ">=2.0.0"
56 | matplotlib = ">=3.7.3"
57 | sphinx = {version = ">=4.1", optional = true}
58 | sphinx-autodoc-typehints = {version = "*", optional = true}
59 | torch = "<2.6.0"
60 |
61 |
62 | [tool.poetry.extras]
63 | dev = ["black", "pytest", "pytest-cov", "ruff", "codecov", "scanpy", "loompy", "jupyter", "pre-commit"]
64 | docs = [
65 | "sphinx",
66 | "scanpydoc",
67 | "ipython",
68 | "myst-nb",
69 | "sphinx-book-theme",
70 | "sphinx-copybutton",
71 | "sphinxcontrib-bibtex",
72 | "ipykernel",
73 | "ipython",
74 | ]
75 | tutorials = ["scanpy"]
76 |
77 | [tool.poetry.dev-dependencies]
78 |
79 |
80 | [tool.coverage.run]
81 | source = ["regvelo"]
82 | omit = [
83 | "**/test_*.py",
84 | ]
85 |
86 | [tool.pytest.ini_options]
87 | testpaths = ["tests"]
88 | xfail_strict = true
89 |
90 |
91 | [tool.black]
92 | include = '\.pyi?$'
93 | exclude = '''
94 | (
95 | /(
96 | \.eggs
97 | | \.git
98 | | \.hg
99 | | \.mypy_cache
100 | | \.tox
101 | | \.venv
102 | | _build
103 | | buck-out
104 | | build
105 | | dist
106 | )/
107 | )
108 | '''
109 |
110 | [tool.ruff]
111 | src = ["."]
112 | line-length = 119
113 | target-version = "py38"
114 | select = [
115 | "F", # Errors detected by Pyflakes
116 | "E", # Error detected by Pycodestyle
117 | "W", # Warning detected by Pycodestyle
118 | "I", # isort
119 | "D", # pydocstyle
120 | "B", # flake8-bugbear
121 | "TID", # flake8-tidy-imports
122 | "C4", # flake8-comprehensions
123 | "BLE", # flake8-blind-except
124 | "UP", # pyupgrade
125 | "RUF100", # Report unused noqa directives
126 | ]
127 | ignore = [
128 | # line too long -> we accept long comment lines; black gets rid of long code lines
129 | "E501",
130 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
131 | "E731",
132 | # allow I, O, l as variable names -> I is the identity matrix
133 | "E741",
134 | # Missing docstring in public package
135 | "D104",
136 | # Missing docstring in public module
137 | "D100",
138 | # Missing docstring in __init__
139 | "D107",
140 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
141 | "B008",
142 | # __magic__ methods are are often self-explanatory, allow missing docstrings
143 | "D105",
144 | # first line should end with a period [Bug: doesn't work with single-line docstrings]
145 | "D400",
146 | # First line should be in imperative mood; try rephrasing
147 | "D401",
148 | ## Disable one in each pair of mutually incompatible rules
149 | # We don’t want a blank line before a class docstring
150 | "D203",
151 | # We want docstrings to start immediately after the opening triple quote
152 | "D213",
153 | # Missing argument description in the docstring TODO: enable
154 | "D417",
155 | ]
156 |
157 | [tool.ruff.per-file-ignores]
158 | "docs/*" = ["I", "BLE001"]
159 | "tests/*" = ["D"]
160 | "*/__init__.py" = ["F401"]
161 | "regvelo/__init__.py" = ["I"]
162 |
163 | [tool.jupytext]
164 | formats = "ipynb,md"
165 |
166 | [tool.cruft]
167 | skip = [
168 | "tests",
169 | "src/**/__init__.py",
170 | "src/**/basic.py",
171 | "docs/api.md",
172 | "docs/changelog.md",
173 | "docs/references.bib",
174 | "docs/references.md",
175 | "docs/notebooks/example.ipynb",
176 | ]
177 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-20.04
4 | tools:
5 | python: "3.10"
6 | sphinx:
7 | configuration: docs/conf.py
8 | python:
9 | install:
10 | - method: pip
11 | path: .
12 | extra_requirements:
13 | - docs
14 |
--------------------------------------------------------------------------------
/regvelo/ModelComparison.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import scipy
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import seaborn as sns
7 |
8 | import scipy.stats
9 | from scipy.stats import ttest_rel
10 |
11 | import cellrank as cr
12 | import scanpy as sc
13 | import scvi
14 | from ._model import REGVELOVI
15 | import scvelo as scv
16 |
17 |
18 | from .tools.set_output import set_output
19 | from .tools._tsi import get_tsi_score
20 |
21 |
22 | # Packages used for data type validation
23 | from typing import List, Optional, Union, Dict
24 | from anndata import AnnData
25 |
26 | class ModelComparison:
27 | """Compare different types of RegVelo models : cite:p: `Wang2025`.
28 |
29 | This class is used to compare different RegVelo models with different optimization mode (soft, hard, soft_regularized) and under different normalization factor lamda2.
30 | User can evaluate and visulize competence of different types of models based on various side information (Real time, Pseudo Time, Stemness Score, Terminal States Identification, Cross Boundary Correctness) of cell.
31 | Finally, it will return a barplot with best performed model marked, and its performance will also be highlighted by significance test.
32 |
33 | Examples
34 | ----------
35 | See notebook.
36 |
37 | """
38 | def __init__(self,
39 | terminal_states: List = None,
40 | state_transition: Dict = None,
41 | n_states: int = None):
42 | """Initialize parameters in comparision object.
43 |
44 | Parameters
45 | ----------
46 | terminal_states
47 | A list records all terminal states among all cell types.
48 | This parameter is not necessary if you don't use TSI as side_information. Please make sure they are consistent with information stored in 'side_key' under TSI mode.
49 | state_transition
50 | A dict records all possible state transition relationships among all cell types.
51 | This parameter is not necessary if you don't use CBC as side_information. Please make sure they are consistent with information stored in 'side_key' under CBC mode.
52 | n_state
53 | A integer provide the number of cell clusters in total.
54 | This parameter is not necessary if you don't use TSI as side_information.
55 |
56 | Returns
57 | ----------
58 | An comparision object. You can deal with more operations as follows.
59 |
60 | """
61 | self.TERMINAL_STATES = terminal_states
62 | self.STATE_TRANSITION = state_transition
63 | self.N_STATES = n_states
64 |
65 | self.METHOD = None
66 | self.MODEL_LIST = None
67 |
68 | self.side_info_dict = {
69 | 'Pseudo_Time':'dpt_pseudotime',
70 | 'Stemness_Score': 'ct_score',
71 | 'Real_Time': None,
72 | 'TSI': None,
73 | 'CBC': None
74 | }
75 | self.MODEL_TRAINED = {}
76 |
77 | def validate_input(self,
78 | adata: AnnData,
79 | model_list: List[str] = None,
80 | side_information: str = None,
81 | lam2: Union[List[float], float] = None,
82 | side_key: str = None) -> None:
83 |
84 | # 1.Validate adata
85 | if not isinstance(adata, AnnData):
86 | raise TypeError(f"Expected AnnData object, got {type(adata).__name__}")
87 | layers = ['Ms', 'Mu', 'fit_t']
88 | for layer in layers:
89 | if layer not in adata.layers:
90 | raise ValueError(f"Missing required layer: {layer}")
91 | if 'skeleton' not in adata.uns:
92 | raise ValueError("Missing required 'skeleton' in adata.uns")
93 |
94 | if 'TF' not in adata.var:
95 | raise ValueError("Missing required 'TF' column in adata.var")
96 |
97 | # 2.Validate Model_list
98 | if model_list is not None:
99 | valid_models = ['hard', 'soft', 'soft_regularized']
100 | if not isinstance(model_list, list) or len(model_list) == 0:
101 | raise ValueError("model_list must be a non-empty list")
102 | for model in model_list:
103 | if model not in valid_models:
104 | raise ValueError(f"Invalid model: {model}. Valid models are {valid_models}")
105 | if model == 'soft_regularized' and lam2 is None:
106 | raise ValueError(f"Under 'soft_regularized' mode, lam2 must be given")
107 | if lam2 is not None:
108 | if not isinstance(lam2, (float, list)):
109 | raise TypeError('lam2 must be a float or a list of floats')
110 | if isinstance(lam2, list):
111 | if len(lam2) == 0:
112 | raise ValueError('lam2 list can not be empty')
113 | for num in lam2:
114 | if not isinstance(num, float):
115 | raise ValueError('All elements in lam2 list must be float')
116 | if not 0.0 < num <= 1.0:
117 | raise ValueError('lam2 is expected to be in range of (0,1)')
118 |
119 | # 3.Validate side_information
120 | if side_information is not None:
121 | if not isinstance(side_information, str):
122 | raise TypeError(f"side_information must be a string")
123 |
124 | if side_information not in self.side_info_dict.keys():
125 | raise ValueError(f"Valid side_information are {self.side_info_dict.keys()}")
126 |
127 | if side_key is not None:
128 | if not isinstance(side_key, str):
129 | raise TypeError(f"side_key must be a string")
130 | if side_key not in adata.obs:
131 | raise TypeError(f"side_key: {side_key} not found in adata.obs.")
132 | if side_key is None:
133 | side_key = self.side_info_dict[side_information]
134 | if side_key is not None:
135 | if side_key not in adata.obs:
136 | raise TypeError(f"Default side_key: {side_key} not found in adata.obs, please input it manualy with parameter: side_key")
137 |
138 | def min_max_scaling(self,x):
139 | return (x - np.min(x)) / (np.max(x) - np.min(x))
140 |
141 | def train(
142 | self,
143 | adata: AnnData,
144 | model_list: List[str],
145 | lam2: Union[List[float], float] = None,
146 | n_repeat: int = 1
147 | ) -> List:
148 | """Train all the possible models given by users, and stored them in a dictionary, where users can reach them easily and deal with them in batch.If there are already model trained and saved before, they won't be removed.
149 |
150 | Parameters
151 | ----------
152 | adata
153 | The annotated data matrix. After input of adata, the object will store it as self variable.
154 | model_list
155 | The list of valid model type, including 'Soft', 'Hard', 'Soft_regularized'
156 | lam2
157 | Normalization factor used under 'soft_regularized' mode. A float or a list of float number in range of (0,1)
158 |
159 | Returns
160 | ----------
161 | A dictionary key names, represent to all models trained in this step.
162 |
163 | """
164 | self.validate_input(adata, model_list = model_list, lam2 = lam2)
165 | self.ADATA = adata
166 |
167 | if not isinstance(n_repeat, int) or n_repeat < 1:
168 | raise ValueError("n_repeat must be a positive integer")
169 |
170 | W = adata.uns["skeleton"].copy()
171 | W = torch.tensor(np.array(W)).int()
172 | TF = adata.var_names[adata.var['TF']]
173 | REGVELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu")
174 |
175 |
176 | for model in model_list:
177 | for nrun in range(n_repeat):
178 | scvi.settings.seed = nrun
179 | if model == 'soft_regularized':
180 | if isinstance(lam2,list):
181 | for lambda2 in lam2:
182 | vae = REGVELOVI(
183 | adata,
184 | W=W.T,
185 | regulators=TF,
186 | lam2 = lambda2
187 | )
188 | vae.train()
189 | self.MODEL_TRAINED[f"{model}\nlam2:{lambda2}_{nrun}"] = vae
190 | else:
191 | vae = REGVELOVI(
192 | adata,
193 | W=W.T,
194 | regulators=TF,
195 | lam2=lam2
196 | )
197 | vae.train()
198 | self.MODEL_TRAINED[f"{model}_{nrun}"] = vae
199 | else:
200 | vae = REGVELOVI(
201 | adata,
202 | W=W.T,
203 | regulators=TF,
204 | soft_constraint=(model == 'soft')
205 | )
206 | vae.train()
207 | self.MODEL_TRAINED[f"{model}_{nrun}"] = vae
208 |
209 | return list(self.MODEL_TRAINED.keys())
210 |
211 | def evaluate(
212 | self,
213 | side_information: str,
214 | side_key:str = None
215 | ) -> pd.DataFrame:
216 | """Evaluate all of trained model under one specific side_information mode, For example, if user know the exact time or stage of cells, user can choose 'Real_Time' as reference; If users has used Pseudotime calculator such as CellRank beforehand, they can also choose 'Pseudo_Time' as reference.
217 |
218 | Parameters
219 | ----------
220 | side_information
221 | User can choose perspectives to compare RegVelo models, including 'Real_Time', 'Pseudo_Time', 'Stemness_Score','TSI','CBC'.
222 | side_key
223 | Column name of adata.obs which used to store information of selected side_information. For 'Pseudo_Time' and 'Stemness_Score', we provide default side_key, but you can also choose your own side_key as input.
224 |
225 | Returns
226 | ----------
227 | A dataframe records evaluation performance of all models.
228 | """
229 | self.validate_input(self.ADATA, side_information=side_information, side_key=side_key)
230 | correlations = []
231 |
232 | for model, vae in self.MODEL_TRAINED.items():
233 | set_output(self.ADATA, vae, n_samples = 30, batch_size = self.ADATA.n_obs)
234 | fit_t_mean = self.ADATA.layers['fit_t'].mean(axis = 1)
235 | self.ADATA.obs["latent_time"] = self.min_max_scaling(fit_t_mean)
236 | corr = np.abs(self.calculate(
237 | self.ADATA, side_information, side_key
238 | ))
239 | correlations.append({
240 | 'Model': model[:-2],
241 | 'Corr':corr,
242 | 'Run':model[-1]
243 | })
244 | df_name = f"df_{side_information}"
245 | df = pd.DataFrame(correlations)
246 | setattr(self, df_name, df)
247 | return df_name, df
248 |
249 | def calculate(
250 | self,
251 | adata: AnnData,
252 | side_information: str,
253 | side_key: str = None
254 | ):
255 | if side_information in ['Pseudo_Time', 'Stemness_Score', 'Real_Time']:
256 | if side_information in ['Pseudo_Time', 'Stemness_Score'] and side_key is None:
257 | side_key = self.side_info_dict[side_information]
258 | return scipy.stats.spearmanr(self.ADATA.obs[side_key].values, self.ADATA.obs['latent_time'])[0]
259 | elif side_information == 'TSI':
260 | thresholds = np.linspace(0.1,1,21)[:20]
261 | vk = cr.kernels.VelocityKernel(self.ADATA)
262 | vk.compute_transition_matrix()
263 | ck = cr.kernels.ConnectivityKernel(self.ADATA).compute_transition_matrix()
264 | kernel = 0.8 * vk + 0.2 * ck
265 | estimator = cr.estimators.GPCCA(kernel)
266 | estimator.compute_macrostates(n_states=self.N_STATES, n_cells=30, cluster_key=side_key)
267 | return np.mean(get_tsi_score(self.ADATA, thresholds, side_key, self.TERMINAL_STATES, estimator))
268 | elif side_information == 'CBC':
269 | self.ADATA.obs['CBC_key'] = self.ADATA.obs[side_key].astype(str)
270 | vk = cr.kernels.VelocityKernel(self.ADATA)
271 | vk.compute_transition_matrix()
272 | ck = cr.kernels.ConnectivityKernel(self.ADATA).compute_transition_matrix()
273 | kernel = 0.8 * vk + 0.2 * ck
274 | cbc_values = []
275 | for source, target in self.STATE_TRANSITION:
276 | cbc = kernel.cbc(source = source, target=target, cluster_key='CBC_key', rep = 'X_pca')
277 | cbc_values.append(np.mean(cbc))
278 | return np.mean(cbc_values)
279 |
280 | def plot_results(
281 | self,
282 | side_information,
283 | figsize = (6, None),
284 | palette = 'lightpink'
285 | ):
286 | """Visualize comparision result by barplot with scatters. The significant mark will only show with n_repeats more than 3, and p < 0.05.
287 |
288 | Paramters
289 | ----------
290 | side_information
291 | Here choose the side_information you wish to visulize, which must be performed in 'evaluation' step in advance.
292 | figsize
293 | You can choose the size of figure. Default is (6,None), which means the height of the plot are set to change with the number of models.
294 | palette
295 | You can choose the color of barplot.
296 |
297 | Returns
298 | ----------
299 | A barplot with scatters, represent performance of all models.
300 | """
301 | df_name = f"df_{side_information}"
302 | data = getattr(self, df_name)
303 |
304 | model_order = data.groupby('Model')['Corr'].mean().sort_values(ascending=False).index.tolist()
305 | num_models = len(model_order)
306 | fig_height = 2 + num_models * 0.5
307 | figsize = (figsize[0], fig_height)
308 |
309 | sns.set(style='whitegrid', font_scale=1.2)
310 | fig, ax = plt.subplots(figsize=figsize)
311 |
312 |
313 | sns.barplot(
314 | y="Model",
315 | x="Corr",
316 | data=data,
317 | order = model_order,
318 | width=0.3,
319 | ci="sd",
320 | capsize=0.1,
321 | errwidth=2,
322 | color=palette,
323 | ax = ax)
324 | sns.stripplot(
325 | y="Model",
326 | x="Corr",
327 | data=data,
328 | order = model_order,
329 | dodge=True,
330 | jitter=0.25,
331 | color="black",
332 | size = 4,
333 | alpha=0.8,
334 | ax = ax
335 | )
336 |
337 | model_means = data.groupby('Model')['Corr'].mean()
338 | ref_model = model_means.idxmax()
339 |
340 | ref_data = data[data["Model"] == ref_model]["Corr"]
341 | y_ticks = ax.get_yticks()
342 | model_positions = dict(zip(model_order, y_ticks))
343 |
344 | for target_model in model_order:
345 | if target_model == ref_model:
346 | continue
347 | target_data = data[data["Model"] == target_model]["Corr"]
348 |
349 | if len(target_data) < 3:
350 | continue
351 |
352 | try:
353 | t_stat, p_value = scipy.stats.ttest_rel(
354 | ref_data,
355 | target_data,
356 | alternative="greater"
357 | )
358 | except ValueError as e:
359 | print(f"Significance test: {ref_model} vs {target_model} failed:{str(e)}")
360 | continue
361 |
362 | if p_value < 0.05:
363 | significance = self.get_significance(p_value)
364 | self._draw_significance_marker(
365 | ax=ax,
366 | start=model_positions[ref_model],
367 | end=model_positions[target_model],
368 | significance=significance,
369 | bracket_height=0.05
370 | )
371 |
372 | ax.set_title(
373 | f"Prediction based on {side_information}",
374 | fontsize=12,
375 | wrap = True
376 | )
377 | ax.set_ylabel('')
378 | ax.set_xlabel(
379 | "Prediction Score",
380 | fontsize=12,
381 | labelpad=10
382 | )
383 | plt.xticks(fontsize=10)
384 | plt.tight_layout()
385 | plt.show()
386 |
387 | def get_significance(self, pvalue):
388 | if pvalue < 0.001:
389 | return "***"
390 | elif pvalue < 0.01:
391 | return "**"
392 | elif pvalue < 0.05:
393 | return "*"
394 | else:
395 | return "ns"
396 |
397 | def _draw_significance_marker(
398 | self,
399 | ax,
400 | start,
401 | end,
402 | significance,
403 | bracket_height=0.05,
404 | linewidth=1.2,
405 | text_offset=0.05):
406 |
407 | if start > end:
408 | start, end = end, start
409 | x_max = ax.get_xlim()[1]
410 | bracket_level = max(start, end) + bracket_height
411 |
412 | ax.plot(
413 | [bracket_level-0.02, bracket_level, bracket_level, bracket_level-0.02],
414 | [start, start, end, end],
415 | color='black',
416 | lw=linewidth,
417 | solid_capstyle="butt",
418 | clip_on=False
419 | )
420 |
421 | ax.text(
422 | bracket_level + text_offset,
423 | (start + end)/2,
424 | significance,
425 | ha='center',
426 | va='baseline',
427 | color='black',
428 | fontsize=10,
429 | fontweight='bold',
430 | rotation = 90
431 | )
432 |
--------------------------------------------------------------------------------
/regvelo/__init__.py:
--------------------------------------------------------------------------------
1 | """regvelo."""
2 |
3 | import logging
4 |
5 | from rich.console import Console
6 | from rich.logging import RichHandler
7 |
8 | from regvelo import datasets
9 | from regvelo import tools as tl
10 | from regvelo import plotting as pl
11 | from regvelo import preprocessing as pp
12 |
13 | from ._constants import REGISTRY_KEYS
14 | from ._model import REGVELOVI, VELOVAE
15 | from .ModelComparison import ModelComparison
16 |
17 | import sys # isort:skip
18 |
19 | sys.modules.update({f"{__name__}.{m}": globals()[m] for m in ["tl", "pl", "pp"]})
20 |
21 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
22 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
23 | try:
24 | import importlib.metadata as importlib_metadata
25 | except ModuleNotFoundError:
26 | import importlib_metadata
27 |
28 | package_name = "regvelo"
29 | __version__ = importlib_metadata.version(package_name)
30 |
31 | logger = logging.getLogger(__name__)
32 | # set the logging level
33 | logger.setLevel(logging.INFO)
34 |
35 | # nice logging outputs
36 | console = Console(force_terminal=True)
37 | if console.is_jupyter is True:
38 | console.is_jupyter = False
39 | ch = RichHandler(show_path=False, console=console, show_time=False)
40 | formatter = logging.Formatter("regvelo: %(message)s")
41 | ch.setFormatter(formatter)
42 | logger.addHandler(ch)
43 |
44 | # this prevents double outputs
45 | logger.propagate = False
46 |
47 | __all__ = [
48 | "REGVELOVI",
49 | "VELOVAE",
50 | "REGISTRY_KEYS",
51 | "datasets",
52 | "ModelComparison"
53 | ]
54 |
--------------------------------------------------------------------------------
/regvelo/_constants.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 |
4 | class _REGISTRY_KEYS_NT(NamedTuple):
5 | X_KEY: str = "X"
6 | U_KEY: str = "U"
7 |
8 |
9 | REGISTRY_KEYS = _REGISTRY_KEYS_NT()
10 |
--------------------------------------------------------------------------------
/regvelo/_module.py:
--------------------------------------------------------------------------------
1 | """Main module."""
2 | from typing import Callable, Iterable, Literal, Optional, Any
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
8 | from scvi.nn import Encoder, FCLayers
9 | from torch import nn as nn
10 | from torch.distributions import Categorical, Dirichlet, MixtureSameFamily, Normal
11 | from torch.distributions import kl_divergence as kl
12 | import torchode as to
13 | from ._constants import REGISTRY_KEYS
14 | from torch import Tensor
15 | import torch.nn.utils.prune as prune
16 |
17 | torch.backends.cudnn.benchmark = True
18 |
19 | def _softplus_inverse(x: np.ndarray) -> np.ndarray:
20 | x = torch.from_numpy(x)
21 | x_inv = torch.where(x > 20, x, x.expm1().log()).numpy()
22 | return x_inv
23 |
24 | class ThresholdPruning(prune.BasePruningMethod):
25 | PRUNING_TYPE = "unstructured"
26 |
27 | def __init__(self, threshold):
28 | self.threshold = threshold
29 |
30 | def compute_mask(self, tensor, default_mask):
31 | return torch.abs(tensor) > self.threshold
32 |
33 | class DecoderVELOVI(nn.Module):
34 | """Decodes data from latent space of ``n_input`` dimensions ``n_output``dimensions.
35 |
36 | Uses a fully-connected neural network of ``n_hidden`` layers.
37 |
38 | Parameters
39 | ----------
40 | n_input
41 | The dimensionality of the input (latent space)
42 | n_output
43 | The dimensionality of the output (data space)
44 | n_cat_list
45 | A list containing the number of categories
46 | for each category of interest. Each category will be
47 | included using a one-hot encoding
48 | n_layers
49 | The number of fully-connected hidden layers
50 | n_hidden
51 | The number of nodes per hidden layer
52 | dropout_rate
53 | Dropout rate to apply to each of the hidden layers
54 | inject_covariates
55 | Whether to inject covariates in each layer, or just the first (default).
56 | use_batch_norm
57 | Whether to use batch norm in layers
58 | use_layer_norm
59 | Whether to use layer norm in layers
60 | linear_decoder
61 | Whether to use linear decoder for time
62 | """
63 |
64 | def __init__(
65 | self,
66 | n_input: int,
67 | n_output: int,
68 | n_cat_list: Iterable[int] = None,
69 | n_layers: int = 1,
70 | n_hidden: int = 128,
71 | inject_covariates: bool = True,
72 | use_batch_norm: bool = True,
73 | use_layer_norm: bool = False,
74 | dropout_rate: float = 0.0,
75 | linear_decoder: bool = False,
76 | **kwargs,
77 | ):
78 | super().__init__()
79 | self.n_output = n_output
80 | self.linear_decoder = linear_decoder
81 | self.rho_first_decoder = FCLayers(
82 | n_in=n_input,
83 | n_out=n_hidden if not linear_decoder else n_output,
84 | n_cat_list=n_cat_list,
85 | n_layers=n_layers if not linear_decoder else 1,
86 | n_hidden=n_hidden,
87 | dropout_rate=dropout_rate,
88 | inject_covariates=inject_covariates,
89 | use_batch_norm=use_batch_norm if not linear_decoder else False,
90 | use_layer_norm=use_layer_norm if not linear_decoder else False,
91 | use_activation=not linear_decoder,
92 | bias=not linear_decoder,
93 | **kwargs,
94 | )
95 |
96 | # rho for induction
97 | self.px_rho_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid())
98 |
99 | def forward(self, z: torch.Tensor, latent_dim: int = None):
100 | """The forward computation for a single sample.
101 |
102 | #. Decodes the data from the latent space using the decoder network
103 | #. Returns parameters for the ZINB distribution of expression
104 | #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None``
105 |
106 | Parameters
107 | ----------
108 | z :
109 | tensor with shape ``(n_input,)``
110 | cat_list
111 | list of category membership(s) for this sample
112 |
113 | Returns
114 | -------
115 | 4-tuple of :py:class:`torch.Tensor`
116 | parameters for the ZINB distribution of expression
117 |
118 | """
119 | z_in = z
120 | if latent_dim is not None:
121 | mask = torch.zeros_like(z)
122 | mask[..., latent_dim] = 1
123 | z_in = z * mask
124 | # The decoder returns values for the parameters of the ZINB distribution
125 | rho_first = self.rho_first_decoder(z_in)
126 |
127 | if not self.linear_decoder:
128 | px_rho = self.px_rho_decoder(rho_first)
129 | else:
130 | px_rho = nn.Sigmoid()(torch.matmul(z_in,torch.ones([1,self.n_output])))
131 |
132 | return px_rho
133 |
134 | ## define a new class velocity encoder
135 | class velocity_encoder(nn.Module):
136 | """Encode the velocity
137 |
138 | time dependent transcription rate is determined by upstream regulator, velocity could be build on top of this
139 |
140 | Parameters
141 | ----------
142 | activate
143 | activate function used for modeling transcription rate
144 | bas_alpha
145 | adding base transcription rate
146 | n_int
147 | number of genes
148 | """
149 | def __init__(
150 | self,
151 | activate: str = "softplus",
152 | base_alpha: bool = True,
153 | n_int: int = 5,
154 | ):
155 | super().__init__()
156 | self.n_int = n_int
157 | self.fc1 = nn.Linear(n_int, n_int)
158 | self.activate = activate
159 | self.base_alpha = base_alpha
160 |
161 | def _set_mask_grad(self):
162 | self.hooks = []
163 |
164 | def _hook_mask_no_regulator(grad):
165 | return grad * self.mask_m
166 |
167 | w_grn = self.fc1.weight.register_hook(_hook_mask_no_regulator)
168 | self.hooks.append(w_grn)
169 |
170 |
171 | ## TODO: regularizing the jacobian
172 | def GRN_Jacobian(self,s):
173 |
174 | if self.activate is not "OR":
175 | if self.base_alpha is not True:
176 | grn = self.fc1.weight
177 | #grn = grn - self.lamb_I
178 | alpha_unconstr = torch.matmul(s,grn.T)
179 | else:
180 | alpha_unconstr = self.fc1(s)
181 |
182 | if self.activate == "softplus":
183 | coef = (torch.sigmoid(alpha_unconstr))
184 | if self.activate == "sigmoid":
185 | coef = (torch.sigmoid(alpha_unconstr))*(1 - torch.sigmoid(alpha_unconstr))*self.alpha_unconstr_max
186 | else:
187 | coef = (1 / (torch.nn.functional.softsign(s) + 1)) * (1 / (1 + torch.abs(s - 0.5))**2) * torch.exp(self.fc1(torch.log(torch.nn.functional.softsign(s - 0.5)+1)))
188 |
189 | if coef.dim() > 1:
190 | coef = coef.mean(0)
191 | Jaco_m = torch.matmul(torch.diag(coef), self.fc1.weight)
192 |
193 |
194 | return Jaco_m
195 |
196 | def GRN_Jacobian2(self,s):
197 |
198 | if self.base_alpha is not True:
199 | grn = self.fc1.weight
200 | alpha_unconstr = torch.matmul(s,grn.T)
201 | else:
202 | alpha_unconstr = self.fc1(s)
203 |
204 | if self.activate == "softplus":
205 | coef = (torch.sigmoid(alpha_unconstr))
206 | if self.activate == "sigmoid":
207 | coef = (torch.sigmoid(alpha_unconstr))*(1 - torch.sigmoid(alpha_unconstr))*self.alpha_unconstr_max
208 |
209 | # Perform element-wise multiplication
210 | Jaco = coef.unsqueeze(-1) * self.fc1.weight.unsqueeze(0)
211 |
212 | # Transpose and reshape to get the final 3D tensor with dimensions (m, n, n)
213 | Jaco = Jaco.reshape(s.shape[0], s.shape[1], s.shape[1])
214 |
215 | return Jaco
216 |
217 |
218 | def transcription_rate(self,s):
219 | if self.activate is not "OR":
220 | if self.base_alpha is not True:
221 | grn = self.fc1.weight
222 | #grn = grn - self.lamb_I
223 | alpha_unconstr = torch.matmul(s,grn.T)
224 | else:
225 | alpha_unconstr = self.fc1(s)
226 |
227 | if self.activate == "softplus":
228 | alpha = torch.clamp(F.softplus(alpha_unconstr),0,50)
229 | elif self.activate == "sigmoid":
230 | alpha = torch.sigmoid(alpha_unconstr)*self.alpha_unconstr_max
231 | else:
232 | raise NotImplementedError
233 | elif self.activate is "OR":
234 | alpha = torch.exp(self.fc1(torch.log(torch.nn.functional.softsign(s - 0.5)+1)))
235 |
236 | return alpha
237 |
238 | ## TODO: introduce sparsity in the model
239 | def forward(self,t, u, s):
240 | ## split x into unspliced and spliced readout
241 | ## x is a matrix with (G*2), in which row is a subgraph (batch)
242 | beta = torch.clamp(F.softplus(self.beta_mean_unconstr), 0, 50)
243 | gamma = torch.clamp(F.softplus(self.gamma_mean_unconstr), 0, 50)
244 | if self.activate is not "OR":
245 | if self.base_alpha is not True:
246 | grn = self.fc1.weight
247 | alpha_unconstr = torch.matmul(s,grn.T)
248 | else:
249 | alpha_unconstr = self.fc1(s)
250 |
251 | if self.activate == "softplus":
252 | alpha = torch.clamp(F.softplus(alpha_unconstr),0,50)
253 | elif self.activate == "sigmoid":
254 | alpha = torch.sigmoid(alpha_unconstr)*self.alpha_unconstr_max
255 | elif self.activate is "OR":
256 | alpha = torch.exp(self.fc1(torch.log(torch.nn.functional.softsign(s - 0.5)+1)))
257 | else:
258 | raise NotImplementedError
259 |
260 | ## Predict velocity
261 | du = alpha - beta*u
262 | ds = beta*u - gamma*s
263 |
264 | return du,ds
265 |
266 | class v_encoder_batch(nn.Module):
267 | """Batching the velocity
268 |
269 | Parameters
270 | ----------
271 | num_g
272 | number of genes
273 | """
274 |
275 | def __init__(
276 | self,
277 | num_g: int = 5,
278 | ):
279 | super().__init__()
280 | self.num_g = num_g
281 |
282 | def forward(self,t,y):
283 | """
284 | in which x is a reshape matrix: (g*n) * 2
285 | we first reshape x into two matrix: unspliced (g*n) and spliced (g*n)
286 | and calculate velocity
287 | then shape back to the vector: (g*n) * 2
288 | the batch number in this case is g*n
289 | """
290 | u_v = y[:,0]
291 | s_v = y[:,1]
292 | u = u_v.reshape(-1, self.num_g)
293 | s = s_v.reshape(-1, self.num_g)
294 | du, ds = self.v_encoder_class(t, u, s)
295 |
296 | ## reshape du and ds
297 | du = du.reshape(-1, 1)
298 | ds = ds.reshape(-1, 1)
299 |
300 | v = torch.concatenate([du,ds],axis = 1)
301 |
302 | return v
303 |
304 |
305 | # VAE model
306 | class VELOVAE(BaseModuleClass):
307 | """Variational auto-encoder model.
308 |
309 | This is an implementation of the RegVelo model.
310 |
311 | Parameters
312 | ----------
313 | n_input
314 | Number of input genes.
315 | regulator_index
316 | list index for all regulators.
317 | target_index
318 | list index for all targets.
319 | skeleton
320 | prior gene regulatory graph.
321 | regulator_list
322 | a integer list represents where is the regulators.
323 | activate
324 | Activation function used for modeling transcription rate.
325 | base_alpha
326 | Adding base transcription rate.
327 | n_hidden
328 | Number of nodes per hidden layer.
329 | n_latent
330 | Dimensionality of the latent space.
331 | n_layers
332 | Number of hidden layers used for encoder and decoder NNs.
333 | lam
334 | Regularization parameter for controling the strengths of adding prior knowledge.
335 | lam2
336 | Regularization parameter for controling the strengths of L1 regularization to the Jacobian matrix.
337 | vector_constraint
338 | Regularization on velocity.
339 | bias_constraint
340 | Regularization on bias term (base transcription rate).
341 | dropout_rate
342 | Dropout rate for neural networks
343 | log_variational
344 | Log(data+1) prior to encoding for numerical stability. Not normalization.
345 | latent_distribution
346 | One of
347 |
348 | * ``'normal'`` - Isotropic normal
349 | * ``'ln'`` - Logistic normal with normal params N(0, 1)
350 | use_layer_norm
351 | Whether to use layer norm in layers
352 | var_activation
353 | Callable used to ensure positivity of the variational distributions' variance.
354 | When `None`, defaults to `torch.exp`.
355 | """
356 |
357 | def __init__(
358 | self,
359 | n_input: int,
360 | regulator_index: list,
361 | target_index: list,
362 | skeleton: torch.Tensor,
363 | regulator_list: list,
364 | activate: Literal["sigmoid", "softplus"] = "softplus",
365 | base_alpha: bool = True,
366 | n_hidden: int = 128,
367 | n_latent: int = 10,
368 | n_layers: int = 1,
369 | lam: float = 1,
370 | lam2: float = 0,
371 | vector_constraint: bool = True,
372 | alpha_constraint: float = 0.1,
373 | bias_constraint: bool = True,
374 | dropout_rate: float = 0.1,
375 | log_variational: bool = False,
376 | latent_distribution: str = "normal",
377 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
378 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both",
379 | var_activation: Optional[Callable] = torch.nn.Softplus(),
380 | gamma_unconstr_init: Optional[np.ndarray] = None,
381 | alpha_unconstr_init: Optional[np.ndarray] = None,
382 | alpha_1_unconstr_init: Optional[np.ndarray] = None,
383 | x0: Optional[np.ndarray] = None,
384 | t0: Optional[np.ndarray] = None,
385 | t_max: float = 20,
386 | linear_decoder: bool = False,
387 | soft_constraint: bool = True,
388 | auto_regulation: bool = False,
389 | ):
390 | super().__init__()
391 | self.n_latent = n_latent
392 | self.log_variational = log_variational
393 | self.latent_distribution = latent_distribution
394 | self.n_input = n_input
395 | self.t_max = t_max
396 | self.lamba = lam
397 | self.lamba2 = lam2
398 | self.vector_constraint = vector_constraint
399 | self.alpha_constraint = alpha_constraint
400 | self.bias_constraint = bias_constraint
401 | self.soft_constraint = soft_constraint
402 |
403 |
404 | n_genes = n_input * 2
405 | n_targets = sum(target_index)
406 | n_regulators = sum(regulator_index)
407 | self.n_targets = int(n_targets)
408 | self.n_regulators = int(n_regulators)
409 | self.regulator_index = regulator_index
410 | self.target_index = target_index
411 |
412 | # degradation for each target gene
413 | if gamma_unconstr_init is None:
414 | self.gamma_mean_unconstr = torch.nn.Parameter(-1 * torch.ones(n_targets))
415 | else:
416 | self.gamma_mean_unconstr = torch.nn.Parameter(
417 | torch.from_numpy(gamma_unconstr_init)
418 | )
419 |
420 | # splicing for each target gene
421 | # first samples around 1
422 | self.beta_mean_unconstr = torch.nn.Parameter(0.5 * torch.ones(n_targets))
423 |
424 | # transcription (bias term for target gene transcription rate function)
425 | if alpha_unconstr_init is None:
426 | self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_targets))
427 | else:
428 | #self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_targets))
429 | self.alpha_unconstr = torch.nn.Parameter(
430 | torch.from_numpy(alpha_unconstr_init)
431 | )
432 |
433 | # TODO: Add `require_grad`
434 | ## The maximum transcription rate (alpha_1) for each target gene
435 | if alpha_1_unconstr_init is None:
436 | self.alpha_1_unconstr = torch.nn.Parameter(torch.ones(n_targets))
437 | else:
438 | self.alpha_1_unconstr = torch.nn.Parameter(
439 | torch.from_numpy(alpha_1_unconstr_init)
440 | )
441 | self.alpha_1_unconstr.data = self.alpha_1_unconstr.data.float()
442 |
443 | # likelihood dispersion
444 | # for now, with normal dist, this is just the variance for target genes
445 | self.scale_unconstr_targets = torch.nn.Parameter(-1 * torch.ones(n_targets*2, 3))
446 |
447 | use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
448 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both"
449 | use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
450 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both"
451 | self.use_batch_norm_decoder = use_batch_norm_decoder
452 |
453 | # z encoder goes from the n_input-dimensional data to an n_latent-d
454 | # latent space representation
455 | n_input_encoder = n_genes
456 | self.z_encoder = Encoder(
457 | n_input_encoder,
458 | n_latent,
459 | n_layers=n_layers,
460 | n_hidden=n_hidden,
461 | dropout_rate=dropout_rate,
462 | distribution=latent_distribution,
463 | use_batch_norm=use_batch_norm_encoder,
464 | use_layer_norm=use_layer_norm_encoder,
465 | var_activation=var_activation,
466 | activation_fn=torch.nn.ReLU,
467 | )
468 | # decoder goes from n_latent-dimensional space to n_target-d data
469 | n_input_decoder = n_latent
470 | self.decoder = DecoderVELOVI(
471 | n_input_decoder,
472 | n_targets,
473 | n_layers=n_layers,
474 | n_hidden=n_hidden,
475 | use_batch_norm=use_batch_norm_decoder,
476 | use_layer_norm=use_layer_norm_decoder,
477 | activation_fn=torch.nn.ReLU,
478 | linear_decoder=linear_decoder,
479 | )
480 |
481 | # define velocity encoder, define velocity vector for target genes
482 | self.v_encoder = velocity_encoder(n_int = n_targets,activate = activate,base_alpha = base_alpha)
483 | self.v_encoder.fc1.weight = torch.nn.Parameter(0 * torch.ones(self.v_encoder.fc1.weight.shape))
484 | # saved kinetic parameter in velocity encoder module
485 | self.v_encoder.regulator_index = self.regulator_index
486 | self.v_encoder.beta_mean_unconstr = self.beta_mean_unconstr
487 | self.v_encoder.gamma_mean_unconstr = self.gamma_mean_unconstr
488 | self.v_encoder.register_buffer("alpha_unconstr_max", torch.tensor(10.0))
489 |
490 | # initilize grn (masked parameters)
491 | if self.soft_constraint is not True:
492 | self.v_encoder.register_buffer("mask_m", skeleton)
493 | self.v_encoder._set_mask_grad()
494 | else:
495 | if regulator_list is not None:
496 | skeleton_ref = torch.zeros(skeleton.shape)
497 | skeleton_ref[:,regulator_list] = 1
498 | else:
499 | skeleton_ref = torch.ones(skeleton.shape)
500 | if not auto_regulation:
501 | skeleton_ref[range(skeleton_ref.shape[0]), range(skeleton_ref.shape[1])] = 0
502 | self.v_encoder.register_buffer("mask_m", skeleton_ref)
503 | self.v_encoder._set_mask_grad()
504 | self.v_encoder.register_buffer("mask_m_raw", skeleton)
505 |
506 |
507 | ## define batch velocity vector for numerical solver
508 | self.v_batch = v_encoder_batch(num_g = n_targets)
509 | self.v_batch.v_encoder_class = self.v_encoder
510 |
511 | ## register variable for torchode
512 | if x0 is not None:
513 | self.register_buffer("x0", torch.tensor(x0))
514 | else:
515 | self.register_buffer("x0", torch.zeros([n_targets,2]))
516 |
517 | ## TODO: follow xiaojie suggestion, update x0 estimate
518 |
519 | if t0 is not None:
520 | self.register_buffer("t0", torch.tensor(t0).reshape(-1,1))
521 | else:
522 | self.register_buffer("t0", torch.zeros([n_targets,1]))
523 |
524 | self.register_buffer("dt0", torch.ones([1]))
525 | #self.register_buffer("t0", torch.zeros([1]))
526 | self.register_buffer("target_m",torch.zeros(self.v_encoder.fc1.weight.data.shape))
527 |
528 |
529 | def _get_inference_input(self, tensors):
530 | spliced = tensors[REGISTRY_KEYS.X_KEY]
531 | unspliced = tensors[REGISTRY_KEYS.U_KEY]
532 |
533 | input_dict = {
534 | "spliced": spliced,
535 | "unspliced": unspliced,
536 | }
537 | return input_dict
538 |
539 | def _get_generative_input(self, tensors, inference_outputs):
540 | z = inference_outputs["z"]
541 | gamma = inference_outputs["gamma"]
542 | beta = inference_outputs["beta"]
543 | alpha_1 = inference_outputs["alpha_1"]
544 |
545 | input_dict = {
546 | "z": z,
547 | "gamma": gamma,
548 | "beta": beta,
549 | "alpha_1": alpha_1,
550 | }
551 | return input_dict
552 |
553 | @auto_move_data
554 | def inference(
555 | self,
556 | spliced,
557 | unspliced,
558 | n_samples=1,
559 | ):
560 | """High level inference method.
561 |
562 | Runs the inference (encoder) model.
563 | """
564 | spliced_ = spliced
565 | unspliced_ = unspliced
566 | if self.log_variational:
567 | spliced_ = torch.log(0.01 + spliced)
568 | unspliced_ = torch.log(0.01 + unspliced)
569 |
570 | encoder_input = torch.cat((spliced_, unspliced_), dim=-1)
571 |
572 | qz_m, qz_v, z = self.z_encoder(encoder_input)
573 |
574 | if n_samples > 1:
575 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
576 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
577 | # when z is normal, untran_z == z
578 | untran_z = Normal(qz_m, qz_v.sqrt()).sample()
579 | z = self.z_encoder.z_transformation(untran_z)
580 |
581 | gamma, beta, alpha_1 = self._get_rates()
582 |
583 | outputs = {
584 | "z": z,
585 | "qz_m": qz_m,
586 | "qz_v": qz_v,
587 | "qzm": qz_m,
588 | "qzv": qz_v,
589 | "gamma": gamma,
590 | "beta": beta,
591 | "alpha_1": alpha_1,
592 | }
593 | return outputs
594 |
595 | def _get_rates(self):
596 | # globals
597 | # degradation for each target gene
598 | gamma = torch.clamp(F.softplus(self.v_encoder.gamma_mean_unconstr), 0, 50)
599 | # splicing for each target gene
600 | beta = torch.clamp(F.softplus(self.v_encoder.beta_mean_unconstr), 0, 50)
601 | # transcription for each target gene (bias term)
602 | alpha_1 = self.alpha_1_unconstr
603 |
604 | return gamma, beta, alpha_1
605 |
606 | @auto_move_data
607 | def generative(self, z, gamma, beta, alpha_1, latent_dim=None):
608 | """Runs the generative model."""
609 | decoder_input = z
610 |
611 | ## decoder directly decode the latent time of each gene
612 | px_rho = self.decoder(decoder_input, latent_dim=latent_dim)
613 |
614 | scale_unconstr = self.scale_unconstr_targets
615 | scale = F.softplus(scale_unconstr)
616 |
617 | dist_s, dist_u, index, ind_t = self.get_px(
618 | px_rho,
619 | scale,
620 | gamma,
621 | beta,
622 | alpha_1,
623 | )
624 |
625 | return {
626 | "px_rho": px_rho,
627 | "scale": scale,
628 | "dist_u": dist_u,
629 | "dist_s": dist_s,
630 | "t_sort": index,
631 | "ind_t": ind_t,
632 | }
633 |
634 | def pearson_correlation_loss(self, tensor1, tensor2, eps=1e-6):
635 | # Calculate means
636 | mean1 = torch.mean(tensor1, dim=0)
637 | mean2 = torch.mean(tensor2, dim=0)
638 |
639 | # Calculate covariance
640 | covariance = torch.mean((tensor1 - mean1) * (tensor2 - mean2), dim=0)
641 |
642 | # Calculate standard deviations
643 | std1 = torch.std(tensor1, dim=0,correction = 0)
644 | std2 = torch.std(tensor2, dim=0,correction = 0)
645 |
646 | # Calculate correlation coefficients
647 | correlation_coefficients = covariance / (std1 * std2 + eps)
648 |
649 | # Convert NaNs to 0 (when std1 or std2 are 0)
650 | correlation_coefficients[torch.isnan(correlation_coefficients)] = 0
651 |
652 | # Calculate loss (1 - correlation_coefficient) to minimize correlation
653 | loss = - correlation_coefficients
654 |
655 | return loss
656 |
657 | def loss(
658 | self,
659 | tensors,
660 | inference_outputs,
661 | generative_outputs,
662 | kl_weight: float = 1.0,
663 | n_obs: float = 1.0,
664 | ):
665 | spliced = tensors[REGISTRY_KEYS.X_KEY]
666 | unspliced = tensors[REGISTRY_KEYS.U_KEY]
667 |
668 | ### extract spliced, unspliced readout
669 | regulator_spliced = spliced[:,self.regulator_index]
670 | target_spliced = spliced[:,self.target_index]
671 | target_unspliced = unspliced[:,self.target_index]
672 |
673 | qz_m = inference_outputs["qz_m"]
674 | qz_v = inference_outputs["qz_v"]
675 | beta = inference_outputs["beta"]
676 |
677 | dist_s = generative_outputs["dist_s"]
678 | dist_u = generative_outputs["dist_u"]
679 | t = generative_outputs["ind_t"]
680 | t_sort = generative_outputs["t_sort"].T
681 |
682 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1)
683 |
684 | reconst_loss_s = -dist_s.log_prob(target_spliced)
685 | reconst_loss_u = -dist_u.log_prob(target_unspliced)
686 |
687 | reconst_loss_target = reconst_loss_u.sum(dim=-1) + reconst_loss_s.sum(dim=-1)
688 |
689 |
690 | alpha = self.v_encoder.transcription_rate(regulator_spliced)
691 | target_unspliced_sort = torch.gather(target_unspliced,0,t_sort) ## measure the lag target unspliced readout (t+1)
692 | alpha_sort = torch.gather(alpha,0,t_sort) ## measure the transcription activity (t)
693 | alpha_loss = self.alpha_constraint * self.pearson_correlation_loss(target_unspliced_sort,alpha_sort).sum() / alpha.shape[1]
694 |
695 | ## add velocity constraint
696 | ## regularize the inferred velocity has both negative and positive compartments
697 |
698 | if self.vector_constraint:
699 | du,_ = self.v_encoder(t = 0, u = unspliced, s = spliced)
700 | alpha = self.v_encoder.transcription_rate(s = spliced)
701 | du = alpha - beta * unspliced
702 | velo_loss = 100 * torch.norm(du,dim=1)
703 | else:
704 | velo_loss = 0
705 |
706 | ## add graph constraint
707 | if self.soft_constraint:
708 | ## Using norm function to perform graph regularization
709 | mask_m = 1 - self.v_encoder.mask_m_raw
710 | graph_constr_loss = self.lamba * torch.norm(self.v_encoder.fc1.weight * mask_m)
711 | else:
712 | graph_constr_loss = 0
713 |
714 | Jaco = self.v_encoder.GRN_Jacobian(dist_s.mean.mean(0))
715 | loss = torch.nn.L1Loss(reduction = "sum")
716 | L1_loss = (self.lamba2)*loss(Jaco,self.target_m)
717 |
718 | ## regularize bias need to be negative
719 | if self.bias_constraint:
720 | bias_regularize = torch.norm(self.v_encoder.fc1.bias + 10)
721 | else:
722 | bias_regularize = 0
723 |
724 | # local loss
725 | kl_local = kl_divergence_z
726 | weighted_kl_local = (kl_divergence_z) * kl_weight
727 | local_loss = torch.mean(reconst_loss_target + weighted_kl_local + velo_loss)
728 |
729 | # total loss
730 | loss = local_loss + alpha_loss + L1_loss + graph_constr_loss + bias_regularize
731 |
732 | loss_recoder = LossOutput(
733 | loss=loss, reconstruction_loss=reconst_loss_target, kl_local=kl_local
734 | )
735 |
736 | return loss_recoder
737 |
738 | @auto_move_data
739 | def get_px(
740 | self,
741 | px_rho,
742 | scale,
743 | gamma,
744 | beta,
745 | alpha_1,
746 | ) -> torch.Tensor:
747 |
748 | # predict the abundance in induction phase for target genes
749 | ind_t = self.t_max * px_rho
750 | n_cells = px_rho.shape[0]
751 | mean_u, mean_s, index = self._get_induction_unspliced_spliced(
752 | ind_t
753 | )
754 |
755 | ### only consider induction phase
756 | scale_u = scale[: self.n_targets, 0].expand(n_cells, self.n_targets).sqrt()
757 | scale_s = scale[self.n_targets :, 0].expand(n_cells, self.n_targets).sqrt()
758 |
759 | dist_u = Normal(mean_u, scale_u)
760 | dist_s = Normal(mean_s, scale_s)
761 |
762 | return dist_s, dist_u, index, ind_t
763 |
764 | def root_time(self, t, root=None):
765 | """TODO."""
766 | t_root = 0 if root is None else t[root]
767 | o = (t >= t_root).int()
768 | t_after = (t - t_root) * o
769 | t_origin,_ = torch.max(t_after, axis=0)
770 | t_before = (t + t_origin) * (1 - o)
771 |
772 | t_rooted = t_after + t_before
773 |
774 | return t_rooted
775 |
776 | def _get_induction_unspliced_spliced(
777 | self, t, eps=1e-6
778 | ):
779 | """
780 | this function aim to calculate the spliced and unspliced abundance for target genes
781 |
782 | alpha_1: the maximum transcription rate during induction phase for each target gene
783 | beta: the splicing parameter for each target gene
784 | gamma: the degradation parameter for each target gene
785 |
786 | ** the above parameters are saved in v_encoder
787 | t: target gene specific latent time
788 | """
789 | device = self.device
790 | #t = t.T
791 |
792 | if t.shape[0] > 1:
793 | ## define parameters
794 | _, index = torch.sort(t, dim=0)
795 | index = index.T
796 | dim = t.shape[0] * t.shape[1]
797 | t0 = self.t0.repeat(t.shape[0],1)
798 | dt0 = self.dt0.expand(dim)
799 | x0 = self.x0.repeat(t.shape[0],1)
800 |
801 | t_eval = t.reshape(-1,1)
802 | t_eval = torch.cat((t0,t_eval),dim=1)
803 |
804 | ## set up G batches, Each G represent a module (a target gene centerred regulon)
805 | ## infer the observe gene expression through ODE solver based on x0, t, and velocity_encoder
806 |
807 | term = to.ODETerm(self.v_batch)
808 | step_method = to.Dopri5(term=term)
809 | #step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
810 | step_size_controller = to.FixedStepController()
811 | solver = to.AutoDiffAdjoint(step_method, step_size_controller)
812 | #jit_solver = torch.jit.script(solver)
813 | sol = solver.solve(to.InitialValueProblem(y0=x0, t_eval=t_eval), dt0 = dt0)
814 | else:
815 | t_eval = t
816 | t_eval = torch.cat((self.t0,t_eval),dim=1)
817 | ## set up G batches, Each G represent a module (a target gene centerred regulon)
818 | ## infer the observe gene expression through ODE solver based on x0, t, and velocity_encoder
819 | #x0 = x0.double()
820 |
821 | term = to.ODETerm(self.v_encoder)
822 | step_method = to.Dopri5(term=term)
823 | #step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
824 | step_size_controller = to.FixedStepController()
825 | solver = to.AutoDiffAdjoint(step_method, step_size_controller)
826 | #jit_solver = torch.jit.script(solver)
827 | sol = solver.solve(to.InitialValueProblem(y0=self.x0, t_eval=t_eval), dt0 = self.dt0)
828 |
829 | ## generate predict results
830 | # the solved results are saved in sol.ys [the number of subsystems, time_stamps, [u,s]]
831 | pre_u = sol.ys[:,1:,0]
832 | pre_s = sol.ys[:,1:,1]
833 |
834 | if t.shape[1] > 1:
835 | unspliced = pre_u.reshape(-1,t.shape[1])
836 | spliced = pre_s.reshape(-1,t.shape[1])
837 | else:
838 | unspliced = pre_u.ravel()
839 | spliced = pre_s.ravel()
840 |
841 | return unspliced, spliced, index
842 |
843 | def _get_repression_unspliced_spliced(self, u_0, s_0, beta, gamma, t, eps=1e-6):
844 | unspliced = torch.exp(-beta * t) * u_0
845 | spliced = s_0 * torch.exp(-gamma * t) - (
846 | beta * u_0 / ((gamma - beta) + eps)
847 | ) * (torch.exp(-gamma * t) - torch.exp(-beta * t))
848 | return unspliced, spliced
849 |
850 | def sample(
851 | self,
852 | ) -> np.ndarray:
853 | """Not implemented."""
854 | raise NotImplementedError
855 |
856 | @torch.no_grad()
857 | def get_loadings(self) -> np.ndarray:
858 | """Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder."""
859 | # This is BW, where B is diag(b) batch norm, W is weight matrix
860 | if self.decoder.linear_decoder is False:
861 | raise ValueError("Model not trained with linear decoder")
862 | w = self.decoder.rho_first_decoder.fc_layers[0][0].weight
863 | if self.use_batch_norm_decoder:
864 | bn = self.decoder.rho_first_decoder.fc_layers[0][1]
865 | sigma = torch.sqrt(bn.running_var + bn.eps)
866 | gamma = bn.weight
867 | b = gamma / sigma
868 | b_identity = torch.diag(b)
869 | loadings = torch.matmul(b_identity, w)
870 | else:
871 | loadings = w
872 | loadings = loadings.detach().cpu().numpy()
873 |
874 | return loadings
875 |
876 | def freeze_mapping(self):
877 | for param in self.z_encoder.parameters():
878 | param.requires_grad = False
879 |
880 | for param in self.decoder.parameters():
881 | param.requires_grad = False
--------------------------------------------------------------------------------
/regvelo/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from ._datasets import (
2 | zebrafish_nc,
3 | zebrafish_grn,
4 | murine_nc,
5 | )
6 |
7 | __all__ = [
8 | "zebrafish_nc",
9 | "zebrafish_grn",
10 | "murine_nc",
11 | ]
--------------------------------------------------------------------------------
/regvelo/datasets/_datasets.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, Union
3 |
4 | import pandas as pd
5 |
6 | from scanpy import read
7 |
8 | from scvelo.core import cleanup
9 | from scvelo.read_load import load
10 |
11 | url_adata = "https://drive.google.com/uc?id=1Nzq1F6dGw-nR9lhRLfZdHOG7dcYq7P0i&export=download"
12 | url_grn = "https://drive.google.com/uc?id=1ci_gCwdgGlZ0xSn6gSa_-LlIl9-aDa1c&export=download/"
13 | url_adata_murine_processed = "https://drive.usercontent.google.com/download?id=19bNQfW3jMKEEjpjNdUkVd7KDTjJfqxa5&export=download&authuser=1&confirm=t&uuid=4fdf3051-229b-4ce2-b644-cb390424570a&at=APcmpoxgcuZ5r6m6Fb6N_2Og6tEO:1745354679573"
14 | url_adata_murine_normalized = "https://drive.usercontent.google.com/download?id=1xy2FNYi6Y2o_DzXjRmmCtARjoZ97Ro_w&export=download&authuser=1&confirm=t&uuid=12cf5d23-f549-48d9-b7ec-95411a58589f&at=APcmpoyexgouf243lNygF9yRUkmi:1745997349046"
15 | url_adata_murine_velocyto = "https://drive.usercontent.google.com/download?id=18Bhtb7ruoUxpNt8WMYSaJ1RyoiHOCEjd&export=download&authuser=1&confirm=t&uuid=ecc42202-bc82-4ab1-b2c3-bfc31c99f0df&at=APcmpozsh6tBzkv8NSIZW0VipDJa:1745997422108"
16 |
17 |
18 | def zebrafish_nc(file_path: Union[str, Path] = "data/zebrafish_nc/adata_zebrafish_preprocessed.h5ad"):
19 | """Zebrafish neural crest cells.
20 |
21 | Single cell RNA-seq datasets of zebrafish neural crest cell development across
22 | seven distinct time points using ultra-deep Smart-seq3 technique.
23 |
24 | There are four distinct phases of NC cell development: 1) specification at the NPB, 2) epithelial-to-mesenchymal
25 | transition (EMT) from the neural tube, 3) migration throughout the periphery, 4) differentiation into distinct cell types
26 |
27 | Arguments:
28 | ---------
29 | file_path
30 | Path where to save dataset and read it from.
31 |
32 | Returns
33 | -------
34 | Returns `adata` object
35 | """
36 | adata = read(file_path, backup_url=url_adata, sparse=True, cache=True)
37 | return adata
38 |
39 | def zebrafish_grn(file_path: Union[str, Path] = "data/zebrafish_nc/prior_GRN.csv"):
40 | """Zebrafish neural crest cells.
41 |
42 | Single cell RNA-seq datasets of zebrafish neural crest cell development across
43 | seven distinct time points using ultra-deep Smart-seq3 technique.
44 |
45 | There are four distinct phases of NC cell development: 1) specification at the NPB, 2) epithelial-to-mesenchymal
46 | transition (EMT) from the neural tube, 3) migration throughout the periphery, 4) differentiation into distinct cell types
47 |
48 | Arguments:
49 | ---------
50 | file_path
51 | Path where to save dataset and read it from.
52 |
53 | Returns
54 | -------
55 | Returns `adata` object
56 | """
57 | grn = pd.read_csv(url_grn, index_col = 0)
58 | grn.to_csv(file_path)
59 | return grn
60 |
61 | def murine_nc(data_type: str = "preprocessed"):
62 | """
63 | Mouse neural crest cells.
64 |
65 | Single-cell RNA-seq datasets of mouse neural crest cell development,
66 | subset from Qiu, Chengxiang et al.
67 |
68 | The gene regulatory network (GRN) is saved in `adata.uns["skeleton"]`,
69 | which is learned via pySCENIC.
70 |
71 | Parameters
72 | ----------
73 | data_type : str
74 | Which version of the dataset to load. Must be one of:
75 | - "preprocessed"
76 | - "normalized"
77 | - "velocyto"
78 |
79 | Returns
80 | -------
81 | AnnData
82 | Annotated data matrix (an `AnnData` object).
83 | """
84 | valid_types = ["preprocessed", "normalized", "velocyto"]
85 | if data_type not in valid_types:
86 | raise ValueError(f"Invalid data_type: '{data_type}'. Must be one of {valid_types}.")
87 |
88 | file_path = ["data/murine_nc/adata_preprocessed.h5ad","data/murine_nc/adata_gex_normalized.h5ad","data/murine_nc/adata_velocity.h5ad"]
89 |
90 | if data_type == "preprocessed":
91 | adata = read(file_path[0], backup_url=url_adata_murine_processed, sparse=True, cache=True)
92 | elif data_type == "normalized":
93 | adata = read(file_path[1], backup_url=url_adata_murine_normalized, sparse=True, cache=True)
94 | elif data_type == "velocyto":
95 | adata = read(file_path[2], backup_url=url_adata_murine_velocyto, sparse=True, cache=True)
96 |
97 | return adata
--------------------------------------------------------------------------------
/regvelo/plotting/__init__.py:
--------------------------------------------------------------------------------
1 | from .fate_probabilities import fate_probabilities
2 | from .commitment_score import commitment_score
3 | from .depletion_score import depletion_score
4 | from .get_significance import get_significance
5 |
6 | __all__ = [
7 | "fate_probabilities",
8 | "commitment_score",
9 | "depletion_score",
10 | "get_significance",
11 | ]
12 |
--------------------------------------------------------------------------------
/regvelo/plotting/commitment_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import scanpy as sc
4 | from anndata import AnnData
5 | from typing import Any
6 |
7 | from .utils import calculate_entropy
8 |
9 |
10 | def commitment_score(adata : AnnData,
11 | lineage_key : str = "lineages_fwd",
12 | **kwargs : Any
13 | ) -> None:
14 | """
15 | Compute and plot cell fate commitment scores based on fate probabilities.
16 |
17 | Parameters
18 | ----------
19 | adata : AnnData
20 | Dataset containing fate probabilities. Original dataset or perturbed dataset.
21 | lineage_key : str
22 | The key in .obsm that stores the fate probabilities.
23 | kwargs : Any
24 | Optional
25 | Additional keyword arguments passed to scanpy.pl.umap function.
26 | """
27 |
28 | if lineage_key not in adata.obsm:
29 | raise KeyError(f"Key '{lineage_key}' not found in `adata.obsm`.")
30 |
31 | p = pd.DataFrame(adata.obsm[lineage_key], columns=adata.obsm[lineage_key].names.tolist())
32 | score = calculate_entropy(p)
33 | adata.obs["commitment_score"] = np.array(score)
34 |
35 | sc.pl.umap(
36 | adata,
37 | color="commitment_score",
38 | **kwargs
39 | )
40 |
41 |
--------------------------------------------------------------------------------
/regvelo/plotting/depletion_score.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import pandas as pd
4 | from anndata import AnnData
5 | from typing import Union, Sequence, Any, Optional
6 | import cellrank as cr
7 |
8 | import seaborn as sns
9 | import matplotlib.pyplot as plt
10 |
11 |
12 | from ..tools.abundance_test import abundance_test
13 |
14 | def depletion_score(adata : AnnData,
15 | df : pd.DataFrame,
16 | color_label : str = "celltype_cluster",
17 | **kwargs : Any,
18 | ) -> None:
19 | """
20 | Plot depletion scores.
21 |
22 | Parameters
23 | ----------
24 | adata : AnnData
25 | Annotated data matrix of original model.
26 | df : pandas.DataFrame
27 | color_label : str
28 | Used for color palette
29 | kwargs : Any
30 | Optional
31 | Additional keyword arguments passed to CellRank and plot functions.
32 | """
33 |
34 | fontsize = kwargs.get("fontsize", 14)
35 | figsize = kwargs.get("figsize", (12, 6))
36 |
37 | legend_loc = kwargs.get("legend_loc", "center left")
38 | legend_bbox = kwargs.get("legend_bbox", (1.02, 0.5))
39 |
40 | xlabel = kwargs.get("xlabel", "TF")
41 | ylabel = kwargs.get("ylabel", "Depletion score")
42 |
43 | plot_kwargs = {k: kwargs[k] for k in ("ax",) if k in kwargs}
44 |
45 | plt.figure(figsize=figsize)
46 |
47 | palette = dict(zip(adata.obs[color_label].cat.categories, adata.uns[f"{color_label}_colors"]))
48 | sns.barplot(x=xlabel, y=ylabel, hue='Terminal state', data=df, palette=palette, dodge=True, **plot_kwargs)
49 |
50 | for i in range(len(df['TF'].unique()) - 1):
51 | plt.axvline(x=i + 0.5, color='gray', linestyle='--')
52 |
53 | plt.ylabel(ylabel, fontsize=fontsize)
54 | plt.xlabel(xlabel, fontsize=fontsize)
55 | plt.xticks(fontsize=fontsize)
56 | plt.yticks(fontsize=fontsize)
57 |
58 | plt.legend(
59 | title='Terminal state',
60 | bbox_to_anchor=legend_bbox,
61 | loc=legend_loc,
62 | borderaxespad=0
63 | )
64 |
65 |
66 | plt.tight_layout()
67 | plt.show()
68 |
--------------------------------------------------------------------------------
/regvelo/plotting/fate_probabilities.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import cellrank as cr
3 | from anndata import AnnData
4 | from typing import Union, Sequence, Any
5 |
6 | def fate_probabilities(adata : AnnData,
7 | terminal_state : Union[str, Sequence[str]],
8 | n_states : int,
9 | save_kernel : bool = True,
10 | **kwargs : Any
11 | ) -> None:
12 | """
13 | Compute transition matrix and fate probabilities toward the terminal states and plot these for each of the
14 | terminal states.
15 |
16 | Parameters
17 | ----------
18 | adata : AnnData
19 | Annotated data matrix.
20 | terminal_state : str or Sequence[str]
21 | List of terminal states to compute probabilities for.
22 | n_states : int
23 | Number of states to compute probabilities for.
24 | save_kernel : bool
25 | Whether to write the kernel to adata. Default is True.
26 | kwargs : Any
27 | Optional
28 | Additional keyword arguments passed to CellRank functions
29 | """
30 |
31 | macro_kwargs = {k: kwargs[k] for k in ("cluster_key", "method") if k in kwargs}
32 | compute_fate_probabilities_kwargs = {k: kwargs[k] for k in ("solver", "tol") if k in kwargs}
33 | plot_kwargs = {k: kwargs[k] for k in ("basis", "same_plot", "states", "title") if k in kwargs}
34 |
35 | vk = cr.kernels.VelocityKernel(adata).compute_transition_matrix()
36 | if save_kernel:
37 | vk.write_to_adata()
38 |
39 | estimator = cr.estimators.GPCCA(vk)
40 | estimator.compute_macrostates(n_states=n_states, **macro_kwargs)
41 | estimator.set_terminal_states(terminal_state)
42 | estimator.compute_fate_probabilities(**compute_fate_probabilities_kwargs)
43 | estimator.plot_fate_probabilities(**plot_kwargs)
44 | estimator.compute_lineage_drivers()
45 | plt.show()
46 |
47 |
--------------------------------------------------------------------------------
/regvelo/plotting/get_significance.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def get_significance(pvalue: float) -> str:
4 | """
5 | Return significance annotation for a p-value.
6 |
7 | Parameters
8 | ----------
9 | pvalue : float
10 | P-value to interpret.
11 |
12 | Returns
13 | -------
14 | str
15 | A string indicating the level of significance:
16 | "***" for p < 0.001,
17 | "**" for p < 0.01,
18 | "*" for p < 0.1,
19 | "n.s." (not significant) otherwise.
20 | """
21 | if pvalue < 0.001:
22 | return "***"
23 | elif pvalue < 0.01:
24 | return "**"
25 | elif pvalue < 0.1:
26 | return "*"
27 | else:
28 | return "n.s."
--------------------------------------------------------------------------------
/regvelo/plotting/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import scvelo as scv
4 |
5 |
6 | def calculate_entropy(prob_matrix : np.array
7 | ) -> np.array:
8 | """
9 | Calculate entropy for each row in a cell fate probability matrix.
10 |
11 | Parameters
12 | ----------
13 | prob_matrix : np.ndarray
14 | A 2D NumPy array of shape (n_cells, n_lineages) where each row represents
15 | a cell's fate probabilities across different lineages. Each row should sum to 1.
16 |
17 | Returns
18 | -------
19 | entropy : np.ndarray
20 | A 1D NumPy array of length n_cells containing the entropy values for each cell.
21 | """
22 | log_probs = np.zeros_like(prob_matrix)
23 | mask = prob_matrix != 0
24 | np.log2(prob_matrix, where=mask, out=log_probs)
25 |
26 | entropy = -np.sum(prob_matrix * log_probs, axis=1)
27 | return entropy
28 |
29 |
--------------------------------------------------------------------------------
/regvelo/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocess_data import preprocess_data
2 | from .sanity_check import sanity_check
3 | from .set_prior_grn import set_prior_grn
4 | from .filter_genes import filter_genes
5 |
6 | __all__ = [
7 | "preprocess_data",
8 | "set_prior_grn",
9 | "sanity_check",
10 | "filter_genes"
11 | ]
12 |
--------------------------------------------------------------------------------
/regvelo/preprocessing/filter_genes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from anndata import AnnData
5 |
6 | def filter_genes(adata: AnnData) -> AnnData:
7 | """Filter genes in an AnnData object to ensure each gene has upstream regulators.
8 |
9 | The function iteratively refines the skeleton matrix to maintain only genes with regulatory connections. Only used
10 | by `soft_constraint=False` RegVelo model.
11 |
12 | Parameters
13 | ----------
14 | adata
15 | Annotated data object (AnnData) containing gene expression data, a skeleton matrix of regulatory interactions,
16 | and a list of regulators.
17 |
18 | Returns
19 | -------
20 | adata
21 | Updated AnnData object with filtered genes and a refined skeleton matrix where all genes have at least one
22 | upstream regulator.
23 | """
24 | # Initial filtering based on regulators
25 | var_mask = adata.var_names.isin(adata.uns["regulators"])
26 |
27 | # Filter genes based on `full_names`
28 | adata = adata[:, var_mask].copy()
29 |
30 | # Update skeleton matrix
31 | skeleton = adata.uns["skeleton"].loc[adata.var_names.tolist(), adata.var_names.tolist()]
32 | adata.uns["skeleton"] = skeleton
33 |
34 | # Iterative refinement
35 | while adata.uns["skeleton"].sum(0).min() == 0:
36 | # Update filtering based on skeleton
37 | skeleton = adata.uns["skeleton"]
38 | mask = skeleton.sum(0) > 0
39 |
40 | regulators = adata.var_names[mask].tolist()
41 | print(f"Number of genes: {len(regulators)}")
42 |
43 | # Filter skeleton and update `adata`
44 | skeleton = skeleton.loc[regulators, regulators]
45 | adata.uns["skeleton"] = skeleton
46 |
47 | # Update adata with filtered genes
48 | adata = adata[:, mask].copy()
49 | adata.uns["regulators"] = regulators
50 | adata.uns["targets"] = regulators
51 |
52 | # Re-index skeleton with updated gene names
53 | skeleton_df = pd.DataFrame(
54 | adata.uns["skeleton"],
55 | index=adata.uns["regulators"],
56 | columns=adata.uns["targets"],
57 | )
58 | adata.uns["skeleton"] = skeleton_df
59 |
60 | return adata
--------------------------------------------------------------------------------
/regvelo/preprocessing/preprocess_data.py:
--------------------------------------------------------------------------------
1 | # the folloing code is adapted from velovi
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import scvelo as scv
7 | from anndata import AnnData
8 | from sklearn.preprocessing import MinMaxScaler
9 |
10 |
11 | def preprocess_data(
12 | adata: AnnData,
13 | spliced_layer: Optional[str] = "Ms",
14 | unspliced_layer: Optional[str] = "Mu",
15 | min_max_scale: bool = True,
16 | filter_on_r2: bool = True,
17 | ) -> AnnData:
18 | """Preprocess data.
19 |
20 | This function removes poorly detected genes and minmax scales the data.
21 |
22 | Parameters
23 | ----------
24 | adata
25 | Annotated data matrix.
26 | spliced_layer
27 | Name of the spliced layer.
28 | unspliced_layer
29 | Name of the unspliced layer.
30 | min_max_scale
31 | Min-max scale spliced and unspliced
32 | filter_on_r2
33 | Filter out genes according to linear regression fit
34 |
35 | Returns
36 | -------
37 | Preprocessed adata.
38 | """
39 | if min_max_scale:
40 | scaler = MinMaxScaler()
41 | adata.layers[spliced_layer] = scaler.fit_transform(adata.layers[spliced_layer])
42 |
43 | scaler = MinMaxScaler()
44 | adata.layers[unspliced_layer] = scaler.fit_transform(
45 | adata.layers[unspliced_layer]
46 | )
47 |
48 | if filter_on_r2:
49 | scv.tl.velocity(adata, mode="deterministic")
50 |
51 | adata = adata[
52 | :, np.logical_and(adata.var.velocity_r2 > 0, adata.var.velocity_gamma > 0)
53 | ].copy()
54 | adata = adata[:, adata.var.velocity_genes].copy()
55 |
56 | return adata
--------------------------------------------------------------------------------
/regvelo/preprocessing/sanity_check.py:
--------------------------------------------------------------------------------
1 |
2 | from pathlib import Path
3 | from typing import Optional, Union
4 | from urllib.request import urlretrieve
5 |
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | import scvelo as scv
10 | from anndata import AnnData
11 | from sklearn.preprocessing import MinMaxScaler
12 | from scipy.spatial.distance import cdist
13 |
14 | def sanity_check(
15 | adata : AnnData,
16 | ) -> AnnData:
17 |
18 | """
19 | Sanity check
20 |
21 | This function helps to ensure each gene will have at least one regulator.
22 |
23 | Parameters
24 | ----------
25 | adata
26 | Annotated data matrix.
27 | """
28 |
29 | gene_name = adata.var.index.tolist()
30 | full_name = adata.uns["regulators"]
31 | index = [i in gene_name for i in full_name]
32 | full_name = full_name[index]
33 | adata = adata[:,full_name].copy()
34 |
35 | W = adata.uns["skeleton"]
36 | W = W[index,:]
37 | W = W[:,index]
38 |
39 | adata.uns["skeleton"] = W
40 | W = adata.uns["network"]
41 | W = W[index,:]
42 | W = W[:,index]
43 | #csgn = csgn[index,:,:]
44 | #csgn = csgn[:,index,:]
45 | adata.uns["network"] = W
46 |
47 | ###
48 | for i in range(1000):
49 | if adata.uns["skeleton"].sum(0).min()>0:
50 | break
51 | else:
52 | W = np.array(adata.uns["skeleton"])
53 | gene_name = adata.var.index.tolist()
54 |
55 | indicator = W.sum(0) > 0 ## every gene would need to have a upstream regulators
56 | regulators = [gene for gene, boolean in zip(gene_name, indicator) if boolean]
57 | targets = [gene for gene, boolean in zip(gene_name, indicator) if boolean]
58 | print("num regulators: "+str(len(regulators)))
59 | print("num targets: "+str(len(targets)))
60 | W = np.array(adata.uns["skeleton"])
61 | W = W[indicator,:]
62 | W = W[:,indicator]
63 | adata.uns["skeleton"] = W
64 |
65 | W = np.array(adata.uns["network"])
66 | W = W[indicator,:]
67 | W = W[:,indicator]
68 | adata.uns["network"] = W
69 |
70 | #csgn = csgn[indicator,:,:]
71 | #csgn = csgn[:,indicator,:]
72 | #adata.uns["csgn"] = csgn
73 |
74 | adata.uns["regulators"] = regulators
75 | adata.uns["targets"] = targets
76 |
77 | W = pd.DataFrame(adata.uns["skeleton"],index = adata.uns["regulators"],columns = adata.uns["targets"])
78 | W = W.loc[regulators,targets]
79 | adata.uns["skeleton"] = W
80 | W = pd.DataFrame(adata.uns["network"],index = adata.uns["regulators"],columns = adata.uns["targets"])
81 | W = W.loc[regulators,targets]
82 | adata.uns["network"] = W
83 | adata = adata[:,indicator].copy()
84 |
85 | return adata
--------------------------------------------------------------------------------
/regvelo/preprocessing/set_prior_grn.py:
--------------------------------------------------------------------------------
1 |
2 | from pathlib import Path
3 | from typing import Optional, Union
4 | from urllib.request import urlretrieve
5 |
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | import scvelo as scv
10 | from anndata import AnnData
11 | from sklearn.preprocessing import MinMaxScaler
12 | from scipy.spatial.distance import cdist
13 |
14 | def set_prior_grn(adata: AnnData, gt_net: pd.DataFrame, keep_dim: bool = False) -> AnnData:
15 | """Adds prior gene regulatory network (GRN) information to an AnnData object.
16 |
17 | Parameters
18 | ----------
19 | adata : AnnData
20 | Annotated data matrix with gene expression data.
21 | gt_net : pd.DataFrame
22 | Prior gene regulatory network (targets as rows, regulators as columns).
23 | keep_dim : bool, optional
24 | If True, output AnnData retains original dimensions. Default is False.
25 |
26 | Returns
27 | -------
28 | AnnData
29 | Updated AnnData object with GRN stored in .uns["skeleton"].
30 | """
31 | # Identify regulators and targets present in adata
32 | regulator_mask = adata.var_names.isin(gt_net.columns)
33 | target_mask = adata.var_names.isin(gt_net.index)
34 | regulators = adata.var_names[regulator_mask]
35 | targets = adata.var_names[target_mask]
36 |
37 | if keep_dim:
38 | skeleton = pd.DataFrame(0, index=adata.var_names, columns=adata.var_names, dtype=float)
39 | common_targets = list(set(adata.var_names).intersection(gt_net.index))
40 | common_regulators = list(set(adata.var_names).intersection(gt_net.columns))
41 | skeleton.loc[common_targets, common_regulators] = gt_net.loc[common_targets, common_regulators]
42 | gt_net = skeleton.copy()
43 |
44 | # Compute correlation matrix based on gene expression layer "Ms"
45 | gex = adata.layers["Ms"]
46 | correlation = 1 - cdist(gex.T, gex.T, metric="correlation")
47 | #correlation = torch.tensor(correlation).float()
48 | correlation = correlation[np.ix_(target_mask, regulator_mask)]
49 | correlation[np.isnan(correlation)] = 0
50 |
51 | # Align and combine ground-truth GRN with expression correlation
52 | filtered_gt = gt_net.loc[targets, regulators]
53 | grn = filtered_gt * correlation
54 |
55 | # Binarize the GRN
56 | grn = (np.abs(grn) >= 0.01).astype(int)
57 | np.fill_diagonal(grn.values, 0) # Remove self-loops
58 |
59 | if keep_dim:
60 | skeleton = pd.DataFrame(0, index=adata.var_names, columns=adata.var_names, dtype=int)
61 | skeleton.loc[grn.columns, grn.index] = grn.T
62 | else:
63 | # Prune genes with no edges
64 | grn = grn.loc[grn.sum(axis=1) > 0, grn.sum(axis=0) > 0]
65 | genes = grn.index.union(grn.columns)
66 | skeleton = pd.DataFrame(0, index=genes, columns=genes, dtype=int)
67 | skeleton.loc[grn.columns, grn.index] = grn.T
68 |
69 | # Subset the adata to GRN genes and store in .uns
70 | adata = adata[:, skeleton.index].copy()
71 | skeleton = skeleton.loc[adata.var_names, adata.var_names]
72 |
73 | adata.uns["regulators"] = adata.var_names.to_numpy()
74 | adata.uns["targets"] = adata.var_names.to_numpy()
75 | adata.uns["skeleton"] = skeleton
76 | adata.uns["network"] = skeleton.copy()
77 |
78 | return adata
79 |
--------------------------------------------------------------------------------
/regvelo/tools/TFScanning_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 |
5 | import cellrank as cr
6 | from anndata import AnnData
7 | from scvelo import logging as logg
8 | import os, shutil
9 | from typing import Dict, Optional, Sequence, Tuple, Union
10 |
11 | from .._model import REGVELOVI
12 | from .utils import split_elements, combine_elements
13 | from .abundance_test import abundance_test
14 |
15 |
16 | def TFScanning_func(
17 | model : str,
18 | adata : AnnData,
19 | cluster_label : Optional[str] = None,
20 | terminal_states : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None,
21 | KO_list : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None,
22 | n_states : Optional[Union[int, Sequence[int]]] = None,
23 | cutoff : Optional[Union[int, Sequence[int]]] = 1e-3,
24 | method : Optional[str] = "likelihood",
25 | combined_kernel : Optional[bool] = False,
26 | ) -> Dict[str, Union[float, pd.DataFrame]]:
27 |
28 | """
29 | Perform in silico TF regulon knock-out screening
30 |
31 | Parameters
32 | ----------
33 | model
34 | The saved address for the RegVelo model.
35 | adata
36 | Anndata objects.
37 | cluster_label
38 | Key in :attr:`~anndata.AnnData.obs` to associate names and colors with :attr:`terminal_states`.
39 | terminal_states
40 | subset of :attr:`macrostates`.
41 | KO_list
42 | List of TF combinations to simulate knock-out (KO) effects
43 | Can be single TF e.g. geneA
44 | or double TFs e.g. geneB_geneC
45 | example input: ["geneA","geneB_geneC"]
46 | n_states
47 | Number of macrostates to compute.
48 | cutoff
49 | The threshold for determing which links need to be muted,
50 | method
51 | Quantify perturbation effects via `likelihood` or `t-statistics`
52 | combined_kernel
53 | Use combined kernel (0.8*VelocityKernel + 0.2*ConnectivityKernel)
54 | """
55 |
56 | reg_vae = REGVELOVI.load(model, adata)
57 | adata = reg_vae.add_regvelo_outputs_to_adata(adata = adata)
58 | raw_GRN = reg_vae.module.v_encoder.fc1.weight.detach().clone()
59 | perturb_GRN = reg_vae.module.v_encoder.fc1.weight.detach().clone()
60 |
61 | ## define kernel matrix
62 | vk = cr.kernels.VelocityKernel(adata)
63 | vk.compute_transition_matrix()
64 | ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()
65 |
66 | if combined_kernel:
67 | g2 = cr.estimators.GPCCA(0.8 * vk + 0.2 * ck)
68 | else:
69 | g2 = cr.estimators.GPCCA(vk)
70 |
71 | ## evaluate the fate prob on original space
72 | g2.compute_macrostates(n_states=n_states, n_cells = 30, cluster_key=cluster_label)
73 | # set a high number of states, and merge some of them and rename
74 | if terminal_states is None:
75 | g2.predict_terminal_states()
76 | terminal_states = g2.terminal_states.cat.categories.tolist()
77 | g2.set_terminal_states(
78 | terminal_states
79 | )
80 | g2.compute_fate_probabilities(solver="direct")
81 | fate_prob = g2.fate_probabilities
82 | sampleID = adata.obs.index.tolist()
83 | fate_name = fate_prob.names.tolist()
84 | fate_prob = pd.DataFrame(fate_prob,index= sampleID,columns=fate_name)
85 | fate_prob_original = fate_prob.copy()
86 |
87 | ## create dictionary
88 | terminal_id = terminal_states.copy()
89 | terminal_type = terminal_states.copy()
90 | for i in terminal_states:
91 | for j in [1,2,3,4,5,6,7,8,9,10]:
92 | terminal_id.append(i+"_"+str(j))
93 | terminal_type.append(i)
94 | terminal_dict = dict(zip(terminal_id, terminal_type))
95 | n_states = len(g2.macrostates.cat.categories.tolist())
96 |
97 | coef = []
98 | pvalue = []
99 | for tf in split_elements(KO_list):
100 | perturb_GRN = raw_GRN.clone()
101 | vec = perturb_GRN[:,[i in tf for i in adata.var.index.tolist()]].clone()
102 | vec[vec.abs() > cutoff] = 0
103 | perturb_GRN[:,[i in tf for i in adata.var.index.tolist()]]= vec
104 | reg_vae_perturb = REGVELOVI.load(model, adata)
105 | reg_vae_perturb.module.v_encoder.fc1.weight.data = perturb_GRN
106 |
107 | adata_target = reg_vae_perturb.add_regvelo_outputs_to_adata(adata = adata)
108 | ## perturb the regulations
109 | vk = cr.kernels.VelocityKernel(adata_target)
110 | vk.compute_transition_matrix()
111 | ck = cr.kernels.ConnectivityKernel(adata_target).compute_transition_matrix()
112 |
113 | if combined_kernel:
114 | g2 = cr.estimators.GPCCA(0.8 * vk + 0.2 * ck)
115 | else:
116 | g2 = cr.estimators.GPCCA(vk)
117 | ## evaluate the fate prob on original space
118 | n_states_perturb = n_states
119 | while True:
120 | try:
121 | # Perform some computation in f(a)
122 | g2.compute_macrostates(n_states=n_states_perturb, n_cells = 30, cluster_key=cluster_label)
123 | break
124 | except:
125 | # If an error is raised, increment a and try again, and need to recompute double knock-out reults
126 | n_states_perturb += 1
127 | vk = cr.kernels.VelocityKernel(adata)
128 | vk.compute_transition_matrix()
129 | ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()
130 | if combined_kernel:
131 | g = cr.estimators.GPCCA(0.8 * vk + 0.2 * ck)
132 | else:
133 | g = cr.estimators.GPCCA(vk)
134 | ## evaluate the fate prob on original space
135 | g.compute_macrostates(n_states=n_states_perturb, n_cells = 30,cluster_key=cluster_label)
136 | ## set a high number of states, and merge some of them and rename
137 | if terminal_states is None:
138 | g.predict_terminal_states()
139 | terminal_states = g.terminal_states.cat.categories.tolist()
140 | g.set_terminal_states(
141 | terminal_states
142 | )
143 | g.compute_fate_probabilities(solver="direct")
144 | fate_prob = g.fate_probabilities
145 | sampleID = adata.obs.index.tolist()
146 | fate_name = fate_prob.names.tolist()
147 | fate_prob = pd.DataFrame(fate_prob,index= sampleID,columns=fate_name)
148 |
149 | ## intersection the states
150 | terminal_states_perturb = g2.macrostates.cat.categories.tolist()
151 | terminal_states_perturb = list(set(terminal_states_perturb).intersection(terminal_states))
152 |
153 | g2.set_terminal_states(
154 | terminal_states_perturb
155 | )
156 | g2.compute_fate_probabilities(solver="direct")
157 | fb = g2.fate_probabilities
158 | sampleID = adata.obs.index.tolist()
159 | fate_name = fb.names.tolist()
160 | fb = pd.DataFrame(fb,index= sampleID,columns=fate_name)
161 | fate_prob2 = pd.DataFrame(columns= terminal_states, index=sampleID)
162 |
163 | for i in terminal_states_perturb:
164 | fate_prob2.loc[:,i] = fb.loc[:,i]
165 |
166 | fate_prob2 = fate_prob2.fillna(0)
167 | arr = np.array(fate_prob2.sum(0))
168 | arr[arr!=0] = 1
169 | fate_prob = fate_prob * arr
170 |
171 | y = [0] * fate_prob.shape[0] + [1] * fate_prob2.shape[0]
172 | fate_prob2.index = [i + "_perturb" for i in fate_prob2.index]
173 | test_result = abundance_test(fate_prob, fate_prob2, method)
174 | coef.append(test_result.loc[:, "coefficient"])
175 | pvalue.append(test_result.loc[:, "FDR adjusted p-value"])
176 | logg.info("Done "+ combine_elements([tf])[0])
177 | fate_prob = fate_prob_original.copy()
178 |
179 | d = {'TF': KO_list, 'coefficient': coef, 'pvalue': pvalue}
180 | return d
--------------------------------------------------------------------------------
/regvelo/tools/TFscreening.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 |
5 | import cellrank as cr
6 | from anndata import AnnData
7 | from scvelo import logging as logg
8 | import os, shutil
9 | from typing import Dict, Optional, Sequence, Tuple, Union
10 |
11 | from .._model import REGVELOVI
12 | from .utils import split_elements, get_list_name
13 | from .TFScanning_func import TFScanning_func
14 |
15 |
16 | def TFscreening(
17 | adata : AnnData,
18 | prior_graph : torch.Tensor,
19 | lam : Optional[int] = 1,
20 | lam2 : Optional[int] = 0,
21 | soft_constraint : Optional[bool] = True,
22 | TF_list : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None,
23 | cluster_label : Optional[str] = None,
24 | terminal_states : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None,
25 | KO_list : Optional[Union[str, Sequence[str], Dict[str, Sequence[str]], pd.Series]] = None,
26 | n_states : Optional[Union[int, Sequence[int]]] = 8,
27 | cutoff : Optional[float] = 1e-3,
28 | max_nruns : Optional[float] = 5,
29 | method : Optional[str] = "likelihood",
30 | dir : Optional[str] = None
31 | ) -> Tuple[pd.DataFrame, pd.DataFrame]:
32 | """
33 | Perform in silico TF regulon knock-out screening
34 |
35 | Parameters
36 | ----------
37 | adata
38 | Anndata objects.
39 | prior_graph
40 | A prior graph for RegVelo inference
41 | lam
42 | Regularization parameter for controling the strengths of adding prior knowledge.
43 | lam2
44 | Regularization parameter for controling the strengths of L1 regularization to the Jacobian matrix.
45 | soft_constraint
46 | Apply soft constraint mode RegVelo.
47 | TF_list
48 | The TF list used for RegVelo inference.
49 | cluster_label
50 | Key in :attr:`~anndata.AnnData.obs` to associate names and colors with :attr:`terminal_states`.
51 | terminal_states
52 | subset of :attr:`macrostates`.
53 | KO_list
54 | List of TF combinations to simulate knock-out (KO) effects
55 | can be single TF e.g. geneA
56 | or double TFs e.g. geneB_geneC
57 | example input: ["geneA","geneB_geneC"]
58 | n_states
59 | Number of macrostates to compute.
60 | cutoff
61 | The threshold for determing which links need to be muted ( List[int]:
7 | """Generates a sequence of numbers from 1 to k. If the length of the sequence is less than n, the remaining positions are filled with the value k.
8 |
9 | Parameters
10 | ----------
11 | k: The last value to appear in the initial sequence.
12 | n: The target length of the sequence.
13 |
14 | Returns
15 | -------
16 | List[int]: A list of integers from 1 to k, padded with k to length n if necessary.
17 | """
18 | sequence = list(range(1, k + 1))
19 |
20 | # If the length of the sequence is already >= n, trim it to n
21 | if len(sequence) >= n:
22 | return sequence[:n]
23 |
24 | # Fill the rest of the sequence with the number k
25 | sequence.extend([k] * (n - len(sequence)))
26 |
27 | return sequence
28 |
29 |
30 | def plot_tsi(
31 | adata: AnnData,
32 | kernel: Any,
33 | threshold: float,
34 | terminal_states: Set[str],
35 | cluster_key: str,
36 | max_states: int = 12,
37 | ) -> List[int]:
38 | """Calculate the number of unique terminal states for each macrostate count.
39 |
40 | Parameters
41 | ----------
42 | adata
43 | Annotated data matrix (e.g., from single-cell experiments).
44 | kernel
45 | Computational kernel used to compute macrostates and predict terminal states.
46 | threshold
47 | Stability threshold for predicting terminal states.
48 | terminal_states
49 | Set of known terminal state names to match against.
50 | cluster_key
51 | Key in `adata.obs` that identifies cluster assignments for cells.
52 | max_states
53 | Maximum number of macrostates to consider. Default is 12.
54 |
55 | Returns
56 | -------
57 | list of int
58 | A list where each entry represents the count of unique terminal states found
59 | at each macrostate count from 1 to `max_states`.
60 | """
61 | # Create a mapping of state identifiers to their corresponding types
62 | all_states = list(set(adata.obs[cluster_key].tolist()))
63 | all_id = all_states.copy()
64 | all_type = all_states.copy()
65 | for state in all_states:
66 | for i in range(1, max_states + 1):
67 | all_id.append(f"{state}_{i}")
68 | all_type.append(state)
69 | all_dict = dict(zip(all_id, all_type))
70 |
71 | pre_value = []
72 |
73 | for num_macro in range(1, max_states):
74 | try:
75 | # Compute macrostates and predict terminal states
76 | kernel.compute_macrostates(n_states=num_macro, cluster_key=cluster_key)
77 | kernel.predict_terminal_states(stability_threshold=threshold)
78 |
79 | # Map terminal states to their types using `all_dict`
80 | pre_terminal = kernel.terminal_states.cat.categories.tolist()
81 | subset_dict = {key: all_dict[key] for key in pre_terminal}
82 | pre_terminal_names = list(set(subset_dict.values()))
83 |
84 | # Count overlap with known terminal states
85 | pre_value.append(len(set(pre_terminal_names).intersection(terminal_states)))
86 |
87 | except: # noqa
88 | # Log error and repeat the last valid value or use 0 if empty
89 | pre_value.append(pre_value[-1] if pre_value else 0)
90 |
91 | return pre_value
92 |
93 |
94 | def get_tsi_score(
95 | adata: AnnData,
96 | points: List[float],
97 | cluster_key: str,
98 | terminal_states: Set[str],
99 | kernel: Any,
100 | max_states: int = 12,
101 | ) -> List[float]:
102 | """Calculate the Terminal State Integration (TSI) score for given thresholds.
103 |
104 | Parameters
105 | ----------
106 | adata
107 | Annotated data matrix (e.g., from single-cell experiments).
108 | points
109 | List of threshold values to evaluate for stability of terminal states.
110 | cluster_key
111 | Key in `adata.obs` to access cluster assignments for cells.
112 | terminal_states
113 | Set of known terminal state names to match against.
114 | kernel
115 | Computational kernel used to compute macrostates and predict terminal states.
116 | max_states
117 | Maximum number of macrostates to consider. Default is 12.
118 |
119 | Returns
120 | -------
121 | list of float
122 | A list of TSI scores, one for each threshold in `points`. Each score represents
123 | the normalized area under the staircase function compared to the goal sequence.
124 | """
125 | # Define the goal sequence and calculate its area
126 | x_values = range(max_states)
127 | y_values = [0] + generate_sequence(len(terminal_states), max_states - 1)
128 | area_gs = sum((x_values[i + 1] - x_values[i]) * y_values[i] for i in range(len(x_values) - 1))
129 |
130 | tsi_score = []
131 |
132 | for threshold in points:
133 | # Compute the staircase function for the current threshold
134 | pre_value = plot_tsi(adata, kernel, threshold, terminal_states, cluster_key, max_states)
135 | y_values = [0] + pre_value
136 |
137 | # Calculate the area under the staircase function
138 | area_velo = sum((x_values[i + 1] - x_values[i]) * y_values[i] for i in range(len(x_values) - 1))
139 |
140 | # Compute the normalized TSI score and append to results
141 | tsi_score.append(area_velo / area_gs if area_gs else 0)
142 |
143 | return tsi_score
144 |
--------------------------------------------------------------------------------
/regvelo/tools/abundance_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import pandas as pd
4 | import numpy as np
5 |
6 | from scipy.stats import ranksums, ttest_ind
7 | from sklearn.metrics import roc_auc_score
8 |
9 | import os, shutil
10 |
11 | from .._model import REGVELOVI
12 | from .utils import p_adjust_bh
13 |
14 |
15 | def abundance_test(
16 | prob_raw : pd.DataFrame,
17 | prob_pert : pd.DataFrame,
18 | method : str = "likelihood"
19 | ) -> pd.DataFrame:
20 | """
21 | Perform an abundance test between two probability datasets.
22 |
23 | Parameters
24 | ----------
25 | prob_raw : pd.DataFrame
26 | Raw probabilities dataset.
27 | prob_pert : pd.DataFrame
28 | Perturbed probabilities dataset.
29 | method : str, optional (default="likelihood")
30 | Method to calculate scores: "likelihood" or "t-statistics".
31 |
32 | Returns
33 | -------
34 | pd.DataFrame
35 | Dataframe with coefficients, p-values, and FDR adjusted p-values.
36 | """
37 | y = [1] * prob_raw.shape[0] + [0] * prob_pert.shape[0]
38 | X = pd.concat([prob_raw, prob_pert], axis=0)
39 |
40 | table = []
41 | for i in range(prob_raw.shape[1]):
42 | pred = np.array(X.iloc[:, i])
43 | if np.sum(pred) == 0:
44 | score, pval = np.nan, np.nan
45 | else:
46 | pval = ranksums(pred[np.array(y) == 0], pred[np.array(y) == 1])[1]
47 | if method == "t-statistics":
48 | score = ttest_ind(pred[np.array(y) == 0], pred[np.array(y) == 1])[0]
49 | elif method == "likelihood":
50 | score = roc_auc_score(y, pred)
51 | else:
52 | raise NotImplementedError("Supported methods are 't-statistics' and 'likelihood'.")
53 |
54 | table.append(np.expand_dims(np.array([score, pval]), 0))
55 |
56 | table = np.concatenate(table, axis=0)
57 | table = pd.DataFrame(table, index=prob_raw.columns, columns=["coefficient", "p-value"])
58 | table["FDR adjusted p-value"] = p_adjust_bh(table["p-value"].tolist())
59 | return table
--------------------------------------------------------------------------------
/regvelo/tools/depletion_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | from anndata import AnnData
4 | from typing import Union, Sequence, Any, Optional, Tuple
5 | import cellrank as cr
6 |
7 | from .abundance_test import abundance_test
8 |
9 | def depletion_score(perturbed : dict[str, AnnData],
10 | baseline : AnnData,
11 | terminal_state : Union[str, Sequence[str]],
12 | **kwargs : Any,
13 | ) -> Tuple[pd.DataFrame, dict[str, AnnData]]:
14 | """
15 | Compute depletion scores.
16 |
17 | Parameters
18 | ----------
19 | perturbed : dict[str, AnnData]
20 | Dictionary mapping TF candidates to their perturbed AnnData objects.
21 | baseline : AnnData
22 | Annotated data matrix. Fate probabilities already computed.
23 | terminal_state : str or Sequence[str]
24 | List of terminal states to compute probabilities for.
25 | kwargs : Any
26 | Optional
27 | Additional keyword arguments passed to CellRank and plot functions.
28 |
29 | Returns
30 | -------
31 | Tuple[pd.DataFrame, dict[str, AnnData]]
32 | A tuple containing:
33 |
34 | - **df** – Summary of depletion scores and associated statistics.
35 | - **adata_perturb_dict** – Dictionary mapping TFs to their perturbed AnnData objects.
36 | """
37 |
38 | macro_kwargs = {k: kwargs[k] for k in ("n_states", "n_cells", "cluster_key", "method") if k in kwargs}
39 | compute_fate_probabilities_kwargs = {k: kwargs[k] for k in ("solver", "tol") if k in kwargs}
40 |
41 | if "lineages_fwd" not in baseline.obsm:
42 | raise KeyError("Lineages not found in baseline.obsm. Please compute lineages first.")
43 |
44 | if isinstance(terminal_state, str):
45 | terminal_state = [terminal_state]
46 |
47 | # selecting indices of cells that have reached a terminal state
48 | ct_indices = {
49 | ct: baseline.obs["term_states_fwd"][baseline.obs["term_states_fwd"] == ct].index.tolist()
50 | for ct in terminal_state
51 | }
52 |
53 | fate_prob_perturb = {}
54 | for TF, adata_target_perturb in perturbed.items():
55 | vk = cr.kernels.VelocityKernel(adata_target_perturb).compute_transition_matrix()
56 | vk.write_to_adata()
57 | estimator = cr.estimators.GPCCA(vk)
58 |
59 | estimator.compute_macrostates(**macro_kwargs)
60 |
61 | estimator.set_terminal_states(ct_indices)
62 | estimator.compute_fate_probabilities(**compute_fate_probabilities_kwargs)
63 |
64 | perturbed[TF] = adata_target_perturb
65 |
66 | perturbed_prob = pd.DataFrame(
67 | adata_target_perturb.obsm["lineages_fwd"],
68 | columns=adata_target_perturb.obsm["lineages_fwd"].names.tolist()
69 | )[terminal_state]
70 |
71 | fate_prob_perturb[TF] = perturbed_prob
72 |
73 | fate_prob_raw = pd.DataFrame(
74 | baseline.obsm["lineages_fwd"],
75 | columns=baseline.obsm["lineages_fwd"].names.tolist()
76 | )
77 |
78 | dfs = []
79 | for TF, perturbed_prob in fate_prob_perturb.items():
80 | stats = abundance_test(prob_raw=fate_prob_raw, prob_pert=perturbed_prob)
81 | df = pd.DataFrame(
82 | {
83 | "Depletion score": stats.iloc[:, 0].tolist(),
84 | "p-value": stats.iloc[:, 1].tolist(),
85 | "FDR adjusted p-value": stats.iloc[:, 2].tolist(),
86 | "Terminal state": stats.index.tolist(),
87 | "TF": [TF] * stats.shape[0],
88 | }
89 | )
90 | dfs.append(df)
91 |
92 | df = pd.concat(dfs)
93 |
94 | df["Depletion score"] = 2 * (0.5 - df["Depletion score"])
95 |
96 | return df, perturbed
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/regvelo/tools/in_silico_block_simulation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 |
5 | from anndata import AnnData
6 |
7 | import os, shutil
8 |
9 | from .._model import REGVELOVI
10 |
11 | def in_silico_block_simulation(
12 | model : str,
13 | adata : AnnData,
14 | TF : str,
15 | effects : int = 0,
16 | cutoff : int = 1e-3,
17 | customized_GRN : torch.Tensor = None
18 | ) -> tuple:
19 | """ Perform in silico TF regulon knock-out
20 |
21 | Parameters
22 | ----------
23 | model
24 | The saved address for the RegVelo model.
25 | adata
26 | Anndata objects.
27 | TF
28 | The candidate TF, need to knockout its regulon.
29 | effect
30 | The coefficient for replacing the weights in GRN
31 | cutoff
32 | The threshold for determing which links need to be muted,
33 | customized_GRN
34 | The customized perturbed GRN
35 | """
36 |
37 | reg_vae_perturb = REGVELOVI.load(model,adata)
38 |
39 | perturb_GRN = reg_vae_perturb.module.v_encoder.fc1.weight.detach().clone()
40 |
41 | if customized_GRN is None:
42 | perturb_GRN[(perturb_GRN[:,[i == TF for i in adata.var.index]].abs()>cutoff).cpu().numpy().reshape(-1),[i == TF for i in adata.var.index]] = effects
43 | reg_vae_perturb.module.v_encoder.fc1.weight.data = perturb_GRN
44 | else:
45 | device = perturb_GRN.device
46 | customized_GRN = customized_GRN.to(device)
47 | reg_vae_perturb.module.v_encoder.fc1.weight.data = customized_GRN
48 |
49 | adata_target_perturb = reg_vae_perturb.add_regvelo_outputs_to_adata(adata = adata)
50 |
51 | return adata_target_perturb, reg_vae_perturb
--------------------------------------------------------------------------------
/regvelo/tools/markov_density_simulation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | def markov_density_simulation(
5 | adata : "AnnData",
6 | T : np.ndarray,
7 | start_indices,
8 | terminal_indices,
9 | terminal_states,
10 | n_steps : int = 100,
11 | n_simulations : int = 200,
12 | method: str = "stepwise",
13 | seed : int = 0,
14 | ):
15 |
16 | """
17 | Simulate transitions on a velocity-derived Markov transition matrix.
18 |
19 | Parameters
20 | ----------
21 | adata : AnnData
22 | Annotated data object.
23 | T : np.ndarray
24 | Transition matrix of shape (n_cells, n_cells).
25 | start_indices : array-like
26 | Indices of starting cells.
27 | terminal_indices : array-like
28 | Indices of terminal (absorbing) cells.
29 | terminal_states : array-like
30 | Labels of terminal states corresponding to cells in `adata.obs["term_states_fwd"]`.
31 | n_steps : int, optional
32 | Maximum number of steps per simulation (default: 100).
33 | n_simulations : int, optional
34 | Number of simulations per starting cell (default: 200).
35 | method : {'stepwise', 'one-step'}, optional
36 | Simulation method to use:
37 | - 'stepwise': simulate trajectories step by step.
38 | - 'one-step': sample directly from T^n.
39 | seed : int, optional
40 | Random seed for reproducibility (default: 0).
41 |
42 | Returns
43 | -------
44 | visits : pd.Series
45 | Number of simulations that ended in each terminal cell.
46 | visits_dens : pd.Series
47 | Proportion of simulations that ended in each terminal cell.
48 | """
49 | np.random.seed(seed)
50 |
51 | T = np.asarray(T)
52 | start_indices = np.asarray(start_indices)
53 | terminal_indices = np.asarray(terminal_indices)
54 | terminal_set = set(terminal_indices)
55 | n_cells = T.shape[0]
56 |
57 | arrivals_array = np.zeros(n_cells, dtype=int)
58 |
59 | if method == "stepwise":
60 | row_sums = T.sum(axis=1)
61 | cum_T = np.cumsum(T, axis=1)
62 |
63 | for start in start_indices:
64 | for _ in range(n_simulations):
65 | current = start
66 | for _ in range(n_steps):
67 | if row_sums[current] == 0:
68 | break # dead end
69 | r = np.random.rand()
70 | next_state = np.searchsorted(cum_T[current], r * row_sums[current])
71 | current = next_state
72 | if current in terminal_set:
73 | arrivals_array[current] += 1
74 | break
75 |
76 | elif method == "one-step":
77 | T_end = np.linalg.matrix_power(T, n_steps)
78 | for start in start_indices:
79 | x0 = np.zeros(n_cells)
80 | x0[start] = 1
81 | x_end = x0 @ T_end # final distribution
82 | if x_end.sum() > 0:
83 | samples = np.random.choice(n_cells, size=n_simulations, p=x_end)
84 | for s in samples:
85 | if s in terminal_set:
86 | arrivals_array[s] += 1
87 | else:
88 | raise ValueError(f"Invalid probability distribution: x_end sums to 0 for start index {start}")
89 | else:
90 | raise ValueError("method must be 'stepwise' or 'one-step'")
91 |
92 | total_simulations = n_simulations * len(start_indices)
93 | visits = pd.Series({tid: arrivals_array[tid] for tid in terminal_indices}, dtype=int)
94 | visits_dens = pd.Series({tid: arrivals_array[tid] / total_simulations for tid in terminal_indices})
95 |
96 | adata.obs[f"visits_{method}"] = np.nan
97 | adata.obs[f"visits_{method}"].iloc[terminal_indices] = visits_dens
98 |
99 | dens_cum = []
100 | for ts in terminal_states:
101 | ts_cells = np.where(adata.obs["term_states_fwd"] == ts)[0]
102 | density = visits_dens.loc[ts_cells].sum()
103 | dens_cum.append(density)
104 |
105 | print("Proportion of simulations reaching a terminal cell", sum(dens_cum))
106 |
107 | return visits, visits_dens
--------------------------------------------------------------------------------
/regvelo/tools/perturbation_effect.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import pandas as pd
4 | from anndata import AnnData
5 | from typing import Union, Sequence
6 |
7 | def perturbation_effect(
8 | adata_perturb : AnnData,
9 | adata : AnnData,
10 | terminal_state : Union[str, Sequence[str]],
11 | ) -> AnnData:
12 | """
13 | Compute change in fate probabilities towards terminal states after perturbation. Negative values correspond to a decrease in
14 | probabilities, while positive values indicate an increase.
15 |
16 | Parameters
17 | ----------
18 | adata_perturb : AnnData
19 | Annotated data matrix of perturbed GRN.
20 | adata : AnnData
21 | Annotated data matrix of unperturbed GRN.
22 | terminal_state : str or Sequence[str]
23 | List of terminal states to compute probabilities for.
24 |
25 | Returns
26 | -------
27 | AnnData
28 | Annotated data matrix with the following added:
29 | - `terminal state perturbation` : Change in fate probabilities towards terminal state after perturbation.
30 | """
31 |
32 | if isinstance(terminal_state, str):
33 | terminal_state = [terminal_state]
34 |
35 | if "lineages_fwd" in adata.obsm and "lineages_fwd" in adata_perturb.obsm:
36 | perturb_df = pd.DataFrame(
37 | adata_perturb.obsm["lineages_fwd"], columns=adata_perturb.obsm["lineages_fwd"].names.tolist()
38 | )
39 | original_df = pd.DataFrame(
40 | adata.obsm["lineages_fwd"], columns=adata.obsm["lineages_fwd"].names.tolist()
41 | )
42 |
43 | for state in terminal_state:
44 | adata.obs[f"perturbation effect on {state}"] = np.array(perturb_df[state] - original_df[state])
45 |
46 | return adata
47 | else:
48 | raise ValueError("Lineages not computed. Please compute lineages before using this function.")
--------------------------------------------------------------------------------
/regvelo/tools/set_output.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from anndata import AnnData
4 | from typing import Any
5 |
6 |
7 | # Code mostly taken from veloVI reproducibility repo
8 | # https://yoseflab.github.io/velovi_reproducibility/estimation_comparison/simulation_w_inferred_rates.html
9 |
10 |
11 | def set_output(
12 | adata : AnnData,
13 | vae : Any,
14 | n_samples: int = 30,
15 | batch_size: int | None = None
16 | ) -> None:
17 | """
18 | Add inference results to adata.
19 | Parameters
20 | ----------
21 | adata : AnnData
22 | Annotated data matrix.
23 | vae : Any
24 | RegVelo model
25 | n_samples : int, optional
26 | Number of posterior samples to use for estimation. Default is 30.
27 | batch_size : int, optional
28 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
29 | """
30 |
31 | latent_time = vae.get_latent_time(n_samples=n_samples, batch_size=batch_size)
32 | velocities = vae.get_velocity(n_samples=n_samples, batch_size=batch_size)
33 |
34 | t = latent_time.values
35 | scaling = 20 / t.max(0)
36 |
37 | adata.layers["velocity"] = velocities / scaling
38 | adata.layers["latent_time_velovi"] = latent_time
39 |
40 | rates = vae.get_rates()
41 | if "alpha" in rates:
42 | adata.var["fit_alpha"] = rates["alpha"] / scaling
43 | adata.var["fit_beta"] = rates["beta"] / scaling
44 | adata.var["fit_gamma"] = rates["gamma"] / scaling
45 |
46 | adata.layers["fit_t"] = latent_time * scaling[np.newaxis, :]
47 | adata.var["fit_scaling"] = 1.0
--------------------------------------------------------------------------------
/regvelo/tools/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 |
5 | import cellrank as cr
6 | from anndata import AnnData
7 | from scvelo import logging as logg
8 | import os,shutil
9 | from typing import Dict, Optional, Sequence, Tuple, Union
10 |
11 | from .._model import REGVELOVI
12 |
13 | def split_elements(character_list):
14 | """split elements."""
15 | result_list = []
16 | for element in character_list:
17 | if '_' in element:
18 | parts = element.split('_')
19 | result_list.append(parts)
20 | else:
21 | result_list.append([element])
22 | return result_list
23 |
24 | def combine_elements(split_list):
25 | """combine elements."""
26 | result_list = []
27 | for parts in split_list:
28 | combined_element = "_".join(parts)
29 | result_list.append(combined_element)
30 | return result_list
31 |
32 | def get_list_name(lst):
33 | names = []
34 | for name, obj in lst.items():
35 | names.append(name)
36 | return names
37 |
38 | def p_adjust_bh(p):
39 | """Benjamini-Hochberg p-value correction for multiple hypothesis testing."""
40 | p = np.asfarray(p)
41 | by_descend = p.argsort()[::-1]
42 | by_orig = by_descend.argsort()
43 | steps = float(len(p)) / np.arange(len(p), 0, -1)
44 | q = np.minimum(1, np.minimum.accumulate(steps * p[by_descend]))
45 | return q[by_orig]
46 |
47 |
--------------------------------------------------------------------------------
/reproduce_env/regvelo.yaml:
--------------------------------------------------------------------------------
1 | name: regvelo_test
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - bioconda
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1
10 | - _openmp_mutex=4.5
11 | - absl-py=2.1.0
12 | - alsa-lib=1.2.13
13 | - aom=3.9.1
14 | - arboreto=0.1.6
15 | - attr=2.5.1
16 | - aws-c-auth=0.8.0
17 | - aws-c-cal=0.8.1
18 | - aws-c-common=0.10.5
19 | - aws-c-compression=0.3.0
20 | - aws-c-event-stream=0.5.0
21 | - aws-c-http=0.9.2
22 | - aws-c-io=0.15.3
23 | - aws-c-mqtt=0.11.0
24 | - aws-c-s3=0.7.5
25 | - aws-c-sdkutils=0.2.1
26 | - aws-checksums=0.2.2
27 | - aws-crt-cpp=0.29.7
28 | - aws-sdk-cpp=1.11.458
29 | - azure-core-cpp=1.14.0
30 | - azure-identity-cpp=1.10.0
31 | - azure-storage-blobs-cpp=12.13.0
32 | - azure-storage-common-cpp=12.8.0
33 | - azure-storage-files-datalake-cpp=12.12.0
34 | - bedtools=2.31.1
35 | - binutils=2.43
36 | - binutils_impl_linux-64=2.43
37 | - binutils_linux-64=2.43
38 | - blosc=1.21.6
39 | - bokeh=3.6.2
40 | - brotli=1.1.0
41 | - brotli-bin=1.1.0
42 | - brotli-python=1.1.0
43 | - bzip2=1.0.8
44 | - c-ares=1.34.3
45 | - c-compiler=1.7.0
46 | - ca-certificates=2025.1.31
47 | - cached-property=1.5.2
48 | - cached_property=1.5.2
49 | - cccl=2.5.0
50 | - cellrank=2.0.6
51 | - certifi=2025.1.31
52 | - cffi=1.17.1
53 | - charset-normalizer=3.4.0
54 | - click=8.1.7
55 | - cloudpickle=3.1.0
56 | - colorama=0.4.6
57 | - cpython=3.10.15
58 | - cuda-cccl=12.4.127
59 | - cuda-cccl_linux-64=12.4.127
60 | - cuda-command-line-tools=12.6.2
61 | - cuda-compiler=12.6.2
62 | - cuda-crt-dev_linux-64=12.4.131
63 | - cuda-crt-tools=12.4.131
64 | - cuda-cudart=12.4.127
65 | - cuda-cudart-dev=12.4.127
66 | - cuda-cudart-dev_linux-64=12.4.127
67 | - cuda-cudart-static=12.4.127
68 | - cuda-cudart-static_linux-64=12.4.127
69 | - cuda-cudart_linux-64=12.4.127
70 | - cuda-cuobjdump=12.4.127
71 | - cuda-cupti=12.4.127
72 | - cuda-cupti-dev=12.4.127
73 | - cuda-cuxxfilt=12.4.127
74 | - cuda-documentation=12.4.127
75 | - cuda-driver-dev=12.4.127
76 | - cuda-driver-dev_linux-64=12.4.127
77 | - cuda-gdb=12.4.127
78 | - cuda-libraries=12.4.1
79 | - cuda-libraries-dev=12.6.2
80 | - cuda-libraries-static=12.6.2
81 | - cuda-nsight=12.4.127
82 | - cuda-nvcc=12.4.131
83 | - cuda-nvcc-dev_linux-64=12.4.131
84 | - cuda-nvcc-impl=12.4.131
85 | - cuda-nvcc-tools=12.4.131
86 | - cuda-nvcc_linux-64=12.4.1
87 | - cuda-nvdisasm=12.4.127
88 | - cuda-nvml-dev=12.4.127
89 | - cuda-nvprof=12.4.127
90 | - cuda-nvprune=12.4.127
91 | - cuda-nvrtc=12.4.127
92 | - cuda-nvrtc-dev=12.4.127
93 | - cuda-nvrtc-static=12.4.127
94 | - cuda-nvtx=12.4.127
95 | - cuda-nvvm-dev_linux-64=12.4.131
96 | - cuda-nvvm-impl=12.4.131
97 | - cuda-nvvm-tools=12.4.131
98 | - cuda-nvvp=12.4.127
99 | - cuda-opencl=12.4.127
100 | - cuda-opencl-dev=12.4.127
101 | - cuda-profiler-api=12.4.127
102 | - cuda-runtime=12.4.1
103 | - cuda-sanitizer-api=12.4.127
104 | - cuda-toolkit=12.4.1
105 | - cuda-tools=12.6.2
106 | - cuda-version=12.4
107 | - cuda-visual-tools=12.6.2
108 | - cxx-compiler=1.7.0
109 | - cycler=0.12.1
110 | - cytoolz=1.0.0
111 | - dav1d=1.2.1
112 | - dbus=1.13.6
113 | - docrep=0.3.2
114 | - et_xmlfile=2.0.0
115 | - exceptiongroup=1.2.2
116 | - expat=2.6.4
117 | - ffmpeg=4.4.0
118 | - fftw=3.3.10
119 | - filelock=3.16.1
120 | - font-ttf-dejavu-sans-mono=2.37
121 | - font-ttf-inconsolata=3.000
122 | - font-ttf-source-code-pro=2.038
123 | - font-ttf-ubuntu=0.83
124 | - fontconfig=2.15.0
125 | - fonts-conda-ecosystem=1
126 | - fonts-conda-forge=1
127 | - freetype=2.12.1
128 | - future=1.0.0
129 | - gcc=12.4.0
130 | - gcc_impl_linux-64=12.4.0
131 | - gcc_linux-64=12.4.0
132 | - gds-tools=1.9.1.3
133 | - get-annotations=0.1.2
134 | - gflags=2.2.2
135 | - giflib=5.2.2
136 | - glog=0.7.1
137 | - gmp=6.3.0
138 | - gmpy2=2.1.5
139 | - gnutls=3.6.13
140 | - gxx=12.4.0
141 | - gxx_impl_linux-64=12.4.0
142 | - gxx_linux-64=12.4.0
143 | - h2=4.1.0
144 | - h5py=3.12.1
145 | - hdf5=1.14.3
146 | - hpack=4.0.0
147 | - htslib=1.21
148 | - humanize=4.11.0
149 | - hyperframe=6.0.1
150 | - hypre=2.31.0
151 | - idna=3.10
152 | - importlib-metadata=8.5.0
153 | - importlib_metadata=8.5.0
154 | - importlib_resources=6.4.5
155 | - jinja2=3.1.4
156 | - joblib=1.4.2
157 | - kernel-headers_linux-64=3.10.0
158 | - keyutils=1.6.1
159 | - kiwisolver=1.4.7
160 | - krb5=1.21.3
161 | - lame=3.100
162 | - lcms2=2.16
163 | - ld_impl_linux-64=2.43
164 | - lerc=4.0.0
165 | - libabseil=20240722.0
166 | - libaec=1.1.3
167 | - libarrow=18.1.0
168 | - libarrow-acero=18.1.0
169 | - libarrow-dataset=18.1.0
170 | - libarrow-substrait=18.1.0
171 | - libavif16=1.1.1
172 | - libblas=3.9.0
173 | - libbrotlicommon=1.1.0
174 | - libbrotlidec=1.1.0
175 | - libbrotlienc=1.1.0
176 | - libcap=2.69
177 | - libcblas=3.9.0
178 | - libcrc32c=1.1.2
179 | - libcublas=12.4.5.8
180 | - libcublas-dev=12.4.5.8
181 | - libcublas-static=12.4.5.8
182 | - libcufft=11.2.1.3
183 | - libcufft-dev=11.2.1.3
184 | - libcufft-static=11.2.1.3
185 | - libcufile=1.9.1.3
186 | - libcufile-dev=1.9.1.3
187 | - libcufile-static=1.9.1.3
188 | - libcurand=10.3.5.147
189 | - libcurand-dev=10.3.5.147
190 | - libcurand-static=10.3.5.147
191 | - libcurl=8.10.1
192 | - libcusolver=11.6.1.9
193 | - libcusolver-dev=11.6.1.9
194 | - libcusolver-static=11.6.1.9
195 | - libcusparse=12.3.1.170
196 | - libcusparse-dev=12.3.1.170
197 | - libcusparse-static=12.3.1.170
198 | - libdeflate=1.22
199 | - libedit=3.1.20191231
200 | - libev=4.33
201 | - libevent=2.1.12
202 | - libexpat=2.6.4
203 | - libfabric=1.22.0
204 | - libffi=3.4.2
205 | - libgcc=14.1.0
206 | - libgcc-devel_linux-64=12.4.0
207 | - libgcc-ng=14.1.0
208 | - libgcrypt=1.11.0
209 | - libgcrypt-devel=1.11.0
210 | - libgcrypt-lib=1.11.0
211 | - libgcrypt-tools=1.11.0
212 | - libgfortran=14.1.0
213 | - libgfortran-ng=14.1.0
214 | - libgfortran5=14.1.0
215 | - libglib=2.82.2
216 | - libgomp=14.1.0
217 | - libgoogle-cloud=2.31.0
218 | - libgoogle-cloud-storage=2.31.0
219 | - libgpg-error=1.51
220 | - libgrpc=1.67.1
221 | - libhwloc=2.11.2
222 | - libiconv=1.17
223 | - libjpeg-turbo=3.0.0
224 | - liblapack=3.9.0
225 | - libllvm14=14.0.6
226 | - libnghttp2=1.64.0
227 | - libnl=3.11.0
228 | - libnpp=12.2.5.30
229 | - libnpp-dev=12.2.5.30
230 | - libnpp-static=12.2.5.30
231 | - libnsl=2.0.1
232 | - libnvfatbin=12.4.127
233 | - libnvfatbin-dev=12.4.127
234 | - libnvfatbin-static=12.4.127
235 | - libnvjitlink=12.4.127
236 | - libnvjitlink-dev=12.4.127
237 | - libnvjitlink-static=12.4.127
238 | - libnvjpeg=12.3.1.117
239 | - libnvjpeg-dev=12.3.1.117
240 | - libnvjpeg-static=12.3.1.117
241 | - libopenblas=0.3.28
242 | - libparquet=18.1.0
243 | - libpng=1.6.44
244 | - libprotobuf=5.28.2
245 | - libptscotch=7.0.4
246 | - libre2-11=2024.07.02
247 | - libsanitizer=12.4.0
248 | - libscotch=7.0.4
249 | - libsqlite=3.46.1
250 | - libssh2=1.11.1
251 | - libstdcxx=14.1.0
252 | - libstdcxx-devel_linux-64=12.4.0
253 | - libstdcxx-ng=14.1.0
254 | - libsystemd0=256.7
255 | - libthrift=0.21.0
256 | - libtiff=4.7.0
257 | - libtorch=2.5.1
258 | - libudev1=256.7
259 | - libutf8proc=2.9.0
260 | - libuuid=2.38.1
261 | - libuv=1.49.2
262 | - libvpx=1.11.0
263 | - libwebp=1.4.0
264 | - libwebp-base=1.4.0
265 | - libxcb=1.17.0
266 | - libxcrypt=4.4.36
267 | - libxkbcommon=1.7.0
268 | - libxkbfile=1.1.0
269 | - libxml2=2.13.5
270 | - libzlib=1.3.1
271 | - locket=1.0.0
272 | - lz4=4.3.3
273 | - lz4-c=1.9.4
274 | - markdown=3.6
275 | - markdown-it-py=3.0.0
276 | - markupsafe=3.0.2
277 | - mdurl=0.1.2
278 | - metis=5.1.0
279 | - mpc=1.3.1
280 | - mpfr=4.2.1
281 | - mpi=1.0.1
282 | - mpi4py=4.0.1
283 | - mpich=4.2.3
284 | - mpmath=1.3.0
285 | - msgpack-python=1.1.0
286 | - mudata=0.3.1
287 | - mumps-include=5.7.3
288 | - mumps-mpi=5.7.3
289 | - munkres=1.1.4
290 | - natsort=8.4.0
291 | - ncurses=6.5
292 | - nest-asyncio=1.6.0
293 | - nettle=3.6
294 | - nomkl=1.0
295 | - nsight-compute=2024.1.1.4
296 | - nspr=4.36
297 | - nss=3.105
298 | - numba=0.60.0
299 | - numpy=1.26.4
300 | - numpy_groupies=0.11.2
301 | - numpyro=0.15.3
302 | - ocl-icd=2.3.2
303 | - openh264=2.1.1
304 | - openjpeg=2.5.2
305 | - openpyxl=3.1.5
306 | - openssl=3.4.1
307 | - opt-einsum=3.4.0
308 | - opt_einsum=3.4.0
309 | - orc=2.0.3
310 | - pandas=2.2.3
311 | - parmetis=4.0.3
312 | - partd=1.4.2
313 | - pcre2=10.44
314 | - petsc=3.21.5
315 | - petsc4py=3.21.5
316 | - pip=24.2
317 | - progressbar2=4.5.0
318 | - protobuf=5.28.2
319 | - pthread-stubs=0.4
320 | - pyarrow-core=18.1.0
321 | - pybind11-abi=4
322 | - pycparser=2.22
323 | - pygam=0.9.1
324 | - pygments=2.18.0
325 | - pygpcca=1.0.4
326 | - pynndescent=0.5.13
327 | - pyro-api=0.1.2
328 | - pysocks=1.7.1
329 | - python=3.10.15
330 | - python-dateutil=2.9.0.post0
331 | - python-tzdata=2024.2
332 | - python-utils=3.9.0
333 | - python_abi=3.10
334 | - pytorch-cuda=12.4
335 | - pytorch-lightning=2.4.0
336 | - pytorch-mutex=1.0
337 | - pyyaml=6.0.2
338 | - qhull=2020.2
339 | - rav1e=0.6.6
340 | - rdma-core=54.0
341 | - re2=2024.07.02
342 | - readline=8.2
343 | - requests=2.32.3
344 | - s2n=1.5.9
345 | - samtools=1.21
346 | - scalapack=2.2.0
347 | - scanpy=1.10.3
348 | - scvi-tools=1.2.0
349 | - seaborn=0.13.2
350 | - seaborn-base=0.13.2
351 | - session-info=1.0.0
352 | - setuptools=75.1.0
353 | - six=1.16.0
354 | - sleef=3.7
355 | - slepc=3.21.2
356 | - slepc4py=3.21.2
357 | - snappy=1.2.1
358 | - sortedcontainers=2.4.0
359 | - sparse=0.15.4
360 | - statsmodels=0.14.4
361 | - suitesparse=7.8.3
362 | - superlu=5.2.2
363 | - superlu_dist=9.1.0
364 | - svt-av1=2.3.0
365 | - sysroot_linux-64=2.17
366 | - tbb=2022.0.0
367 | - tblib=3.0.0
368 | - tensorboard=2.18.0
369 | - threadpoolctl=3.5.0
370 | - tk=8.6.13
371 | - toolz=1.0.0
372 | - tornado=6.4.2
373 | - typing-extensions=4.12.2
374 | - typing_extensions=4.12.2
375 | - tzdata=2024b
376 | - ucx=1.17.0
377 | - unicodedata2=15.1.0
378 | - urllib3=2.2.3
379 | - wayland=1.23.1
380 | - wheel=0.44.0
381 | - x264=1!161.3030
382 | - x265=3.5
383 | - xcb-util=0.4.1
384 | - xcb-util-cursor=0.1.5
385 | - xcb-util-image=0.4.0
386 | - xcb-util-keysyms=0.4.1
387 | - xcb-util-renderutil=0.3.10
388 | - xcb-util-wm=0.4.2
389 | - xkeyboard-config=2.43
390 | - xorg-libice=1.1.1
391 | - xorg-libsm=1.2.4
392 | - xorg-libx11=1.8.10
393 | - xorg-libxau=1.0.11
394 | - xorg-libxcomposite=0.4.6
395 | - xorg-libxdamage=1.1.6
396 | - xorg-libxdmcp=1.1.5
397 | - xorg-libxext=1.3.6
398 | - xorg-libxfixes=6.0.1
399 | - xorg-libxi=1.8.2
400 | - xorg-libxrandr=1.5.4
401 | - xorg-libxrender=0.9.11
402 | - xorg-libxtst=1.2.5
403 | - xorg-xorgproto=2024.1
404 | - xyzservices=2024.9.0
405 | - xz=5.2.6
406 | - yaml=0.2.5
407 | - zict=3.0.0
408 | - zlib=1.3.1
409 | - zstandard=0.23.0
410 | - zstd=1.5.6
411 | - pip:
412 | - aiohappyeyeballs==2.4.3
413 | - aiohttp==3.10.10
414 | - aiosignal==1.3.1
415 | - alabaster==0.7.16
416 | - anndata==0.11.0rc2
417 | - annotated-types==0.7.0
418 | - anyio==4.6.0
419 | - appdirs==1.4.4
420 | - argon2-cffi==23.1.0
421 | - argon2-cffi-bindings==21.2.0
422 | - array-api-compat==1.9
423 | - arrow==1.3.0
424 | - asciitree==0.3.3
425 | - asttokens==2.4.1
426 | - astunparse==1.6.3
427 | - async-lru==2.0.4
428 | - async-timeout==4.0.3
429 | - babel==2.16.0
430 | - backoff==2.2.1
431 | - biofluff==3.0.4
432 | - biopython==1.84
433 | - biothings-client==0.3.1
434 | - blessed==1.20.0
435 | - boltons==24.1.0
436 | - boto3==1.35.63
437 | - botocore==1.35.63
438 | - bucketcache==0.12.1
439 | - celloracle==0.20.0
440 | - chex==0.1.87
441 | - colorcet==3.1.0
442 | - comm==0.2.2
443 | - configparser==7.1.0
444 | - contextlib2==21.6.0
445 | - contourpy==1.3.0
446 | - croniter==1.4.1
447 | - csaps==1.2.0
448 | - cython==3.0.11
449 | - dask==2024.2.1
450 | - dask-expr==0.5.3
451 | - dateutils==0.6.12
452 | - debugpy==1.8.7
453 | - deepdiff==7.0.1
454 | - diskcache==5.6.3
455 | - distributed==2024.2.1
456 | - dm-tree==0.1.8
457 | - docopt==0.6.2
458 | - docutils==0.21.2
459 | - editor==1.6.6
460 | - etils==1.9.4
461 | - executing==2.1.0
462 | - fa2-modified==0.3.10
463 | - fastapi==0.115.5
464 | - fasteners==0.19
465 | - feather-format==0.4.1
466 | - flatbuffers==24.3.25
467 | - flax==0.9.0
468 | - fonttools==4.54.1
469 | - fqdn==1.5.1
470 | - frozenlist==1.4.1
471 | - fsspec==2024.9.0
472 | - ftpretty==0.4.0
473 | - gast==0.6.0
474 | - gdown==5.2.0
475 | - genomepy==0.16.1
476 | - gimmemotifs==0.17.2
477 | - goatools==1.4.12
478 | - google-pasta==0.2.0
479 | - grpcio==1.66.2
480 | - h11==0.14.0
481 | - hnswlib==0.8.0
482 | - htseq==2.0.9
483 | - httpcore==1.0.6
484 | - httpx==0.27.2
485 | - igraph==0.11.6
486 | - imagesize==1.4.1
487 | - inquirer==3.4.0
488 | - iprogress==0.4
489 | - ipykernel==6.29.5
490 | - ipython==8.28.0
491 | - ipywidgets==8.1.5
492 | - isoduration==20.11.0
493 | - iteround==1.0.4
494 | - itsdangerous==2.2.0
495 | - jax==0.4.34
496 | - jaxlib==0.4.34
497 | - jedi==0.19.1
498 | - jmespath==1.0.1
499 | - joypy==0.2.6
500 | - json5==0.9.25
501 | - jsonpickle==4.0.0
502 | - jsonpointer==3.0.0
503 | - jupyter==1.1.1
504 | - jupyter-console==6.6.3
505 | - jupyter-events==0.10.0
506 | - jupyter-lsp==2.2.5
507 | - jupyter-server==2.14.2
508 | - jupyter-server-terminals==0.5.3
509 | - jupyterlab==4.2.5
510 | - jupyterlab-server==2.27.3
511 | - jupyterlab-widgets==3.0.13
512 | - keras==3.7.0
513 | - legacy-api-wrap==1.4
514 | - leidenalg==0.10.2
515 | - libclang==18.1.1
516 | - lightning==2.0.9.post0
517 | - lightning-cloud==0.5.70
518 | - lightning-utilities==0.11.7
519 | - llvmlite==0.43.0
520 | - logbook==1.8.0
521 | - logomaker==0.8
522 | - loguru==0.7.2
523 | - loompy==3.0.7
524 | - louvain==0.8.2
525 | - matplotlib==3.6.3
526 | - matplotlib-inline==0.1.7
527 | - ml-collections==0.1.1
528 | - ml-dtypes==0.4.1
529 | - mplscience==0.0.7
530 | - multidict==6.1.0
531 | - multipledispatch==1.0.0
532 | - mygene==3.2.2
533 | - mysql-connector-python==9.1.0
534 | - namex==0.0.8
535 | - networkx==3.4.1
536 | - norns==0.1.6
537 | - nose==1.3.7
538 | - notebook==7.2.2
539 | - notebook-shim==0.2.4
540 | - numcodecs==0.13.1
541 | - numdifftools==0.9.41
542 | - nvidia-cublas-cu12==12.4.5.8
543 | - nvidia-cuda-cupti-cu12==12.4.127
544 | - nvidia-cuda-nvrtc-cu12==12.4.127
545 | - nvidia-cuda-runtime-cu12==12.4.127
546 | - nvidia-cudnn-cu12==9.1.0.70
547 | - nvidia-cufft-cu12==11.2.1.3
548 | - nvidia-curand-cu12==10.3.5.147
549 | - nvidia-cusolver-cu12==11.6.1.9
550 | - nvidia-cusparse-cu12==12.3.1.170
551 | - nvidia-nccl-cu12==2.21.5
552 | - nvidia-nvjitlink-cu12==12.4.127
553 | - nvidia-nvtx-cu12==12.4.127
554 | - optax==0.2.3
555 | - optree==0.13.1
556 | - orbax-checkpoint==0.7.0
557 | - ordered-set==4.1.0
558 | - overrides==7.7.0
559 | - packaging==24.1
560 | - parso==0.8.4
561 | - patsy==0.5.6
562 | - pexpect==4.9.0
563 | - pillow==10.4.0
564 | - platformdirs==4.3.6
565 | - plotly==5.24.1
566 | - prometheus-client==0.21.0
567 | - prompt-toolkit==3.0.48
568 | - propcache==0.2.0
569 | - psutil==6.0.0
570 | - ptyprocess==0.7.0
571 | - pure-eval==0.2.3
572 | - pyarrow==18.0.0
573 | - pybedtools==0.10.0
574 | - pybigwig==0.3.23
575 | - pydantic==2.1.1
576 | - pydantic-core==2.4.0
577 | - pydot==3.0.2
578 | - pyfaidx==0.8.1.3
579 | - pyjwt==2.10.0
580 | - pymde==0.1.18
581 | - pyparsing==3.1.4
582 | - pyro-ppl==1.9.1
583 | - pysam==0.22.1
584 | - python-igraph==0.11.6
585 | - python-json-logger==2.0.7
586 | - python-multipart==0.0.17
587 | - pytz==2024.2
588 | - pyvis==0.3.2
589 | - qnorm==0.8.1
590 | - readchar==4.2.1
591 | - regvelo==0.1.0
592 | - represent==2.1
593 | - rfc3339-validator==0.1.4
594 | - rfc3986-validator==0.1.1
595 | - rich==13.9.2
596 | - rpy2==3.5.17
597 | - runs==1.2.2
598 | - s3transfer==0.10.3
599 | - scib==1.1.5
600 | - scikit-learn==1.5.2
601 | - scipy==1.10.1
602 | - scvelo==0.3.2
603 | - send2trash==1.8.3
604 | - sniffio==1.3.1
605 | - snowballstemmer==2.2.0
606 | - sphinx==7.3.7
607 | - sphinx-autodoc-typehints==2.3.0
608 | - sphinxcontrib-applehelp==2.0.0
609 | - sphinxcontrib-devhelp==2.0.0
610 | - sphinxcontrib-htmlhelp==2.1.0
611 | - sphinxcontrib-jsmath==1.0.1
612 | - sphinxcontrib-qthelp==2.0.0
613 | - sphinxcontrib-serializinghtml==2.0.0
614 | - splicejac==0.0.1
615 | - stack-data==0.6.3
616 | - starlette==0.41.2
617 | - starsessions==1.3.0
618 | - stdlib-list==0.10.0
619 | - sympy==1.13.1
620 | - tenacity==9.0.0
621 | - tensorboard-data-server==0.7.2
622 | - tensorflow==2.18.0
623 | - tensorflow-io-gcs-filesystem==0.37.1
624 | - tensorstore==0.1.66
625 | - termcolor==2.5.0
626 | - terminado==0.18.1
627 | - texttable==1.7.0
628 | - tf-keras==2.18.0
629 | - tomli==2.0.2
630 | - torch==2.5.1
631 | - torchaudio==2.5.1
632 | - torchmetrics==1.4.3
633 | - torchode==1.0.0
634 | - torchtyping==0.1.5
635 | - torchvision==0.20.1
636 | - tqdm==4.66.5
637 | - traitlets==5.14.3
638 | - triton==3.1.0
639 | - typeguard==2.13.3
640 | - types-python-dateutil==2.9.0.20241003
641 | - tzlocal==5.3.1
642 | - umap-learn==0.5.6
643 | - unitvelo==0.2.5.2
644 | - uri-template==1.3.0
645 | - uvicorn==0.32.0
646 | - velocyto==0.17.17
647 | - velovi==0.3.1
648 | - wcwidth==0.2.13
649 | - webcolors==24.8.0
650 | - websocket-client==1.8.0
651 | - websockets==12.0
652 | - werkzeug==3.0.4
653 | - widgetsnbextension==4.0.13
654 | - wrapt==1.16.0
655 | - xarray==2024.9.0
656 | - xdg==6.0.0
657 | - xlsxwriter==3.2.0
658 | - xmod==1.8.1
659 | - xxhash==3.5.0
660 | - yarl==1.15.0
661 | - zarr==2.18.3
662 | - zipp==3.20.2
663 | prefix: /home/icb/weixu.wang/miniconda3/envs/regvelo_test
664 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry
4 |
5 | import setuptools
6 |
7 | if __name__ == "__main__":
8 | setuptools.setup(name="regvelo")
9 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/regvelo/5ed133bd37a563390ee4ec909b528f13d75c1b8a/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_regvelo.py:
--------------------------------------------------------------------------------
1 | ## test RegVelo
2 | import numpy as np
3 | import pandas as pd
4 | from scvi.data import synthetic_iid
5 | from regvelo import REGVELOVI
6 | import torch
7 |
8 | def test_regvelo():
9 | adata = synthetic_iid()
10 | adata.layers["spliced"] = adata.X.copy()
11 | adata.layers["unspliced"] = adata.X.copy()
12 | adata.var_names = "Gene" + adata.var_names
13 | n_gene = len(adata.var_names)
14 | ## create W
15 | grn_matrix = np.random.choice([0, 1], size=(n_gene,n_gene), p=[0.8, 0.2]).T
16 | W = pd.DataFrame(grn_matrix, index=adata.var_names, columns=adata.var_names)
17 | adata.uns["skeleton"] = W
18 | TF_list = adata.var_names.tolist()
19 |
20 | ## training process
21 | W = adata.uns["skeleton"].copy()
22 | W = torch.tensor(np.array(W))
23 | REGVELOVI.setup_anndata(adata, spliced_layer="spliced", unspliced_layer="unspliced")
24 |
25 | ## Training the model
26 | # hard constraint
27 | reg_vae = REGVELOVI(adata,W=W.T,regulators = TF_list,soft_constraint = False)
28 | reg_vae.train()
29 | # soft constraint
30 | reg_vae = REGVELOVI(adata,W=W.T,regulators = TF_list,soft_constraint = True)
31 | reg_vae.train()
32 |
33 | reg_vae.get_latent_representation()
34 | reg_vae.get_velocity()
35 | reg_vae.get_latent_time()
36 |
37 | reg_vae.history
38 | print(reg_vae)
--------------------------------------------------------------------------------