├── .github
└── workflows
│ ├── publish-conda.yml
│ └── python-publish.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── License.md
├── README.md
├── assets
├── Visual_Abstract.png
└── tool_logo.svg
├── conda.recipe
├── conda_build_config.yaml
├── environment.yml
└── meta.yaml
├── docs
├── Data_Preparation.html
├── README.md
└── tutorial_main.html
├── notebooks
├── Data_Preparation.ipynb
├── README.md
└── tutorial_main.ipynb
├── priors
├── collectri_human_net.csv
└── collectri_mouse_net.csv
├── pyproject.toml
├── requirements.txt
├── scregulate
├── __init__.py
├── __version__.py
├── auto_tuning.py
├── datasets.py
├── fine_tuning.py
├── loss_functions.py
├── train.py
├── train_utils.py
├── ulm_standalone.py
├── utils.py
└── vae_model.py
├── setup.cfg
└── setup.py
/.github/workflows/publish-conda.yml:
--------------------------------------------------------------------------------
1 | name: Publish Conda Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | build-and-upload-conda:
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - name: Checkout code
13 | uses: actions/checkout@v4
14 |
15 | - name: Setup Miniconda
16 | uses: conda-incubator/setup-miniconda@v3.0.4
17 | with:
18 | environment-file: ./conda.recipe/environment.yml
19 | activate-environment: build_env
20 | auto-update-conda: true
21 | use-mamba: true
22 | channels: conda-forge,pytorch
23 | channel-priority: strict
24 |
25 | - name: Install conda-build and anaconda-client
26 | shell: bash
27 | run: |
28 | eval "$(conda shell.bash hook)"
29 | conda activate build_env
30 | conda install conda-build anaconda-client -y
31 |
32 | - name: Build conda package
33 | shell: bash
34 | run: |
35 | eval "$(conda shell.bash hook)"
36 | conda activate build_env
37 | export GIT_TAG_NAME="${GITHUB_REF##*/}"
38 | export GIT_TAG_NAME="${GIT_TAG_NAME#v}" # strip leading "v" if present
39 |
40 | # Inject the version into the Python module so setuptools can read it
41 | echo "__version__ = \"${GIT_TAG_NAME}\"" > scregulate/__version__.py
42 |
43 | conda build conda.recipe --output-folder dist \
44 | --channel conda-forge --channel pytorch
45 |
46 | - name: Upload to Anaconda
47 | shell: bash
48 | env:
49 | ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }}
50 | run: |
51 | eval "$(conda shell.bash hook)"
52 | conda activate build_env
53 | shopt -s nullglob
54 | for file in dist/*/*.conda dist/*/*.tar.bz2; do
55 | echo "Uploading $file"
56 | anaconda -t "$ANACONDA_TOKEN" upload "$file" --force
57 | done
58 |
59 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | release-build:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Checkout code
16 | uses: actions/checkout@v4
17 |
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: "3.10"
22 |
23 | - name: Install build tools
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install build
27 |
28 | - name: Inject version from GitHub tag
29 | run: |
30 | VERSION="${GITHUB_REF##*/}"
31 | VERSION="${VERSION#v}" # Strip leading 'v' if present
32 | echo "__version__ = \"${VERSION}\"" > scregulate/__version__.py
33 | cat scregulate/__version__.py
34 |
35 |
36 | - name: Build distributions
37 | run: python -m build
38 |
39 | - name: Upload distributions as artifacts
40 | uses: actions/upload-artifact@v4
41 | with:
42 | name: release-dists
43 | path: dist/
44 |
45 | pypi-publish:
46 | runs-on: ubuntu-latest
47 | needs: release-build
48 | permissions:
49 | id-token: write
50 |
51 | environment:
52 | name: pypi
53 | url: https://pypi.org/project/scRegulate/${{ github.event.release.tag_name }}
54 |
55 | steps:
56 | - name: Download built distributions
57 | uses: actions/download-artifact@v4
58 | with:
59 | name: release-dists
60 | path: dist/
61 |
62 | - name: Publish to PyPI
63 | uses: pypa/gh-action-pypi-publish@release/v1
64 | with:
65 | packages-dir: dist/
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
5 |
6 | ## Our Standards
7 | Examples of behavior that contributes to creating a positive environment include:
8 | - Using welcoming and inclusive language
9 | - Being respectful of differing viewpoints and experiences
10 | - Gracefully accepting constructive criticism
11 | - Focusing on what is best for the community
12 | - Showing empathy towards other community members
13 |
14 | Examples of unacceptable behavior by participants include:
15 | - The use of sexualized language or imagery and unwelcome sexual attention or advances
16 | - Trolling, insulting/derogatory comments, and personal or political attacks
17 | - Public or private harassment
18 | - Publishing others' private information, such as a physical or electronic address, without explicit permission
19 | - Other conduct which could reasonably be considered inappropriate in a professional setting
20 |
21 | ## Our Responsibilities
22 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
23 |
24 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
25 |
26 | ## Scope
27 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
28 |
29 | ## Enforcement
30 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at [mzandi2@uic.edu](mailto:mzandi2@uic.edu). All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
31 |
32 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
33 |
34 | ## Attribution
35 | This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html)
36 |
37 | For answers to common questions about this code of conduct, see [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq)
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Mehrdad Zandigohar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/License.md:
--------------------------------------------------------------------------------
1 | YEAR: 2024
2 |
3 | COPYRIGHT HOLDER: scRegulate Authors
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scRegulate
2 | **S**ingle-**C**ell **Regula**tory-Embedded Variational Inference of **T**ranscription Factor Activity from Gene **E**xpression
3 |
4 | [](https://github.com/YDaiLab/scRegulate/issues)
5 | [](https://pypi.org/project/scRegulate/)
6 | [](https://anaconda.org/zandigohar/scregulate)
7 | [](https://ydailab.github.io/scRegulate/)
8 |
9 |
10 | ## Introduction
11 | **scRegulate** is a powerful tool designed for the inference of transcription factor activity from single cell/nucleus RNA data using advanced generative modeling techniques. It leverages a unified learning framework to optimize the modeling of cellular regulatory networks, providing researchers with accurate insights into transcriptional regulation. With its efficient clustering capabilities, **scRegulate** facilitates the analysis of complex biological data, making it an essential resource for studies in genomics and molecular biology.
12 |
13 |
14 |
15 |
16 |
17 | For further information and example tutorials, please check our documentation:
18 | - [PBMC 3K Tutorial (HTML)](https://ydailab.github.io/scRegulate/tutorial_main.html)
19 | - [Reproducibility Guide (HTML)](https://ydailab.github.io/scRegulate/Data_Preparation.html)
20 |
21 | If you have any questions or concerns, feel free to [open an issue](https://github.com/YDaiLab/scRegulate/issues).
22 |
23 | ## Requirements
24 | scRegulate is implemented in the PyTorch framework. Running scRegulate on `CUDA` is highly recommended if available.
25 |
26 |
27 | Before installing and running scRegulate, ensure you have the following libraries installed:
28 | - **PyTorch** (version 2.0 or higher)
29 | Install with the exact command from the [PyTorch “Get Started” page](https://pytorch.org/get-started/locally/) for your OS, Python version and (optionally) CUDA toolkit.
30 | - **NumPy** (version 1.23 or higher)
31 | - **Scanpy** (version 1.9 or higher)
32 | - **Anndata** (version 0.8 or higher)
33 |
34 | You can install these dependencies using `pip`:
35 |
36 | ```bash
37 | pip install torch numpy scanpy anndata
38 | ```
39 |
40 | ## Installation
41 |
42 | **Option 1:**
43 | You can install **scRegulate** via pip for a lightweight installation:
44 |
45 | ```bash
46 | pip install scregulate
47 | ```
48 |
49 | **Option 2:**
50 | Alternatively, if you want the latest, unreleased version, you can install it directly from the source on GitHub:
51 |
52 | ```bash
53 | pip install git+https://github.com/YDaiLab/scRegulate.git
54 | ```
55 |
56 | **Option 3:**
57 | For users who prefer Conda or Mamba for environment management, you can install **scRegulate** along with extra dependencies:
58 |
59 | **Conda:**
60 | ```bash
61 | conda install -c zandigohar scregulate
62 | ```
63 |
64 | **Mamba:**
65 | ```bash
66 | mamba create -n scRegulate -c zandigohar scregulate
67 | ```
68 |
69 | ## FAQ
70 |
71 | **Q1: Do I need a GPU to run scRegulate?**
72 | No, a GPU is not required. However, using a CUDA-enabled GPU is strongly recommended for faster training and inference, especially with large datasets.
73 |
74 | **Q2: How do I know if I can use a GPU with scRegulate?**
75 | There are two quick checks:
76 |
77 | 1. **System check**
78 | In your terminal, run `nvidia-smi`. If you see your GPU listed (model, memory, driver version), your machine has an NVIDIA GPU with the driver installed.
79 |
80 | 2. **Python check**
81 | In a Python shell, run:
82 | ```python
83 | import torch
84 | print(torch.cuda.is_available()) # True means PyTorch can see your GPU
85 | print(torch.cuda.device_count()) # How many GPUs are usable
86 | ```
87 |
88 | **Q3: Can I use scRegulate with Seurat or R-based tools?**
89 | scRegulate is written in Python and works directly with `AnnData` objects (e.g., from Scanpy). You can convert Seurat objects to AnnData using tools like `SeuratDisk`.
90 |
91 | **Q4: How can I visualize inferred TF activities?**
92 | TF activities inferred by scRegulate are stored in the `obsm` slot of the AnnData object. You can use `scanpy.pl.embedding`, `scanpy.pl.heatmap`, or export the matrix for custom plots.
93 |
94 | **Q5: What kind of prior networks does scRegulate accept?**
95 | scRegulate supports user-provided gene regulatory networks (GRNs) in CSV or matrix format. These can be curated from public databases or inferred from ATAC-seq or motif analysis.
96 |
97 | **Q6: Can I use scRegulate for multi-omics integration?**
98 | Not directly. While scRegulate focuses on TF activity from RNA, you can incorporate priors derived from other omics (e.g., ATAC) to **guide** the model.
99 |
100 | **Q7: What file formats are supported?**
101 | scRegulate works with `.h5ad` files (AnnData format). Input files should contain gene expression matrices with proper normalization.
102 |
103 | **Q8: How do I cite scRegulate?**
104 | See the [Citation](#citation) section below for the latest reference and preprint link.
105 |
106 | **Q9: How can I reproduce the paper’s results?**
107 | See our [Reproducibility Guide](https://github.com/YDaiLab/scRegulate/blob/main/notebooks/Data_Preparation.ipynb) for step-by-step instructions. Then run scregulate.
108 |
109 | ## Citation
110 |
111 | **scRegulate** manuscript is currently under peer review.
112 |
113 | If you use **scRegulate** in your research, please cite:
114 |
115 | Mehrdad Zandigohar, Jalees Rehman and Yang Dai (2025). **scRegulate: Single-Cell Regulatory-Embedded Variational Inference of Transcription Factor Activity from Gene Expression**, Bioinformatics Journal (under review).
116 |
117 | 📄 Read the preprint on bioRxiv: [10.1101/2025.04.17.649372](https://doi.org/10.1101/2025.04.17.649372)
118 |
119 | ## Development & Contact
120 | scRegulate was developed and is actively maintained by Mehrdad Zandigohar as part of his PhD research at the University of Illinois Chicago (UIC), in the lab of Dr. Yang Dai.
121 |
122 | 📬 For private questions, please email: mzandi2@uic.edu
123 |
124 | 🤝 For collaboration inquiries, please contact PI: Dr. Yang Dai (yangdai@uic.edu)
125 |
126 | Contributions, feature suggestions, and feedback are always welcome!
127 |
128 | ## License
129 |
130 | The code in **scRegulate** is licensed under the [MIT License](https://opensource.org/licenses/MIT), which permits academic and commercial use, modification, and distribution.
131 |
132 | Please note that any third-party dependencies bundled with **scRegulate** may have their own respective licenses.
133 |
134 |
--------------------------------------------------------------------------------
/assets/Visual_Abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YDaiLab/scRegulate/52df4611f4e99014c20caca9ebf321dc44d660b3/assets/Visual_Abstract.png
--------------------------------------------------------------------------------
/assets/tool_logo.svg:
--------------------------------------------------------------------------------
1 |
2 | S E c R e g 01010100 01000110
179 |
--------------------------------------------------------------------------------
/conda.recipe/conda_build_config.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - pytorch
4 | - defaults
5 |
6 | python:
7 | - 3.10
8 | numpy:
9 | - 1.23
10 |
--------------------------------------------------------------------------------
/conda.recipe/environment.yml:
--------------------------------------------------------------------------------
1 | name: build_env
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.10
8 | - pip
9 | - pytorch>=2.0
10 | - scanpy=1.10.4
11 | - anndata=0.10.8
12 | - numpy=1.26.4
13 | - matplotlib-base>=3.6
14 | - optuna >=3.0
15 |
--------------------------------------------------------------------------------
/conda.recipe/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set version = environ.get('GIT_TAG_NAME', '0.0.0') %}
2 |
3 | package:
4 | name: scregulate
5 | version: "{{ version }}"
6 |
7 | source:
8 | path: ..
9 |
10 | build:
11 | noarch: python
12 | script: python -m pip install .
13 | number: 0
14 |
15 | requirements:
16 | host:
17 | - python >=3.10,<3.11
18 | - pip
19 | run:
20 | - python >=3.10,<3.11
21 | - numpy >=1.26
22 | - pytorch >=2.0,<2.3
23 | - scanpy >=1.10,<1.11
24 | - anndata >=0.10
25 | - matplotlib-base >=3.6,<3.9
26 | - pillow >=8.0
27 | - optuna >=3.0
28 |
29 | test:
30 | imports:
31 | - scregulate
32 | commands:
33 | - pip check
34 | requires:
35 | - pip
36 | source_files:
37 | - setup.py
38 | channels:
39 | - conda-forge
40 | - pytorch
41 |
42 | about:
43 | home: https://github.com/YDaiLab/scRegulate
44 | dev_url: https://github.com/YDaiLab/scRegulate
45 | license: MIT
46 | license_file: LICENSE
47 | summary: Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data
48 | description: |
49 | scRegulate is a Python toolkit designed for inferring transcription factor activity
50 | and clustering single-cell RNA-seq data using variational inference with biological priors.
51 | It is built on PyTorch and Scanpy, and supports CUDA acceleration for high-performance inference.
52 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 |
28 |
29 |
📘 scRegulate Tutorials
30 |
31 | This page contains multiple tutorials to help you:
32 |
33 | Run transcription factor (TF) inference on your own data
34 | Reproduce the results in our manuscript
35 |
36 |
37 |
38 |
🧬 PBMC 3K Tutorial
39 |
📄 Reproduce Paper Results
40 |
41 |
42 |
43 |
🧬 TF Inference on PBMC 3K
44 |
This tutorial walks you through the basic usage of scRegulate
on a small PBMC 3K dataset.
45 |
49 |
50 |
51 |
52 |
📄 Reproducing Manuscript Results
53 |
This notebook replicates the preprocessing and analysis pipeline used in our paper.
54 |
58 |
59 |
60 |
71 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | This page contains multiple tutorials consisting of:
2 | 1) TF inference on PBMC 3K [Notebook](https://github.com/YDaiLab/scRegulate/blob/main/notebooks/tutorial_main.ipynb) | [HTML](https://ydailab.github.io/scRegulate/tutorial_main.html)
3 | 2) Reproducing manuscript results [Notebook](https://github.com/YDaiLab/scRegulate/blob/main/notebooks/Data_Preparation.ipynb) | [HTML](https://ydailab.github.io/scRegulate/Data_Preparation.html)
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 77.0.3"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "scRegulate"
7 | dynamic = ["version"]
8 | description = "Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data"
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | license = "MIT"
12 | authors = [
13 | {name = "Mehrdad Zandigohar", email = "mehr.zgohar@gmail.com"},
14 | ]
15 | dependencies = [
16 | "torch>=2.0",
17 | "numpy>=1.23",
18 | "scanpy>=1.9",
19 | "anndata>=0.8"
20 | ]
21 |
22 | [project.urls]
23 | "Homepage" = "https://github.com/YDaiLab/scRegulate"
24 | "Documentation" = "https://github.com/YDaiLab/scRegulate#readme"
25 | "Issue Tracker" = "https://github.com/YDaiLab/scRegulate/issues"
26 | "Paper (bioRxiv)" = "https://doi.org/10.1101/2025.04.17.649372"
27 |
28 | [tool.setuptools.dynamic]
29 | version = { attr = "scregulate.__version__.__version__" }
30 |
31 |
32 |
33 | [tool.setuptools.packages.find]
34 | where = ["."]
35 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0
2 | numpy>=1.23
3 | scanpy>=1.9
4 | anndata>=0.8
5 |
--------------------------------------------------------------------------------
/scregulate/__init__.py:
--------------------------------------------------------------------------------
1 | # scregulate/__init__.py
2 | from .__version__ import __version__
3 |
4 | # Core training and auto-tuning
5 | from .train import train_model, adapt_prior_and_data
6 | from .auto_tuning import auto_tune
7 |
8 | # Model architecture
9 | from .vae_model import scRNA_VAE
10 |
11 | # Loss functions and gradient norm
12 | from .loss_functions import loss_function, get_gradient_norm
13 |
14 | # Training utilities
15 | from .train_utils import (
16 | schedule_parameter,
17 | schedule_mask_factor,
18 | apply_gradient_mask,
19 | compute_average_loss,
20 | clip_gradients,
21 | )
22 |
23 | # General utilities
24 | from .utils import (
25 | set_random_seed,
26 | create_dataloader,
27 | clear_memory,
28 | extract_GRN,
29 | to_torch_tensor,
30 | set_active_modality,
31 | extract_modality,
32 | )
33 |
34 | # GRN prior utilization (collectri)
35 | from .datasets import collectri_prior
36 |
--------------------------------------------------------------------------------
/scregulate/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.26"
2 |
--------------------------------------------------------------------------------
/scregulate/auto_tuning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import optuna
4 | from optuna.samplers import TPESampler, NSGAIISampler, QMCSampler
5 | from optuna.pruners import HyperbandPruner, SuccessiveHalvingPruner
6 |
7 | import logging
8 | from .utils import set_active_modality
9 |
10 | from .fine_tuning import fine_tune_clusters
11 | from .train import train_model
12 | from .train import adapt_prior_and_data
13 |
14 | # ---------- Logger Configuration ----------
15 | autotune_logger = logging.getLogger("autotune")
16 | autotune_logger.setLevel(logging.INFO)
17 | handler = logging.StreamHandler()
18 | formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
19 | handler.setFormatter(formatter)
20 | if not autotune_logger.handlers:
21 | autotune_logger.addHandler(handler)
22 | autotune_logger.propagate = False
23 |
24 | # ---------- Common Validation Loss Function ----------
25 | def compute_validation_loss(model, val_adata):
26 | autotune_logger.info("Computing validation loss...")
27 |
28 | # Retrieve gene names used in the model
29 | if not hasattr(model, "gene_names"):
30 | raise ValueError("Model does not have gene_names attribute. Ensure it is set during training.")
31 | model_genes = model.gene_names # List of gene names used during training
32 |
33 | # Align validation data with model's gene names
34 | adata_genes = val_adata.var_names
35 | shared_genes = set(adata_genes).intersection(set(model_genes))
36 | if len(shared_genes) < len(model_genes):
37 | val_adata = val_adata[:, list(shared_genes)].copy()
38 | autotune_logger.info(f"Aligned validation data to {len(shared_genes)} shared genes.")
39 |
40 | X_val = val_adata.X
41 | if not isinstance(X_val, np.ndarray):
42 | X_val = X_val.toarray()
43 |
44 | device = "cuda" if torch.cuda.is_available() else "cpu"
45 | X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
46 |
47 | model.eval()
48 | with torch.no_grad():
49 | mu, logvar = model.encode(X_val_tensor)
50 | z = model.reparameterize(mu, logvar)
51 | tf_activity = model.decode(z)
52 | recon_x = model.tf_mapping(tf_activity)
53 | mse_loss = torch.mean((recon_x - X_val_tensor) ** 2).item()
54 |
55 | autotune_logger.info(f"Validation MSE loss: {mse_loss}")
56 | return mse_loss
57 |
58 | # ---------- Training Objective Function ----------
59 | def training_objective(trial, train_adata, val_adata, net, train_model, **kwargs):
60 | autotune_logger.info("Running training objective...")
61 |
62 | # Suggest hyperparameters
63 | freeze_epochs = trial.suggest_int('freeze_epochs', kwargs['freeze_epochs_range'][0], kwargs['freeze_epochs_range'][1])
64 | alpha_scale = trial.suggest_float('alpha_scale', kwargs['alpha_scale_range'][0], kwargs['alpha_scale_range'][1])
65 | alpha_max = trial.suggest_float('alpha_max', kwargs['alpha_max_range'][0], kwargs['alpha_max_range'][1])
66 | beta_max = trial.suggest_float('beta_max', kwargs['beta_max_range'][0], kwargs['beta_max_range'][1])
67 | gamma_max = trial.suggest_float('gamma_max', kwargs['gamma_max_range'][0], kwargs['gamma_max_range'][1])
68 | learning_rate = trial.suggest_float('learning_rate', kwargs['learning_rate_range'][0], kwargs['learning_rate_range'][1], log=True)
69 |
70 | # Train model
71 | model, _, _ = train_model(
72 | rna_data=train_adata,
73 | net=net,
74 | encode_dims=kwargs['encode_dims'],
75 | decode_dims=kwargs['decode_dims'],
76 | z_dim=kwargs['z_dim'],
77 | epochs=kwargs['epochs'],
78 | freeze_epochs=freeze_epochs,
79 | learning_rate=learning_rate,
80 |
81 | alpha_start=0,
82 | alpha_scale=alpha_scale,
83 | alpha_max=alpha_max,
84 |
85 | beta_start=0,
86 | beta_max=beta_max,
87 |
88 | gamma_start=0,
89 | gamma_max=gamma_max,
90 |
91 | log_interval=kwargs['log_interval'],
92 | early_stopping_patience=kwargs['early_stopping_patience'],
93 | min_targets=kwargs['min_targets'],
94 | min_TFs=kwargs['min_TFs'],
95 | device=None,
96 | return_outputs=True,
97 | verbose=kwargs['verbose']
98 | )
99 |
100 | # Compute validation loss
101 | val_loss = compute_validation_loss(model, val_adata)
102 | return val_loss
103 |
104 | # ---------- Fine-Tuning Objective Function ----------
105 | def fine_tuning_objective(trial, processed_adata, model, **kwargs):
106 | autotune_logger.info("Running fine-tuning objective...")
107 |
108 | # Suggest hyperparameters
109 | beta_max = trial.suggest_float("beta_max", kwargs['beta_max_range'][0], kwargs['beta_max_range'][1])
110 | max_weight_norm = trial.suggest_float("max_weight_norm", kwargs['max_weight_norm_range'][0], kwargs['max_weight_norm_range'][1])
111 | tf_mapping_lr = trial.suggest_float("tf_mapping_lr", kwargs['tf_mapping_lr_range'][0], kwargs['tf_mapping_lr_range'][1], log=True)
112 | fc_output_lr = trial.suggest_float("fc_output_lr", kwargs['fc_output_lr_range'][0], kwargs['fc_output_lr_range'][1], log=True)
113 | default_lr = trial.suggest_float("default_lr", kwargs['default_lr_range'][0], kwargs['default_lr_range'][1], log=True)
114 |
115 | # Log modality availability
116 | autotune_logger.info("Checking modalities in processed_adata...")
117 | if not hasattr(processed_adata, "modality"):
118 | autotune_logger.error("processed_adata does not have 'modality' attribute.")
119 | raise ValueError("Missing 'modality' attribute in processed_adata.")
120 |
121 | if "TF" not in processed_adata.modality:
122 | autotune_logger.error(f"Available modalities: {list(processed_adata.modality.keys())}")
123 | raise ValueError("Missing modality['TF'] in processed_adata.")
124 |
125 | autotune_logger.info(f"Modalities available: {list(processed_adata.modality.keys())}")
126 |
127 | # Log cluster_key and its addition to obs
128 | cluster_key = kwargs.get('cluster_key', None)
129 | autotune_logger.info(f"Adding cluster_key '{cluster_key}' to processed_adata.obs...")
130 |
131 | # Fine-tune model
132 | _, _, fine_model , _ = fine_tune_clusters(
133 | processed_adata=processed_adata,
134 | model=model,
135 | cluster_key=cluster_key,
136 | epochs=kwargs['epochs'],
137 | device=kwargs['device'],
138 | verbose=kwargs['verbose'],
139 | beta_start=0,
140 | beta_max=beta_max,
141 | max_weight_norm=max_weight_norm,
142 | early_stopping_patience=kwargs['early_stopping_patience'],
143 | tf_mapping_lr=tf_mapping_lr,
144 | fc_output_lr=fc_output_lr,
145 | default_lr=default_lr # Fine-tuned base learning rate
146 | )
147 |
148 | # Compute validation loss
149 | autotune_logger.info("Computing validation loss after fine-tuning...")
150 | val_loss = compute_validation_loss(fine_model, processed_adata)
151 | return val_loss
152 |
153 |
154 | # ---------- Unified Auto-Tuning Function ----------
155 | def auto_tune(
156 | mode, # 'training' or 'fine-tuning'
157 | processed_adata=None,
158 | model=None,
159 | net=None,
160 | train_model=None, # Explicitly added as a direct argument
161 | n_trials=10,
162 | **kwargs
163 | ):
164 | autotune_logger.info(f"Starting auto-tuning for {mode} with {n_trials} trials...")
165 |
166 | if processed_adata is None:
167 | raise ValueError("processed_adata must be provided for auto-tuning.")
168 |
169 |
170 | if mode == "training":
171 | train_val_split_ratio = kwargs.get('train_val_split_ratio', 0.8)
172 | np.random.seed(0)
173 | indices = np.arange(processed_adata.n_obs)
174 | np.random.shuffle(indices)
175 | train_size = int(train_val_split_ratio * len(indices))
176 | train_idx = indices[:train_size]
177 | val_idx = indices[train_size:]
178 |
179 | train_adata = processed_adata[train_idx].copy()
180 | val_adata = processed_adata[val_idx].copy()
181 |
182 | # Adapt prior for training data
183 | W_prior, gene_names, TF_names = adapt_prior_and_data(train_adata, net, kwargs['min_targets'], kwargs['min_TFs'])
184 |
185 | # Align validation data to training gene names
186 | val_genes = val_adata.var_names.intersection(gene_names)
187 | val_adata = val_adata[:, list(val_genes)].copy()
188 |
189 | elif mode == "fine-tuning":
190 | # Ensure the active modality is set to 'RNA'
191 | autotune_logger.info("Setting the active modality to 'RNA'...")
192 | processed_adata = set_active_modality(processed_adata, "RNA")
193 | autotune_logger.info("Active modality set to 'RNA'.")
194 |
195 | if not hasattr(model, "gene_names"):
196 | raise ValueError("Model does not have gene_names attribute. Ensure it is set during training.")
197 | gene_names = model.gene_names
198 | fine_tuned_genes = processed_adata.var_names.intersection(gene_names)
199 | processed_adata._inplace_subset_var(list(fine_tuned_genes))
200 | autotune_logger.info(f"Aligned processed_adata to {len(fine_tuned_genes)} genes.")
201 | train_adata, val_adata = None, None
202 | else:
203 | raise ValueError("Invalid mode. Choose 'training' or 'fine-tuning'.")
204 |
205 | # Create study
206 | #study = optuna.create_study(direction='minimize', sampler=TPESampler())
207 |
208 | # Use BOHB (Bayesian Optimization with HyperBand)
209 | sampler = TPESampler(multivariate=True, seed=42)
210 |
211 | pruner = HyperbandPruner(
212 | min_resource=500, # Let model run at least 500 epochs before considering pruning
213 | max_resource=kwargs['epochs'], # Full training budget
214 | reduction_factor=3, # Prune bottom 2/3 at each round
215 | bootstrap_count=0
216 | )
217 |
218 | study = optuna.create_study(direction='minimize', sampler=sampler, pruner=pruner)
219 |
220 | if mode == "training":
221 | study.optimize(
222 | lambda trial: training_objective(
223 | trial, train_adata, val_adata, net, train_model, **kwargs
224 | ),
225 | n_trials=n_trials
226 | )
227 | elif mode == "fine-tuning":
228 | study.optimize(
229 | lambda trial: fine_tuning_objective(
230 | trial, processed_adata, model, **kwargs
231 | ),
232 | n_trials=n_trials
233 | )
234 | else:
235 | raise ValueError("Invalid mode. Choose 'training' or 'fine-tuning'.")
236 |
237 | autotune_logger.info(f"Best hyperparameters: {study.best_params}")
238 | autotune_logger.info(f"Best validation loss: {study.best_value}")
239 | return study.best_params, study.best_value
240 |
--------------------------------------------------------------------------------
/scregulate/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib.request
3 | import pandas as pd
4 |
5 | def _download_file_if_needed(filename: str, url: str, target_dir: str) -> str:
6 | os.makedirs(target_dir, exist_ok=True)
7 | filepath = os.path.join(target_dir, filename)
8 | if not os.path.exists(filepath):
9 | print(f"[scRegulate] Downloading {filename} to {filepath}...")
10 | try:
11 | urllib.request.urlretrieve(url, filepath)
12 | except Exception as e:
13 | raise RuntimeError(f"Failed to download {url}.\nError: {e}")
14 | return filepath
15 |
16 | def collectri_prior(species: str = "human") -> pd.DataFrame:
17 | """
18 | Returns the collectri prior dataframe for the given species ("human" or "mouse").
19 | Downloads the file on first use and caches it under ~/.scregulate/priors/.
20 |
21 | Parameters:
22 | - species: str = "human" or "mouse"
23 |
24 | Returns:
25 | - pd.DataFrame with TF-target prior network
26 | """
27 | base_url = "https://github.com/YDaiLab/scRegulate/raw/main/priors/"
28 | species_to_filename = {
29 | "human": "collectri_human_net.csv",
30 | "mouse": "collectri_mouse_net.csv"
31 | }
32 |
33 | if species not in species_to_filename:
34 | raise ValueError("species must be either 'human' or 'mouse'")
35 |
36 | filename = species_to_filename[species]
37 | url = base_url + filename
38 | target_dir = os.path.expanduser("~/.scregulate/priors")
39 |
40 | local_path = _download_file_if_needed(filename, url, target_dir)
41 | return pd.read_csv(local_path)
42 |
--------------------------------------------------------------------------------
/scregulate/fine_tuning.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import numpy as np
4 | import scanpy as sc
5 | import pandas as pd
6 | from torch.utils.data import DataLoader
7 | from sklearn.preprocessing import MinMaxScaler
8 | import logging
9 | import torch.nn.functional as F
10 | from .loss_functions import loss_function
11 | from .utils import (
12 | set_random_seed, # To ensure reproducibility
13 | create_dataloader, # For DataLoader creation
14 | clear_memory, # To handle memory cleanup
15 | to_torch_tensor, # Converts dense or sparse data to PyTorch tensors
16 | schedule_parameter, # For parameter scheduling during training
17 | )
18 | set_random_seed()
19 |
20 | # ---------- Logging Configuration ----------
21 | finetune_logger = logging.getLogger("finetune")
22 | finetune_logger.setLevel(logging.INFO)
23 |
24 | # Reset logger: Remove any existing handlers
25 | while finetune_logger.handlers:
26 | finetune_logger.removeHandler(finetune_logger.handlers[0])
27 |
28 | # Define a new handler
29 | handler = logging.StreamHandler() # Logs to console; use FileHandler for a file
30 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
31 | handler.setFormatter(formatter)
32 | finetune_logger.addHandler(handler)
33 |
34 | finetune_logger.propagate = False # Prevent duplicate logs from root logger
35 |
36 | # ---------- Sparsity Handling ----------
37 | def sparse_collate_fn(batch):
38 | """
39 | Custom collate function to handle sparse tensors.
40 |
41 | Args:
42 | batch: A batch of data (potentially sparse).
43 |
44 | Returns:
45 | torch.Tensor: Dense tensor batch.
46 | """
47 | if isinstance(batch[0], torch.Tensor):
48 | return torch.stack(batch)
49 | elif isinstance(batch[0], torch.sparse_coo_tensor):
50 | # Convert sparse tensors to dense format before stacking
51 | return torch.stack([b.to_dense() for b in batch])
52 | else:
53 | raise TypeError(f"Unsupported batch type: {type(batch[0])}")
54 |
55 | # ---------- Extract TF Activities ----------
56 | def extract_tf_activities_deterministic(model, data_loader, device="cuda"):
57 | """
58 | Extract transcription factor (TF) activities deterministically using the trained model.
59 |
60 | Args:
61 | model (torch.nn.Module): The trained model to use for inference.
62 | data_loader (DataLoader): DataLoader containing the cluster-specific data.
63 | device (str): Device for computation ('cuda' or 'cpu'). Default is 'cuda'.
64 |
65 | Returns:
66 | np.ndarray: Matrix of inferred TF activities.
67 | """
68 | model.eval()
69 | tf_activities = []
70 | with torch.no_grad():
71 | for batch in data_loader:
72 | batch = batch.to(device)
73 | mu, _ = model.encode(batch) # Use mean (mu) for deterministic encoding
74 | tf_activity = model.decode(mu).clamp(min=0)
75 | #tf_activity = model.decode(mu).clamp(min=0, max=5.0)
76 | tf_activities.append(tf_activity.cpu().numpy())
77 | return np.concatenate(tf_activities, axis=0)
78 |
79 | # ---------- Cluster-sepcific handling ----------
80 | def scale_and_format_grns(W_posteriors, gene_names, tf_names):
81 | """
82 | Apply MinMax scaling to each W matrix (per cluster) and return a dict of DataFrames.
83 | """
84 | scaled_grns = {}
85 | for cluster, W in W_posteriors.items():
86 | scaled = MinMaxScaler().fit_transform(W)
87 | scaled_df = pd.DataFrame(scaled, index=gene_names, columns=tf_names)
88 | scaled_grns[cluster] = scaled_df
89 |
90 | return scaled_grns
91 |
92 | def finalize_fine_tuning_outputs(processed_adata, all_tf_activities_matrix, tf_names, W_posteriors, gene_names):
93 | """
94 | Prepares and returns the processed AnnData, TF activity AnnData, and GRNs.
95 | """
96 | # Attach TF activities to AnnData
97 | processed_adata.obsm["TF_finetuned"] = all_tf_activities_matrix
98 | processed_adata.uns["W_posteriors_per_cluster"] = W_posteriors
99 |
100 | # Create new AnnData object for TF activities
101 | tf_activities_adata = sc.AnnData(
102 | X=all_tf_activities_matrix,
103 | obs=processed_adata.obs.copy(),
104 | var=pd.DataFrame(index=tf_names),
105 | )
106 |
107 | # Format GRNs
108 | scaled_grns = scale_and_format_grns(
109 | W_posteriors, gene_names=gene_names, tf_names=tf_names
110 | )
111 |
112 | return processed_adata, tf_activities_adata, scaled_grns
113 |
114 |
115 |
116 | # ---------- Fine-Tuning for All Clusters ----------
117 | def fine_tune_clusters(
118 | processed_adata,
119 | model,
120 | cluster_key=None,
121 | epochs=500,
122 | batch_size=3500,
123 | device=None,
124 | max_weight_norm=4,
125 | log_interval=100,
126 | early_stopping_patience=250,
127 | min_epochs=0,
128 | beta_start=0,
129 | beta_max=0.5,
130 | tf_mapping_lr=4e-04, # Learning rate for tf_mapping layer
131 | fc_output_lr=2e-05, # Learning rate for fc_output layer
132 | default_lr=3.5e-05, # Default learning rate for other layers
133 | verbose=True,
134 | ):
135 | """
136 | Fine-tune the model for each cluster in the AnnData object with logging intervals and early stopping.
137 |
138 | Args:
139 | processed_adata (AnnData): AnnData object containing RNA data and cluster annotations.
140 | model (torch.nn.Module): Pre-trained model to fine-tune.
141 | cluster_key (str): Key in `.obs` containing cluster annotations. Default is None.
142 | epochs (int): Number of epochs for fine-tuning each cluster. Default is 5000.
143 | batch_size (int): Batch size for fine-tuning. Default is 2000.
144 | device (str): Device to perform computation ('cuda' or 'cpu'). Default is 'cuda'.
145 | beta_start (float): Initial value for beta warm-up. Default is 0.01.
146 | beta_max (float): Maximum value for beta warm-up. Default is 0.1.
147 | max_weight_norm (float): Maximum weight norm for clipping. Default is 0.1.
148 | log_interval (int): Number of epochs between logging updates. Default is 100.
149 | early_stopping_patience (int): Number of epochs to wait without improvement. Default is 1000.
150 | min_epochs (int): Minimum number of epochs before checking for early stopping. Default is 1000.
151 | tf_mapping_lr (float): Learning rate for the tf_mapping layer. Default is 1e-2.
152 | fc_output_lr (float): Learning rate for the fc_output layer. Default is 1e-3.
153 | default_lr (float): Learning rate for other layers. Default is 1e-4.
154 | verbose (bool): If True, enables detailed logging.
155 |
156 | Returns:
157 | AnnData: Updated AnnData object with fine-tuned TF activities (`obsm`) and cluster-specific W_posteriors (`uns`).
158 |
159 | Notes:
160 | - Do not change the learning rates unless you are an expert.
161 | - tf_mapping_lr (1e-2): Faster learning rate for tf_mapping layer.
162 | - fc_output_lr (1e-3): Slower learning rate for fc_output layer.
163 | - default_lr (1e-4): Learning rate for any other trainable layers.
164 | """
165 | set_random_seed()
166 | batch_size = int(batch_size)
167 |
168 | # Set logging verbosity, control verbosity of train_model logger only.
169 | if verbose:
170 | finetune_logger.setLevel(logging.INFO)
171 | else:
172 | finetune_logger.setLevel(logging.WARNING)
173 |
174 | # Auto-detect device if not specified
175 | if device is None:
176 | device = "cuda" if torch.cuda.is_available() else "cpu"
177 |
178 | finetune_logger.info("Starting fine-tuning for cluster(s)...")
179 |
180 | # Align gene indices to match GRN dimensions
181 | grn_gene_names = processed_adata.uns["GRN_posterior"]["gene_names"]
182 | adata_gene_names = processed_adata.var.index
183 | gene_indices = [adata_gene_names.get_loc(gene) for gene in grn_gene_names if gene in adata_gene_names]
184 |
185 |
186 | finetune_logger.info(f"Aligning data to {len(gene_indices)} genes matching the GRN.")
187 |
188 | if not cluster_key:
189 | finetune_logger.info("Cluster key not provided, fine-tuning on all cells together...")
190 | cluster_key = 'all'
191 | processed_adata.obs[cluster_key] = 1
192 |
193 | unique_clusters = processed_adata.obs[cluster_key].unique()
194 | original_model = copy.deepcopy(model)
195 | W_posteriors_per_cluster = {}
196 | all_tf_activities = []
197 | cluster_indices_list = []
198 | total_cells = len(processed_adata)
199 | for cluster in unique_clusters:
200 | if len(unique_clusters)>1:
201 | finetune_logger.info(f"Fine-tuning {cluster} for {epochs} epochs...")
202 | else:
203 | finetune_logger.info(f"Fine-tuning on all cells for {epochs} epochs...")
204 |
205 | cluster_size = len(processed_adata.obs[processed_adata.obs[cluster_key] == cluster])
206 | proportional_epochs = max(min_epochs, int(epochs * (cluster_size / total_cells)))
207 |
208 | # Extract cluster data and ensure dense format
209 | cluster_indices = processed_adata.obs[cluster_key] == cluster
210 | cluster_data = processed_adata[cluster_indices].X[:, gene_indices]
211 |
212 | # Convert sparse to dense if needed
213 | if not isinstance(cluster_data, np.ndarray):
214 | cluster_data = cluster_data.todense()
215 | cluster_tensor = to_torch_tensor(cluster_data, device=device)
216 |
217 | # Create DataLoader
218 | cluster_loader = create_dataloader(cluster_tensor, batch_size=batch_size)
219 |
220 | # Fine-tune model for the cluster with early stopping
221 | model_copy = copy.deepcopy(original_model)
222 | best_loss, epochs_no_improve = float("inf"), 0
223 |
224 | # Optimizer with configurable learning rates
225 | optimizer = torch.optim.AdamW([
226 | {"params": model_copy.tf_mapping.parameters(), "lr": tf_mapping_lr}, # Learning rate for tf_mapping
227 | {"params": model_copy.fc_output.parameters(), "lr": fc_output_lr}, # Learning rate for fc_output
228 | {"params": [p for n, p in model_copy.named_parameters()
229 | if "tf_mapping" not in n and "fc_output" not in n], "lr": default_lr} # Default LR
230 | ])
231 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min",
232 | patience=early_stopping_patience//2,
233 | factor=0.5, min_lr=default_lr)
234 | for epoch in range(proportional_epochs):
235 | model_copy.train()
236 | total_loss, total_samples = 0, 0
237 | # Warm-up beta
238 | beta = schedule_parameter(epoch, beta_start, beta_max, epochs)
239 |
240 | for batch in cluster_loader:
241 | batch = batch.to(device)
242 | optimizer.zero_grad()
243 |
244 | # Forward pass
245 | recon_batch, mu, logvar = model_copy(batch)
246 | MSE, _ = loss_function(recon_batch, batch, mu, logvar)
247 | l1_reg = torch.sum(torch.abs(model_copy.tf_mapping.weight))
248 |
249 | cluster_weight = len(cluster_indices) / len(processed_adata)
250 | loss = (MSE + beta * l1_reg)# * cluster_weight
251 |
252 | # Backward pass
253 | loss.backward()
254 | optimizer.step()
255 |
256 | # Weight clipping
257 | with torch.no_grad():
258 | model_copy.tf_mapping.weight.clamp_(-max_weight_norm, max_weight_norm)
259 |
260 | total_loss += loss.item() * len(batch)
261 | total_samples += len(batch)
262 |
263 | # Average loss for the epoch
264 | avg_loss = total_loss / total_samples
265 | scheduler.step(avg_loss)
266 |
267 | # Logging
268 | if epoch % log_interval == 0 or epoch == epochs - 1:
269 | finetune_logger.info(f"Epoch {epoch+1}, Avg Loss: {avg_loss/processed_adata.shape[1]:.4f}")
270 |
271 | # Early stopping
272 | if epoch > min_epochs:
273 | if avg_loss < best_loss:
274 | best_loss = avg_loss
275 | epochs_no_improve = 0
276 | else:
277 | epochs_no_improve += 1
278 | if epochs_no_improve >= early_stopping_patience:
279 | finetune_logger.info(f"Early stopping for cluster {cluster} at epoch {epoch+1}")
280 | break
281 |
282 | # Extract W_posterior and TF activities
283 | W_posterior_cluster = model_copy.tf_mapping.weight.detach().cpu().numpy()
284 | W_posteriors_per_cluster[cluster] = W_posterior_cluster
285 |
286 | tf_activities = extract_tf_activities_deterministic(model_copy, cluster_loader, device=device)
287 |
288 | cluster_indices_list.append(np.where(cluster_indices)[0])
289 | all_tf_activities.append(tf_activities)
290 |
291 | # Clear memory
292 | clear_memory(device)
293 |
294 | # [Modification 3] - We now build the final matrix in the ORIGINAL cell order
295 | n_cells = processed_adata.n_obs
296 |
297 | # If there's at least one cluster, get the size of the TF dimension from the first cluster's result
298 | n_tfs = all_tf_activities[0].shape[1] if len(all_tf_activities) > 0 else 0
299 |
300 | # [Modification 4] - Create a zeroed matrix to hold TF activities for ALL cells
301 | all_tf_activities_matrix = np.zeros((n_cells, n_tfs), dtype=np.float32)
302 |
303 | # [Modification 5] - Insert each cluster's activities into the correct rows
304 | for row_indices, cluster_acts in zip(cluster_indices_list, all_tf_activities):
305 | all_tf_activities_matrix[row_indices, :] = cluster_acts
306 |
307 | # Attach to adata
308 | processed_adata.obsm["TF_finetuned"] = all_tf_activities_matrix
309 | processed_adata.uns["W_posteriors_per_cluster"] = W_posteriors_per_cluster
310 |
311 | tf_names = processed_adata.modality['TF'].var.index
312 |
313 | tf_activities_adata = sc.AnnData(
314 | X=all_tf_activities_matrix,
315 | obs=processed_adata.obs.copy(),
316 | var=pd.DataFrame(index=tf_names),
317 | )
318 |
319 | gene_names = processed_adata.uns['GRN_posterior']['gene_names']
320 | tf_names = processed_adata.uns['GRN_posterior']['TF_names']
321 |
322 | processed_adata, tf_activities_adata, scaled_grns = finalize_fine_tuning_outputs(
323 | processed_adata,
324 | all_tf_activities_matrix,
325 | tf_names=tf_names,
326 | W_posteriors=W_posteriors_per_cluster,
327 | gene_names=gene_names,
328 | )
329 |
330 | # If only one GRN (no cluster_key given), return a single DataFrame
331 | if len(scaled_grns) == 1:
332 | scaled_grns = list(scaled_grns.values())[0]
333 |
334 | finetune_logger.info("Fine-tuning completed for all clusters.")
335 | return processed_adata, tf_activities_adata, model_copy, scaled_grns
336 |
337 |
338 |
--------------------------------------------------------------------------------
/scregulate/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def loss_function(recon_x, x, mu, logvar):
5 | """
6 | Computes the loss function for the Variational Autoencoder (VAE).
7 | The loss consists of:
8 | - Mean Squared Error (MSE) between the input and reconstructed data.
9 | - Kullback-Leibler Divergence (KLD) to regularize the latent space.
10 |
11 | Args:
12 | recon_x: Reconstructed input.
13 | x: Original input.
14 | mu: Mean vector from the encoder.
15 | logvar: Log variance vector from the encoder.
16 |
17 | Returns:
18 | MSE: Reconstruction loss.
19 | KLD: KL divergence loss.
20 | """
21 | MSE = F.mse_loss(recon_x, x, reduction='sum') # Reconstruction loss
22 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # KL Divergence
23 | return MSE, KLD
24 |
25 |
26 | def get_gradient_norm(model, norm_type=2):
27 | """
28 | Computes the total gradient norm for a model.
29 |
30 | Args:
31 | model: The PyTorch model.
32 | norm_type: The type of norm to compute (default is 2 for Euclidean norm).
33 |
34 | Returns:
35 | total_norm: The total gradient norm.
36 | """
37 | total_norm = 0.0
38 | for p in model.parameters():
39 | if p.grad is not None:
40 | param_norm = p.grad.data.norm(norm_type)
41 | total_norm += param_norm.item() ** norm_type
42 | total_norm = total_norm ** (1. / norm_type)
43 | return total_norm
44 |
--------------------------------------------------------------------------------
/scregulate/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import pandas as pd
4 | import numpy as np
5 | import scanpy as sc
6 | import gc
7 | import logging
8 | from . import ulm_standalone as ulm
9 | from .vae_model import scRNA_VAE
10 | from .loss_functions import loss_function
11 | from sklearn.preprocessing import MinMaxScaler
12 | from .utils import (
13 | create_dataloader,
14 | schedule_parameter,
15 | schedule_mask_factor,
16 | apply_gradient_mask,
17 | clip_gradients,
18 | compute_average_loss,
19 | clear_memory,
20 | set_random_seed,
21 | )
22 | set_random_seed()
23 |
24 | # Initialize logging
25 | # Dedicated logger for train_model
26 | # Reset logging to avoid duplicate handlers
27 | train_logger = logging.getLogger("train_model")
28 | train_logger.setLevel(logging.INFO) # Default level
29 |
30 | # Remove any existing handlers
31 | for handler in train_logger.handlers[:]:
32 | train_logger.removeHandler(handler)
33 |
34 | # Create new handler
35 | handler = logging.StreamHandler()
36 | formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
37 | handler.setFormatter(formatter)
38 | train_logger.addHandler(handler)
39 |
40 | train_logger.propagate = False # Prevent duplicate logs from parent loggers
41 |
42 | def adapt_prior_and_data(rna_data, net, min_targets=10, min_TFs=1):
43 | """
44 | Adapt the RNA data and network to create W_prior.
45 |
46 | Args:
47 | rna_data (AnnData): AnnData object containing RNA data and metadata.
48 | net (pd.DataFrame): DataFrame containing 'source', 'target', and 'weight' columns.
49 | min_targets (int): Minimum number of target genes per TF to retain (default: 10).
50 | min_TFs (int): Minimum number of TFs per target gene to retain (default: 1).
51 |
52 | Returns:
53 | tuple: (W_prior (torch.Tensor), gene_names (list), TF_names (list))
54 | """
55 | train_logger.info("Adapting prior and data...")
56 | train_logger.info(f"Initial genes in RNA data: {rna_data.var.index.size}")
57 |
58 | # Validate network structure
59 | required_columns = {'source', 'target', 'weight'}
60 | if not required_columns.issubset(net.columns):
61 | train_logger.error(f"Network DataFrame is missing columns: {required_columns - set(net.columns)}")
62 | raise ValueError(f"The 'net' DataFrame must contain: {required_columns}")
63 |
64 | # Create binary matrix
65 | binary_matrix = pd.crosstab(net['source'], net['target'], values=net['weight'], aggfunc='first').fillna(0)
66 |
67 | # Filter genes in RNA data
68 | keep_genes = rna_data.var.index.intersection(net['target'])
69 | if keep_genes.empty:
70 | train_logger.error("No overlapping genes between RNA data and network.")
71 | raise ValueError("No overlapping genes between RNA data and network.")
72 | rna_data._inplace_subset_var(keep_genes)
73 | train_logger.info(f"Genes retained after intersection with network: {len(keep_genes)}")
74 | train_logger.info(f"Initial TFs in GRN matrix: {binary_matrix.index.size}")
75 |
76 | # Align binary matrix with RNA data
77 | binary_matrix = binary_matrix.loc[:, binary_matrix.columns.intersection(rna_data.var.index)]
78 | if binary_matrix.empty:
79 | train_logger.error("Binary matrix is empty after alignment.")
80 | raise ValueError("Binary matrix empty after alignment with RNA data.")
81 |
82 | # Remove TFs with no targets and genes with no TFs
83 | binary_matrix = binary_matrix.loc[binary_matrix.sum(axis=1) >= min_targets, :]
84 | binary_matrix = binary_matrix.loc[:, binary_matrix.sum(axis=0) >= min_TFs]
85 | if binary_matrix.empty:
86 | train_logger.error(f"Binary matrix empty after filtering (min_targets={min_targets}, min_TFs={min_TFs}).")
87 | raise ValueError("Binary matrix empty after filtering for min_targets and min_TFs.")
88 |
89 | # Convert to W_prior tensor
90 | W_prior = torch.tensor(binary_matrix.values, dtype=torch.float32).T
91 |
92 | # Metadata
93 | gene_names = binary_matrix.columns.tolist()
94 | TF_names = binary_matrix.index.tolist()
95 |
96 | train_logger.info(f"Retained {len(gene_names)} genes and {len(TF_names)} transcription factors.")
97 |
98 | return W_prior, gene_names, TF_names
99 |
100 | def train_model(
101 | rna_data, net,
102 | encode_dims=None, decode_dims=None, z_dim=40,
103 | train_val_split_ratio=0.8, random_state=42,
104 | batch_size=3500, epochs=2000, freeze_epochs=1500, learning_rate=2.5e-4,
105 | alpha_start=0, alpha_max=0.7, alpha_scale=0.06,
106 | beta_start=0, beta_max=0.4,
107 | gamma_start=0, gamma_max=2.6,
108 | log_interval=500, early_stopping_patience=350,
109 | min_targets=20, min_TFs=1, device=None, return_outputs=True, verbose=True
110 | ):
111 |
112 | """
113 | Train the scRNA_VAE model using in-memory data.
114 |
115 | This function trains a variational autoencoder (VAE) designed to infer transcription factor (TF) activities and reconstruct RNA expression. It takes an AnnData object and a gene regulatory network (GRN) prior to initialize the model.
116 |
117 | Args:
118 | rna_data (AnnData, required): AnnData object containing RNA data (normalized expression in `.X` and gene names in `.var.index`).
119 | net (pd.DataFrame, required): GRN DataFrame with columns:
120 | - 'source': Transcription factors (TFs).
121 | - 'target': Target genes.
122 | - 'weight': Interaction weights between TFs and genes.
123 | encode_dims (list, optional): List of integers specifying the sizes of hidden layers for the encoder. Default is [512].
124 | decode_dims (list, optional): List of integers specifying the sizes of hidden layers for the decoder. Default is [1024].
125 | z_dim (int, optional): Latent space dimensionality. Default is 40.
126 | train_val_split_ratio (float, optional): Ratio of the training data to the entire dataset. Default is 0.8 (80% training, 20% validation).
127 | random_state (int, optional): Seed for reproducible splits. Default is 42.
128 | batch_size (int, optional): Number of samples per training batch. Default is 3,500.
129 | epochs (int, optional): Total number of epochs for training. Default is 2,000.
130 | freeze_epochs (int, optional): Number of epochs during which the mask factor is gradually applied to enforce prior structure. Default is 1,500.
131 | learning_rate (float, optional): Learning rate for the optimizer. Default is 2.5e-4.
132 | alpha_start (float, optional): Initial value of `alpha`, controlling the weight of initialized TF activities. Default is 0.
133 | alpha_max (float, optional): Maximum value of `alpha` during training. Default is 0.7.
134 | alpha_scale (float, optional): Scaling factor for `alpha` during training. Default is 0.06.
135 | beta_start (float, optional): Initial value of `beta`, controlling the weight of the KL divergence term in the loss. Default is 0.
136 | beta_max (float, optional): Maximum value of `beta` during training. Default is 0.4.
137 | gamma_start (float, optional): Initial value of `gamma`, controlling the L1 regularization of TF-target weights. Default is 0.
138 | gamma_max (float, optional): Maximum value of `gamma` during training. Default is 2.6.
139 | log_interval (int, optional): Frequency (in epochs) for logging training progress. Default is 500.
140 | early_stopping_patience (int, optional): Number of epochs to wait for improvement before stopping training early. Default is 350.
141 | min_targets (int, optional): Minimum number of target genes per TF to retain in the GRN prior. Default is 10.
142 | min_TFs (int, optional): Minimum number of TFs per target gene to retain in the GRN prior. Default is 1.
143 | device (str, optional): Device for computation ('cuda' or 'cpu'). If not specified, defaults to 'cuda' if available.
144 | return_outputs (bool, optional): If True, returns both the trained model and the modified AnnData object with embeddings and reconstructions. Default is True.
145 | verbose (bool, optional): If True, enable detailed logging. Default is True.
146 |
147 | Returns:
148 | model (scRNA_VAE): Trained scRNA_VAE model instance.
149 | Optional:
150 | rna_data (AnnData): Updated AnnData object with the following fields:
151 | - `obsm["latent"]`: Latent space embedding (cells x latent dimensions).
152 | - `obsm["TF"]`: Inferred transcription factor activities (cells x TFs) from the layer before the last layer of the model.
153 | - `obsm["recon_X"]`: Reconstructed RNA expression (cells x genes).
154 | - `layers["RNA"]`: Original RNA expression data (cells x genes).
155 | - `uns["GRN_prior"]`: Dictionary containing:
156 | - `"matrix"`: The GRN prior weight matrix (TFs x genes).
157 | - `"TF_names"`: List of transcription factor names.
158 | - `"gene_names"`: List of gene names.
159 | - `uns["GRN_posterior"]`: Dictionary containing:
160 | - `"matrix"`: The GRN posterior weight matrix (TFs x genes).
161 | - `"TF_names"`: List of transcription factor names.
162 | - `"gene_names"`: List of gene names.
163 |
164 | Notes:
165 | - The TF activities stored in `rna_data.obsm["TF"]` reflect the model's inferred activities, not the raw ULM estimates.
166 | - Raw ULM estimates are stored as `rna_data.obsm["ulm_estimate"]`.
167 | - `rna_data.uns["GRN_prior"]` contains the aligned GRN matrix fed to the model as a prior.
168 | - `rna_data.uns["GRN_posterior"]` contains the inferred GRN matrix learned by the model.
169 | - Training halts early if the loss stops improving for `early_stopping_patience` epochs after `alpha` reaches its maximum value.
170 | - Ensure that `net` and `rna_data` overlap in their genes to avoid errors during initialization.
171 |
172 | Examples:
173 | >>> trained_model, updated_data = train_model(
174 | >>> rna_data=my_rna_data,
175 | >>> net=my_grn
176 | >>> )
177 | >>> GRN_prior = updated_data.uns["GRN_prior"]["matrix"]
178 | >>> GRN_posterior = updated_data.uns["GRN_posterior"]["matrix"]
179 | """
180 | if not (0 < train_val_split_ratio < 1):
181 | raise ValueError("train_val_split_ratio must be between 0 and 1.")
182 | if encode_dims is None:
183 | encode_dims = [512]
184 | if decode_dims is None:
185 | decode_dims = [1024]
186 |
187 | set_random_seed(seed=random_state)
188 | rna_data = rna_data.copy()
189 |
190 | # Set logging verbosity, control verbosity of train_model logger only.
191 | if verbose:
192 | train_logger.setLevel(logging.INFO)
193 | else:
194 | train_logger.setLevel(logging.WARNING)
195 |
196 | # Auto-detect device if not specified
197 | if device is None:
198 | device = "cuda" if torch.cuda.is_available() else "cpu"
199 |
200 | # Determine batch size dynamically if not provided
201 | if batch_size is None:
202 | batch_size = int(train_val_split_ratio * rna_data.n_obs)
203 | train_logger.info(f"No batch size provided. Using default batch size equal to the number of training samples: {batch_size}")
204 | else:
205 | train_logger.info(f"Using provided batch size: {batch_size}")
206 |
207 | train_logger.info("=" * 40)
208 | train_logger.info("Starting scRegulate TF inference Training Pipeline")
209 | train_logger.info("=" * 40)
210 |
211 | # Adapt prior and prepare data
212 | W_prior, gene_names, TF_names = adapt_prior_and_data(rna_data, net, min_targets, min_TFs)
213 |
214 | # Split into training and validation sets
215 | train_logger.info(f"Splitting data with train-validation split ratio={train_val_split_ratio}")
216 | n_cells = rna_data.n_obs
217 | all_indices = np.arange(n_cells)
218 | train_indices = np.random.choice(all_indices, size=int(train_val_split_ratio * n_cells), replace=False)
219 | val_indices = np.setdiff1d(all_indices, train_indices)
220 |
221 | # Use the indices to reorder cell names
222 | train_cell_names = rna_data.obs.index[train_indices]
223 | val_cell_names = rna_data.obs.index[val_indices]
224 | shuffled_cell_names = np.concatenate([train_cell_names, val_cell_names])
225 |
226 | # Align metadata to the reordering applied to the data
227 | rna_data = rna_data.copy()
228 | rna_data = rna_data[shuffled_cell_names, :] # This reorders both .obs and .X
229 |
230 | # Validation step: Ensure alignment
231 | assert np.array_equal(rna_data.obs.index, shuffled_cell_names), "Mismatch in obs index and shuffled cell names"
232 |
233 | # Run ULM once for the entire dataset
234 | train_logger.info("Running ULM...")
235 | ulm_start = time.time()
236 | ulm.run_ulm(rna_data, net=net, min_n=min_targets, batch_size=batch_size, source="source", target="target", weight="weight", verbose=verbose)
237 | train_logger.info(f"ULM completed in {time.time() - ulm_start:.2f}s")
238 | ulm_estimates = torch.tensor(rna_data.obsm["ulm_estimate"].loc[:, TF_names].to_numpy(), dtype=torch.float32)
239 |
240 | # Split ULM estimates
241 | ulm_estimates_train = ulm_estimates[train_indices, :]
242 | ulm_estimates_val = ulm_estimates[val_indices, :]
243 |
244 | # Prepare data for training and validation
245 | gene_indices = [rna_data.var.index.get_loc(gene) for gene in gene_names]
246 | scRNA_train = torch.tensor(
247 | rna_data.X[train_indices][:, gene_indices].todense() if not isinstance(rna_data.X, np.ndarray) else rna_data.X[train_indices][:, gene_indices],
248 | dtype=torch.float32
249 | )
250 | scRNA_val = torch.tensor(
251 | rna_data.X[val_indices][:, gene_indices].todense() if not isinstance(rna_data.X, np.ndarray) else rna_data.X[val_indices][:, gene_indices],
252 | dtype=torch.float32
253 | )
254 |
255 | # Validate ULM estimates
256 | train_logger.info("Validating ULM estimates...")
257 | # Validate ULM estimates
258 | if ulm_estimates_train.isnan().any() or ulm_estimates_val.isnan().any():
259 | train_logger.error("ULM estimates contain NaN values. Check data preprocessing.")
260 | raise ValueError("Invalid ULM estimates.")
261 | train_logger.info(f"ULM estimates validation passed. Shape: {ulm_estimates.shape}")
262 |
263 | # Transfer data to device
264 | train_logger.info(f"Transferring data to device {device}")
265 | scRNA_train, scRNA_val = scRNA_train.to(device), scRNA_val.to(device)
266 | ulm_estimates_train, ulm_estimates_val = ulm_estimates_train.to(device), ulm_estimates_val.to(device)
267 | W_prior = W_prior.to(device)
268 | train_logger.info("Transfer complete...")
269 |
270 | # Clear memory after transfer
271 | clear_memory(device)
272 |
273 | # Initialize model
274 | model = scRNA_VAE(scRNA_train.shape[1], encode_dims, decode_dims, z_dim, W_prior.shape[1]).to(device)
275 | model.W_prior = W_prior
276 | model.gene_names = gene_names
277 |
278 | # Optimizer and scheduler
279 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
280 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=200, factor=0.5, min_lr=1e-6)
281 | mask = (W_prior == 0).float()
282 |
283 | best_val_loss, epochs_no_improve = float("inf"), 0
284 | batch_size = int(batch_size)
285 | train_loader = create_dataloader(scRNA_train, batch_size=batch_size)
286 | val_loader = create_dataloader(scRNA_val, batch_size=batch_size)
287 | start_time = time.time()
288 |
289 | alpha, beta, gamma = alpha_start, beta_start, gamma_start # Initialize parameters
290 |
291 | # Training loop
292 | for epoch in range(epochs):
293 | model.train()
294 | total_loss, total_samples = 0, 0
295 | mask_factor = schedule_mask_factor(epoch, freeze_epochs)
296 |
297 | for batch_idx, batch in enumerate(train_loader):
298 | optimizer.zero_grad()
299 | start_idx = batch_idx * batch_size
300 | end_idx = start_idx + batch.size(0)
301 | tf_activity_slice = ulm_estimates_train[start_idx:end_idx]
302 | if tf_activity_slice.size(0) != batch.size(0):
303 | continue
304 |
305 | # Forward pass and loss calculation
306 | recon_batch, mu, logvar = model(batch, tf_activity_init=tf_activity_slice, alpha=alpha)
307 | MSE, KLD = loss_function(recon_batch, batch, mu, logvar)
308 | loss = MSE + beta * KLD + gamma * torch.sum(torch.abs(model.tf_mapping.weight) * mask)
309 | loss.backward()
310 |
311 | # Apply gradient masking and clipping
312 | if epoch < freeze_epochs:
313 | apply_gradient_mask(model.tf_mapping, mask, mask_factor)
314 | clip_gradients(model)
315 | optimizer.step()
316 |
317 | total_loss += loss.detach() #* len(batch)
318 | total_samples += len(batch)
319 |
320 | avg_train_loss = compute_average_loss(total_loss, total_samples)
321 |
322 | # Evaluate on validation set
323 | model.eval()
324 | val_loss, val_samples = 0, 0
325 | with torch.no_grad():
326 | for val_batch_idx, val_batch in enumerate(val_loader):
327 | start_idx = val_batch_idx * batch_size
328 | end_idx = start_idx + val_batch.size(0)
329 | tf_val_slice = ulm_estimates_val[start_idx:end_idx]
330 | if tf_val_slice.size(0) != val_batch.size(0):
331 | continue
332 |
333 | recon_val, mu_val, logvar_val = model(val_batch, tf_activity_init=tf_val_slice, alpha=alpha)
334 | MSE_val, KLD_val = loss_function(recon_val, val_batch, mu_val, logvar_val)
335 | loss_val = MSE_val + beta * KLD_val + gamma * torch.sum(torch.abs(model.tf_mapping.weight) * mask)
336 | val_loss += loss_val.detach() #* len(val_batch)
337 | val_samples += len(val_batch)
338 |
339 | avg_val_loss = compute_average_loss(val_loss, val_samples)
340 |
341 | # Update parameters based on validation loss improvement
342 | if avg_val_loss <= best_val_loss:
343 | best_val_loss = avg_val_loss # Update the best validation loss
344 | alpha = alpha + (alpha_max - alpha_start) * alpha_scale
345 | beta = schedule_parameter(epoch, beta_start, beta_max, epochs)
346 | gamma = schedule_parameter(epoch, gamma_start, gamma_max, epochs)
347 |
348 |
349 |
350 | alpha = min(alpha, alpha_max)
351 |
352 | scheduler.step(avg_val_loss)
353 |
354 | # Clear memory after each epoch
355 | clear_memory(device)
356 |
357 | # Check early stopping
358 | if early_stopping_patience is not None:
359 | if avg_val_loss < best_val_loss:
360 | best_val_loss = avg_val_loss
361 | epochs_no_improve = 0
362 | elif alpha == 1.0:
363 | epochs_no_improve += 1
364 |
365 | if epochs_no_improve >= early_stopping_patience:
366 | train_logger.info(f"Early stopping at epoch {epoch + 1}")
367 | break
368 |
369 | # Log progress
370 | if epoch % log_interval == 0:
371 | n_features = scRNA_train.shape[1]
372 | train_loss_per_sample = avg_train_loss / (len(train_indices) * n_features)
373 | val_loss_per_sample = avg_val_loss / (len(val_indices)* n_features)
374 | train_logger.info(
375 | f"Epoch {epoch+1}: Avg Train Loss = {train_loss_per_sample:.4f}, Avg Val Loss = {val_loss_per_sample:.4f}, "
376 | f"Alpha = {alpha:.4f}, Beta = {beta:.4f}, Gamma = {gamma:.4f}, Mask Factor: {mask_factor:.2f}"
377 | )
378 |
379 | # End timing and log final metrics
380 | total_time = time.time() - start_time
381 | train_logger.info("Training completed in %.2fs", total_time)
382 |
383 | clear_memory(device)
384 | # Add latent space, TF activities, reconstructed RNA, and original RNA to AnnData
385 | if return_outputs:
386 | latent_space = []
387 | TF_space = []
388 | reconstructed_rna = []
389 | model.eval()
390 |
391 | # Use the model's batch encoding to extract relevant spaces
392 | # Encode all samples (train + val) together
393 | # Encode all samples (train + val) together
394 | latent_space = np.concatenate([
395 | model.encodeBatch(loader, device=device, out='z') for loader in [train_loader, val_loader]
396 | ], axis=0)
397 | # Reorder latent embeddings to match RNA order
398 | latent_space = latent_space[np.argsort(np.concatenate([train_indices, val_indices])), :]
399 |
400 |
401 | TF_space = np.concatenate([
402 | model.encodeBatch(loader, device=device, out='tf') for loader in [train_loader, val_loader]
403 | ], axis=0)
404 |
405 | # Reorder TF embeddings to match RNA order
406 | TF_space = TF_space[np.argsort(np.concatenate([train_indices, val_indices])), :]
407 |
408 |
409 | reconstructed_rna = np.concatenate([
410 | model.encodeBatch(loader, device=device, out='x') for loader in [train_loader, val_loader]
411 | ], axis=0)
412 |
413 | # Reorder reconstructed_rna embeddings to match RNA order
414 | reconstructed_rna = reconstructed_rna[np.argsort(np.concatenate([train_indices, val_indices])), :]
415 |
416 |
417 | # Extract posterior weight matrix
418 | W_posterior = model.tf_mapping.weight.detach().cpu().numpy()
419 |
420 | # Reorder obs based on concatenated train and validation indices
421 | meta_obs = rna_data.obs.copy()
422 |
423 | rna_modality = sc.AnnData(
424 | X=rna_data.X, # Original RNA expression
425 | obs=meta_obs,
426 | var=rna_data.var.copy(),
427 | obsm=rna_data.obsm.copy(),
428 | uns=rna_data.uns.copy(),
429 | layers=rna_data.layers.copy() if hasattr(rna_data, "layers") else None
430 | )
431 |
432 | rna_modality.uns["type"] = "RNA"
433 |
434 | tf_modality = sc.AnnData(
435 | X=TF_space,
436 | obs=meta_obs,
437 | var=pd.DataFrame(index=TF_names) # Index of transcription factors
438 | )
439 | tf_modality.uns["type"] = "TF"
440 |
441 | recon_rna_modality = sc.AnnData(
442 | X=reconstructed_rna,
443 | obs=meta_obs,
444 | var=pd.DataFrame(index=gene_names) # Use aligned gene names
445 | )
446 | recon_rna_modality.uns["type"] = "Reconstructed RNA"
447 |
448 | latent_modality = sc.AnnData(
449 | X=latent_space,
450 | obs=meta_obs,
451 | var=pd.DataFrame(index=[f"latent_{i}" for i in range(latent_space.shape[1])])
452 | )
453 | latent_modality.uns["type"] = "Latent Space"
454 |
455 | # Add modalities to the parent AnnData object
456 | rna_data.modality = {
457 | "RNA": rna_modality,
458 | "TF": tf_modality,
459 | "recon_RNA": recon_rna_modality,
460 | "latent": latent_modality
461 | }
462 |
463 | rna_data.uns["main_modality"] = "RNA" # Set default modality
464 | rna_data.uns["current_modality"] = "RNA" # Track current modality
465 | train_logger.info(f"Default modality set to: {rna_data.uns['main_modality']}")
466 | train_logger.info(f"Current modality: {rna_data.uns['current_modality']}")
467 |
468 |
469 | # Store GRN_prior in .uns
470 | rna_data.uns["GRN_prior"] = {
471 | "matrix": W_prior.cpu().numpy(),
472 | "TF_names": TF_names,
473 | "gene_names": gene_names,
474 | }
475 |
476 | # Store normalized GRN_posterior in .uns
477 | rna_data.uns["GRN_posterior"] = {
478 | "matrix": MinMaxScaler().fit_transform(model.tf_mapping.weight.detach().cpu().numpy()),
479 | "TF_names": TF_names,
480 | "gene_names": gene_names,
481 | }
482 |
483 | GRN_posterior = pd.DataFrame(
484 | rna_data.uns['GRN_posterior']['matrix'].copy(),
485 | index=rna_data.uns['GRN_posterior']['gene_names'],
486 | columns=rna_data.uns['GRN_posterior']['TF_names']
487 | )
488 |
489 | train_logger.info("`GRN_prior` and `GRN_posterior` stored in the AnnData object under .uns")
490 |
491 | # Final logging
492 | train_logger.info("=" * 40)
493 | train_logger.info("[FINAL SUMMARY]")
494 | train_logger.info(f"Training stopped after {epoch + 1} epochs.")
495 | train_logger.info(f"Final Train Loss: {avg_train_loss:.4f}")
496 | train_logger.info(f"Final Valid Loss: {avg_val_loss:.4f}")
497 | train_logger.info(f"Final Alpha: {alpha:.4f}, Beta: {beta:.4f}, Gamma: {gamma:.4f}")
498 | train_logger.info(f"Total Training Time: {total_time:.2f}s")
499 |
500 | # Retrieve and log shapes from modalities
501 | train_logger.info(f"Latent Space Shape: {rna_data.modality['latent'].X.shape}")
502 | train_logger.info(f"TF Space Shape: {rna_data.modality['TF'].X.shape}")
503 | train_logger.info(f"Reconstructed RNA Shape: {rna_data.modality['recon_RNA'].X.shape}")
504 | train_logger.info(f"Original RNA Shape: {rna_data.modality['RNA'].X.shape}")
505 |
506 | train_logger.info(f"TFs: {len(TF_names)}, Genes: {rna_data.modality['RNA'].X.shape[1]}")
507 | train_logger.info("=" * 40)
508 |
509 |
510 | return model, rna_data, GRN_posterior
511 |
512 | return model
513 |
--------------------------------------------------------------------------------
/scregulate/train_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def schedule_parameter(epoch, start_value, max_value, total_epochs, scale=0.6):
5 | """
6 | Schedules a parameter (e.g., alpha, beta, gamma) to grow from `start_value` to `max_value`
7 | over a given fraction of the training epochs.
8 |
9 | Args:
10 | epoch: Current epoch.
11 | start_value: Starting value of the parameter.
12 | max_value: Maximum value of the parameter.
13 | total_epochs: Total number of training epochs.
14 | scale: Fraction of epochs over which the parameter reaches `max_value`.
15 |
16 | Returns:
17 | Scheduled parameter value for the current epoch.
18 | """
19 | scaled_epochs = total_epochs * scale
20 | if epoch >= scaled_epochs:
21 | return max_value
22 | return start_value + (max_value - start_value) * (epoch / scaled_epochs)
23 |
24 |
25 | def schedule_mask_factor(epoch, freeze_epochs):
26 | """
27 | Computes the mask factor using a sigmoidal schedule.
28 |
29 | Args:
30 | epoch: Current epoch.
31 | freeze_epochs: Epoch at which the mask fully applies.
32 |
33 | Returns:
34 | Mask factor value for the current epoch.
35 | """
36 | return 1.0 / (1.0 + np.exp((epoch - freeze_epochs // 2) / (freeze_epochs // 20)))
37 |
38 |
39 | def apply_gradient_mask(linear_layer, mask_float, mask_factor):
40 | """
41 | Applies a gradient mask to a layer's weights.
42 |
43 | Args:
44 | linear_layer: Linear layer whose gradients are masked.
45 | mask_float: Precomputed mask (e.g., `W_prior == 0`).
46 | mask_factor: Factor controlling the strength of the mask.
47 | """
48 | with torch.no_grad():
49 | linear_layer.weight.grad.mul_(1 - mask_factor * mask_float)
50 |
51 |
52 | def compute_average_loss(total_loss, total_samples):
53 | """
54 | Computes the average loss per sample.
55 |
56 | Args:
57 | total_loss: Total accumulated loss over an epoch.
58 | total_samples: Total number of samples in the epoch.
59 |
60 | Returns:
61 | Average loss per sample.
62 | """
63 | return total_loss.item() / total_samples
64 |
65 |
66 | def clip_gradients(model, max_norm=0.5):
67 | """
68 | Clips gradients of the model to prevent exploding gradients.
69 |
70 | Args:
71 | model: PyTorch model.
72 | max_norm: Maximum allowed norm for gradients.
73 | """
74 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
75 |
--------------------------------------------------------------------------------
/scregulate/ulm_standalone.py:
--------------------------------------------------------------------------------
1 | """
2 | Standalone Implementation of the Univariate Linear Model (ULM).
3 |
4 | This implementation is based on the DecoupleR package's `run_ulm` function but
5 | has been adapted to work independently without any external dependencies. It
6 | calculates transcription factor activity scores from a gene expression matrix
7 | and a regulatory network.
8 |
9 | Citation:
10 | - For the original ULM method, please cite DecoupleR:
11 | Badia-i-Mompel, P., Wessels, L., & Reinders, M. (2021). DecoupleR: a computational framework
12 | to infer molecular activities from omics data. *Bioinformatics*.
13 | """
14 |
15 | import numpy as np
16 | import pandas as pd
17 | from scipy.sparse import csr_matrix
18 | from scipy.stats import t
19 | from tqdm.auto import tqdm
20 | import logging
21 | import warnings
22 |
23 | logger = logging.getLogger(__name__)
24 | from .utils import set_random_seed
25 | set_random_seed()
26 | # --- Helper Functions ---
27 |
28 | def mat_cov(A, b):
29 | return np.dot(b.T - b.mean(), A - A.mean(axis=0)) / (b.shape[0] - 1)
30 |
31 | def mat_cor(A, b):
32 | cov = mat_cov(A, b)
33 | ssd = np.std(A, axis=0, ddof=1) * np.std(b, axis=0, ddof=1).reshape(-1, 1)
34 | return cov / ssd
35 |
36 | def t_val(r, df):
37 | return r * np.sqrt(df / ((1.0 - r + 1.0e-16) * (1.0 + r + 1.0e-16)))
38 |
39 | def filt_min_n(c, net, min_n=5):
40 | msk = np.isin(net['target'].values.astype('U'), c)
41 | net = net.iloc[msk]
42 |
43 | sources, counts = np.unique(net['source'].values.astype('U'), return_counts=True)
44 | msk = np.isin(net['source'].values.astype('U'), sources[counts >= min_n])
45 |
46 | net = net[msk]
47 | if net.shape[0] == 0:
48 | raise ValueError(f"No sources with more than min_n={min_n} targets.")
49 | return net
50 |
51 | def match(c, r, net):
52 | regX = np.zeros((c.shape[0], net.shape[1]), dtype=np.float32)
53 | c_dict = {gene: i for i, gene in enumerate(c)}
54 | idxs = [c_dict[gene] for gene in r if gene in c_dict]
55 | regX[idxs, :] = net[: len(idxs), :]
56 | return regX
57 |
58 | def rename_net(net, source="source", target="target", weight="weight"):
59 | assert source in net.columns, f"Column '{source}' not found in net."
60 | assert target in net.columns, f"Column '{target}' not found in net."
61 | if weight is not None:
62 | assert weight in net.columns, f"Column '{weight}' not found in net."
63 |
64 | net = net.rename(columns={source: "source", target: "target", weight: "weight"})
65 | net = net.reindex(columns=["source", "target", "weight"])
66 |
67 | if net.duplicated(["source", "target"]).sum() > 0:
68 | raise ValueError("net contains repeated edges.")
69 | return net
70 |
71 | def get_net_mat(net):
72 | X = net.pivot(columns="source", index="target", values="weight")
73 | X[np.isnan(X)] = 0
74 | sources = X.columns.values
75 | targets = X.index.values
76 | X = X.values
77 | return sources.astype("U"), targets.astype("U"), X.astype(np.float32)
78 |
79 |
80 | # --- Main Functionality ---
81 |
82 | def ulm(mat, net, batch_size):
83 | logger.debug("Starting ULM calculation...")
84 | n_samples = mat.shape[0]
85 | n_features, n_fsets = net.shape
86 | df = n_features - 2 # Degrees of freedom
87 |
88 | if isinstance(mat, csr_matrix):
89 | logger.debug("Matrix is sparse, processing in batches...")
90 | n_batches = int(np.ceil(n_samples / batch_size))
91 | es = np.zeros((n_samples, n_fsets), dtype=np.float32)
92 | # The progress bar is only shown if logger level <= INFO
93 | show_progress = logger.getEffectiveLevel() <= logging.INFO
94 | for i in tqdm(range(n_batches), disable=not show_progress):
95 | start, end = i * batch_size, i * batch_size + batch_size
96 | batch = mat[start:end].toarray().T
97 | r = mat_cor(net, batch)
98 | es[start:end] = t_val(r, df)
99 | else:
100 | logger.debug("Matrix is dense, processing at once...")
101 | r = mat_cor(net, mat.T)
102 | es = t_val(r, df)
103 |
104 | pv = t.sf(abs(es), df) * 2
105 | logger.debug("ULM calculation complete.")
106 | return es, pv
107 |
108 |
109 | def run_ulm(adata, net, batch_size, source="source", target="target", weight="weight", min_n=5, verbose=True):
110 | """
111 | Run ULM on a Scanpy AnnData object and store results in `.obsm`.
112 |
113 | Args:
114 | adata: AnnData object with gene expression data.
115 | net: DataFrame with columns [source, target, weight].
116 | source (str): Name of the TF column in net.
117 | target (str): Name of the target gene column in net.
118 | weight (str): Name of the weight column in net.
119 | batch_size (int): Batch size for ULM if data is large.
120 | min_n (int): Minimum number of targets per TF.
121 | verbose (bool): If True, set logging level to INFO, else WARNING.
122 |
123 | Returns:
124 | adata with ULM results in adata.obsm["ulm_estimate"] and adata.obsm["ulm_pvals"].
125 | """
126 | # Set logger level based on verbose
127 | logger.setLevel(logging.INFO if verbose else logging.WARNING)
128 |
129 | logger.info("Initializing ULM analysis...")
130 | logger.info("Extracting gene expression data...")
131 | mat = adata.to_df()
132 | genes = adata.var_names
133 |
134 | logger.info("Preparing the regulatory network...")
135 | net = rename_net(net, source=source, target=target, weight=weight)
136 | net = filt_min_n(genes, net, min_n=min_n)
137 |
138 | logger.info("Matching genes in the network to gene expression data...")
139 | sources, targets, net_mat = get_net_mat(net)
140 | net_mat = match(genes, targets, net_mat)
141 |
142 | logger.info(f"ULM parameters: {mat.shape[0]} samples, {len(genes)} genes, {net_mat.shape[1]} TFs.")
143 | logger.info("Calculating ULM estimates and p-values...")
144 | estimate, pvals = ulm(mat.values, net_mat, batch_size=batch_size)
145 |
146 | logger.info("Processing ULM results...")
147 | ulm_estimate = pd.DataFrame(estimate, index=mat.index, columns=sources)
148 | ulm_pvals = pd.DataFrame(pvals, index=mat.index, columns=sources)
149 |
150 | logger.info("Storing ULM results in AnnData object...")
151 | with warnings.catch_warnings():
152 | warnings.simplefilter("ignore", category=UserWarning) # Suppress UserWarnings
153 | adata.obsm["ulm_estimate"] = np.maximum(ulm_estimate, 0)
154 | adata.obsm["ulm_pvals"] = ulm_pvals
155 |
156 | #adata.obsm["ulm_estimate"] = np.maximum(ulm_estimate, 0)
157 | #adata.obsm["ulm_pvals"] = ulm_pvals
158 |
159 | logger.info("ULM analysis complete.")
160 | return adata
161 |
--------------------------------------------------------------------------------
/scregulate/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | import scanpy as sc
5 | from torch.utils.data import DataLoader
6 | import gc
7 | import random
8 |
9 | # ---------- Random Seed Utility ----------
10 | def set_random_seed(seed=42):
11 | """
12 | Sets the random seed for reproducibility across multiple libraries.
13 |
14 | Args:
15 | seed (int): Random seed value to set.
16 | """
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | if torch.cuda.is_available():
21 | torch.cuda.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | torch.backends.cudnn.deterministic = True # Ensures deterministic behavior
24 | torch.backends.cudnn.benchmark = False # Disables benchmarking to avoid randomness
25 | set_random_seed()
26 |
27 | # ---------- Data Loading/Handling ----------
28 | def as_GTRD(net):
29 | W_prior_df = net.copy()
30 | collectri_format = W_prior_df.stack().reset_index()
31 | collectri_format.columns = ["source", "target", "weight"].copy()
32 | collectri_format = collectri_format[collectri_format["weight"] == 1]
33 | return collectri_format.reset_index(drop=True)
34 |
35 | def as_TF_Link(net):
36 | W_prior_df = net.copy()
37 | collectri_format = W_prior_df[['Name.TF', 'Name.Target']]
38 | collectri_format.columns = ["source", "target"]
39 | collectri_format = collectri_format.assign(weight=1)
40 | collectri_format = collectri_format.drop_duplicates()
41 | return collectri_format.reset_index(drop=True)
42 |
43 | def create_dataloader(scRNA_batch, batch_size, shuffle=False, num_workers=0, collate_fn=None):
44 | """
45 | Creates a PyTorch DataLoader for the scRNA_batch data.
46 |
47 | Args:
48 | scRNA_batch (torch.Tensor): Input data for training.
49 | batch_size (int): Batch size for training.
50 | shuffle (bool): Whether to shuffle the data during loading.
51 | num_workers (int): Number of workers for data loading.
52 | collate_fn (callable, optional): Custom collate function for DataLoader.
53 |
54 | Returns:
55 | DataLoader: PyTorch DataLoader for the scRNA_batch data.
56 | """
57 | return DataLoader(scRNA_batch, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
58 | collate_fn=collate_fn) # Pass custom collate_fn if provided)
59 |
60 | def clear_memory(device):
61 | """
62 | Clears memory for the specified device.
63 |
64 | Args:
65 | device (str): The device to clear memory for ('cuda' or 'cpu').
66 | """
67 | if device == "cuda":
68 | torch.cuda.empty_cache()
69 | else:
70 | gc.collect()
71 |
72 | def extract_GRN(adata, matrix_type="posterior"):
73 | """
74 | Extract the GRN matrix (either prior or posterior) from an AnnData object as a pandas DataFrame.
75 |
76 | Args:
77 | adata (AnnData): The AnnData object containing GRN information in `.uns`.
78 | matrix_type (str): The type of GRN matrix to extract, either 'prior' or 'posterior'.
79 | Defaults to 'posterior'.
80 |
81 | Returns:
82 | pd.DataFrame: A DataFrame where rows are genes and columns are transcription factors (TFs).
83 | """
84 | # Map short names to full keys
85 | matrix_key_map = {
86 | "posterior": "GRN_posterior",
87 | "prior": "GRN_prior"
88 | }
89 |
90 | if matrix_type not in matrix_key_map:
91 | raise ValueError(f"Invalid matrix_type '{matrix_type}'. Choose from 'prior' or 'posterior'.")
92 |
93 | full_key = matrix_key_map[matrix_type]
94 |
95 | if full_key not in adata.uns:
96 | raise ValueError(f"{full_key} is not stored in the provided AnnData object.")
97 |
98 | grn_data = adata.uns[full_key]
99 | matrix = grn_data["matrix"]
100 | TF_names = grn_data["TF_names"]
101 | gene_names = grn_data["gene_names"]
102 |
103 | # Determine the orientation of the matrix
104 | if matrix.shape[0] == len(TF_names) and matrix.shape[1] == len(gene_names):
105 | # TFs as rows, genes as columns (needs transposition)
106 | GRN_df = pd.DataFrame(matrix.T, index=gene_names, columns=TF_names)
107 | elif matrix.shape[0] == len(gene_names) and matrix.shape[1] == len(TF_names):
108 | # Genes as rows, TFs as columns (already correct orientation)
109 | GRN_df = pd.DataFrame(matrix, index=gene_names, columns=TF_names)
110 | else:
111 | raise ValueError("Matrix dimensions do not match the lengths of `gene_names` and `TF_names`.")
112 |
113 | return GRN_df
114 |
115 |
116 | def to_torch_tensor(matrix, device="cpu"):
117 | """
118 | Converts a dense or sparse matrix to a PyTorch tensor.
119 |
120 | Args:
121 | matrix: Numpy array or sparse matrix.
122 | device (str): Device to load the tensor onto.
123 |
124 | Returns:
125 | torch.Tensor: Dense or sparse tensor.
126 | """
127 | if isinstance(matrix, np.ndarray):
128 | return torch.tensor(matrix, dtype=torch.float32).to(device)
129 | elif hasattr(matrix, "tocoo"): # Sparse matrix support
130 | coo = matrix.tocoo()
131 | indices = torch.LongTensor(np.vstack((coo.row, coo.col)))
132 | values = torch.FloatTensor(coo.data)
133 | shape = torch.Size(coo.shape)
134 | return torch.sparse_coo_tensor(indices, values, shape).to(device)
135 | else:
136 | raise TypeError("Unsupported matrix type for conversion to PyTorch tensor.")
137 |
138 |
139 | # ---------- Parameter Scheduling ----------
140 | def schedule_parameter(epoch, start_value, max_value, total_epochs, scale=0.6):
141 | """
142 | Schedules a parameter (e.g., alpha, beta, gamma) to grow from `start_value` to `max_value`
143 | over a given fraction of the training epochs.
144 |
145 | Args:
146 | epoch: Current epoch.
147 | start_value: Starting value of the parameter.
148 | max_value: Maximum value of the parameter.
149 | total_epochs: Total number of training epochs.
150 | scale: Fraction of epochs over which the parameter reaches `max_value`.
151 |
152 | Returns:
153 | Scheduled parameter value for the current epoch.
154 | """
155 | scaled_epochs = total_epochs * scale
156 | if epoch >= scaled_epochs:
157 | return max_value
158 | return start_value + (max_value - start_value) * (epoch / scaled_epochs)
159 |
160 |
161 | def schedule_mask_factor(epoch, freeze_epochs):
162 | """
163 | Computes the mask factor using a sigmoid function.
164 |
165 | Args:
166 | epoch (int): Current epoch.
167 | freeze_epochs (int): Epoch at which the mask factor reaches its midpoint.
168 |
169 | Returns:
170 | float: Mask factor for the current epoch.
171 | """
172 | return 1.0 / (1.0 + np.exp((epoch - freeze_epochs // 2) / (freeze_epochs // 20)))
173 |
174 |
175 | # ---------- Gradient Operations ----------
176 | def apply_gradient_mask(layer, mask, mask_factor):
177 | """
178 | Applies a gradient mask to a layer's weights.
179 |
180 | Args:
181 | layer (torch.nn.Linear): Linear layer whose gradients need masking.
182 | mask (torch.Tensor): Binary mask indicating where gradients should be scaled.
183 | mask_factor (float): Scaling factor for the mask.
184 |
185 | Returns:
186 | None
187 | """
188 | with torch.no_grad():
189 | layer.weight.grad.mul_(1 - mask_factor * mask)
190 |
191 |
192 | def clip_gradients(model, max_norm=0.5):
193 | """
194 | Clips the gradients of the model parameters.
195 |
196 | Args:
197 | model (torch.nn.Module): PyTorch model.
198 | max_norm (float): Maximum norm for gradient clipping.
199 |
200 | Returns:
201 | None
202 | """
203 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
204 |
205 |
206 | # ---------- Loss Utility ----------
207 | def compute_average_loss(total_loss, total_samples):
208 | """
209 | Computes the average loss per sample.
210 |
211 | Args:
212 | total_loss (torch.Tensor): Total accumulated loss.
213 | total_samples (int): Total number of samples processed.
214 |
215 | Returns:
216 | float: Average loss per sample.
217 | """
218 | return total_loss.item() / total_samples
219 |
220 |
221 | # ---------- Modality Enhancement ----------
222 | def set_active_modality(adata, modality_name):
223 | """
224 | Sets the active modality in the AnnData object.
225 |
226 | Args:
227 | adata (AnnData): Parent AnnData object containing multiple modalities.
228 | modality_name (str): The name of the modality to activate.
229 |
230 | Returns:
231 | AnnData: A new AnnData object with the specified modality set as the active one.
232 | """
233 | # Validate input modality
234 | if modality_name not in adata.modality:
235 | raise ValueError(f"Modality '{modality_name}' not found in AnnData.modality.")
236 |
237 | # Retrieve selected modality
238 | selected_modality = adata.modality[modality_name]
239 |
240 | # Create a new AnnData object with the selected modality
241 | temp_adata = sc.AnnData(
242 | X=selected_modality.X,
243 | obs=selected_modality.obs.copy(),
244 | var=selected_modality.var.copy(),
245 | obsm=selected_modality.obsm.copy(),
246 | uns=adata.uns.copy(), # Retain shared metadata in `uns`
247 | )
248 |
249 | # Retain modality-specific `.obsm` entries
250 | temp_adata.obsm.update(adata.obsm)
251 |
252 | # Attach all modalities back to the new AnnData object
253 | temp_adata.modality = adata.modality
254 |
255 | return temp_adata
256 |
257 | def extract_modality(adata, modality_name):
258 | """
259 | Extract a specific modality as an independent AnnData object.
260 |
261 | Args:
262 | adata (AnnData): The AnnData object containing multiple modalities.
263 | modality_name (str): The name of the modality to extract.
264 |
265 | Returns:
266 | AnnData: The requested modality as a standalone AnnData object.
267 | """
268 | # Ensure the modality exists
269 | if not hasattr(adata, "modality") or modality_name not in adata.modality:
270 | raise ValueError(f"Modality '{modality_name}' not found in the AnnData object.")
271 |
272 | # Extract the requested modality
273 | selected_modality = adata.modality[modality_name]
274 | return sc.AnnData(
275 | X=selected_modality.X.copy(),
276 | obs=selected_modality.obs.copy(),
277 | var=selected_modality.var.copy(),
278 | obsm=selected_modality.obsm.copy(),
279 | uns=selected_modality.uns.copy(),
280 | layers=selected_modality.layers.copy() if hasattr(selected_modality, "layers") else None
281 | )
282 |
--------------------------------------------------------------------------------
/scregulate/vae_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | # Define the VAE model
6 | class scRNA_VAE(nn.Module):
7 | """
8 | Variational Autoencoder (VAE) for scRNA-seq data.
9 |
10 | Args:
11 | input_dim (int): Number of input features (genes).
12 | encode_dims (list): List of hidden layer sizes for the encoder.
13 | decode_dims (list): List of hidden layer sizes for the decoder.
14 | z_dim (int): Size of the latent space.
15 | tf_dim (int): Size of the transcription factor space.
16 | ulm_init (torch.Tensor, optional): Pre-initialized TF activity matrix.
17 | """
18 |
19 | def __init__(self, input_dim, encode_dims, decode_dims, z_dim, tf_dim, ulm_init=None):
20 | super(scRNA_VAE, self).__init__()
21 |
22 | # Encoder
23 | self.encoder_layers = nn.ModuleList()
24 | previous_dim = input_dim
25 | for h_dim in encode_dims:
26 | self.encoder_layers.append(nn.Linear(previous_dim, h_dim))
27 | previous_dim = h_dim
28 | self.fc_mu = nn.Linear(previous_dim, z_dim)
29 | self.fc_logvar = nn.Linear(previous_dim, z_dim)
30 |
31 | # Decoder
32 | self.decoder_layers = nn.ModuleList()
33 | previous_dim = z_dim
34 | for h_dim in decode_dims:
35 | self.decoder_layers.append(nn.Linear(previous_dim, h_dim))
36 | previous_dim = h_dim
37 | self.fc_output = nn.Linear(previous_dim, tf_dim)
38 | self.tf_mapping = nn.Linear(tf_dim, input_dim)
39 | self.ulm_init = ulm_init
40 |
41 |
42 | def encode(self, x):
43 | for layer in self.encoder_layers:
44 | x = F.relu(layer(x))
45 | mu = self.fc_mu(x)
46 | logvar = self.fc_logvar(x)
47 | return mu, logvar
48 |
49 | def reparameterize(self, mu, logvar):
50 | std = torch.exp(0.5 * logvar)
51 | eps = torch.randn_like(std)
52 | return mu + eps * std
53 |
54 | def decode(self, z):
55 | for layer in self.decoder_layers:
56 | z = F.relu(layer(z))
57 | return F.relu(self.fc_output(z)) # Was torch.exp()
58 |
59 | def forward(self, x, tf_activity_init=None, alpha=1.0):
60 | mu, logvar = self.encode(x)
61 | z = self.reparameterize(mu, logvar)
62 |
63 | if tf_activity_init is not None:
64 | # Use a weighted combination of pre-initialized and learned TF activities
65 | tf_activity = (1 - alpha) * tf_activity_init.detach() + alpha * self.decode(z)
66 | else:
67 | # If no pre-initialized TF activities, just use the decoded values
68 | tf_activity = self.decode(z)
69 |
70 | recon_x = self.tf_mapping(tf_activity)
71 | return recon_x, mu, logvar
72 |
73 |
74 | def encodeBatch(self, dataloader, device='cuda', out='z'):
75 | output = []
76 | for batch in dataloader:
77 | batch = batch.to(device)
78 | mu, logvar = self.encode(batch)
79 | z = self.reparameterize(mu, logvar)
80 |
81 | if out == 'z':
82 | output.append(z.detach().cpu())
83 | elif out == 'x':
84 | recon_x = self.tf_mapping(self.decode(z))
85 | output.append(recon_x.detach().cpu())
86 | elif out == 'tf':
87 | tf_activity = self.decode(z)
88 | output.append(tf_activity.detach().cpu())
89 |
90 | output = torch.cat(output).numpy()
91 | return output
92 |
93 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = scRegulate
3 | version = 0.1.0
4 | author = Mehrdad Zandigohar
5 | author_email = mehr.zgohar@gmail.com
6 | description = Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | license = MIT
10 | url = https://github.com/YDaiLab/scRegulate
11 | project_urls =
12 | Documentation = https://github.com/YDaiLab/scRegulate#readme
13 | Issue Tracker = https://github.com/YDaiLab/scRegulate/issues
14 | Paper (bioRxiv) = https://doi.org/10.1101/2025.04.17.649372
15 |
16 | [options]
17 | packages = find:
18 | python_requires = >=3.8
19 | install_requires =
20 | torch>=2.0
21 | numpy>=1.23
22 | scanpy>=1.9
23 | anndata>=0.8
24 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="scRegulate",
5 | version="0.1.0",
6 | author="Mehrdad Zandigohar",
7 | author_email="mehr.zgohar@gmail.com",
8 | description="Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data",
9 | long_description=open("README.md", encoding="utf-8").read(),
10 | long_description_content_type="text/markdown",
11 | url="https://github.com/YDaiLab/scRegulate",
12 | project_urls={
13 | "Documentation": "https://github.com/YDaiLab/scRegulate#readme",
14 | "Issue Tracker": "https://github.com/YDaiLab/scRegulate/issues",
15 | "Paper (bioRxiv)": "https://doi.org/10.1101/2025.04.17.649372"
16 | },
17 | license="MIT",
18 | packages=find_packages(),
19 | python_requires=">=3.8",
20 | install_requires=[
21 | "torch>=2.0",
22 | "numpy>=1.23",
23 | "scanpy>=1.9",
24 | "anndata>=0.8"
25 | ],
26 | classifiers=[
27 | "Programming Language :: Python :: 3",
28 | "License :: OSI Approved :: MIT License",
29 | "Operating System :: OS Independent",
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------