├── .github └── workflows │ ├── publish-conda.yml │ └── python-publish.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── License.md ├── README.md ├── assets ├── Visual_Abstract.png └── tool_logo.svg ├── conda.recipe ├── conda_build_config.yaml ├── environment.yml └── meta.yaml ├── docs ├── Data_Preparation.html ├── README.md └── tutorial_main.html ├── notebooks ├── Data_Preparation.ipynb ├── README.md └── tutorial_main.ipynb ├── priors ├── collectri_human_net.csv └── collectri_mouse_net.csv ├── pyproject.toml ├── requirements.txt ├── scregulate ├── __init__.py ├── __version__.py ├── auto_tuning.py ├── datasets.py ├── fine_tuning.py ├── loss_functions.py ├── train.py ├── train_utils.py ├── ulm_standalone.py ├── utils.py └── vae_model.py ├── setup.cfg └── setup.py /.github/workflows/publish-conda.yml: -------------------------------------------------------------------------------- 1 | name: Publish Conda Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build-and-upload-conda: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | 15 | - name: Setup Miniconda 16 | uses: conda-incubator/setup-miniconda@v3.0.4 17 | with: 18 | environment-file: ./conda.recipe/environment.yml 19 | activate-environment: build_env 20 | auto-update-conda: true 21 | use-mamba: true 22 | channels: conda-forge,pytorch 23 | channel-priority: strict 24 | 25 | - name: Install conda-build and anaconda-client 26 | shell: bash 27 | run: | 28 | eval "$(conda shell.bash hook)" 29 | conda activate build_env 30 | conda install conda-build anaconda-client -y 31 | 32 | - name: Build conda package 33 | shell: bash 34 | run: | 35 | eval "$(conda shell.bash hook)" 36 | conda activate build_env 37 | export GIT_TAG_NAME="${GITHUB_REF##*/}" 38 | export GIT_TAG_NAME="${GIT_TAG_NAME#v}" # strip leading "v" if present 39 | 40 | # Inject the version into the Python module so setuptools can read it 41 | echo "__version__ = \"${GIT_TAG_NAME}\"" > scregulate/__version__.py 42 | 43 | conda build conda.recipe --output-folder dist \ 44 | --channel conda-forge --channel pytorch 45 | 46 | - name: Upload to Anaconda 47 | shell: bash 48 | env: 49 | ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }} 50 | run: | 51 | eval "$(conda shell.bash hook)" 52 | conda activate build_env 53 | shopt -s nullglob 54 | for file in dist/*/*.conda dist/*/*.tar.bz2; do 55 | echo "Uploading $file" 56 | anaconda -t "$ANACONDA_TOKEN" upload "$file" --force 57 | done 58 | 59 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | release-build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install build tools 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install build 27 | 28 | - name: Inject version from GitHub tag 29 | run: | 30 | VERSION="${GITHUB_REF##*/}" 31 | VERSION="${VERSION#v}" # Strip leading 'v' if present 32 | echo "__version__ = \"${VERSION}\"" > scregulate/__version__.py 33 | cat scregulate/__version__.py 34 | 35 | 36 | - name: Build distributions 37 | run: python -m build 38 | 39 | - name: Upload distributions as artifacts 40 | uses: actions/upload-artifact@v4 41 | with: 42 | name: release-dists 43 | path: dist/ 44 | 45 | pypi-publish: 46 | runs-on: ubuntu-latest 47 | needs: release-build 48 | permissions: 49 | id-token: write 50 | 51 | environment: 52 | name: pypi 53 | url: https://pypi.org/project/scRegulate/${{ github.event.release.tag_name }} 54 | 55 | steps: 56 | - name: Download built distributions 57 | uses: actions/download-artifact@v4 58 | with: 59 | name: release-dists 60 | path: dist/ 61 | 62 | - name: Publish to PyPI 63 | uses: pypa/gh-action-pypi-publish@release/v1 64 | with: 65 | packages-dir: dist/ 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 5 | 6 | ## Our Standards 7 | Examples of behavior that contributes to creating a positive environment include: 8 | - Using welcoming and inclusive language 9 | - Being respectful of differing viewpoints and experiences 10 | - Gracefully accepting constructive criticism 11 | - Focusing on what is best for the community 12 | - Showing empathy towards other community members 13 | 14 | Examples of unacceptable behavior by participants include: 15 | - The use of sexualized language or imagery and unwelcome sexual attention or advances 16 | - Trolling, insulting/derogatory comments, and personal or political attacks 17 | - Public or private harassment 18 | - Publishing others' private information, such as a physical or electronic address, without explicit permission 19 | - Other conduct which could reasonably be considered inappropriate in a professional setting 20 | 21 | ## Our Responsibilities 22 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 23 | 24 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 25 | 26 | ## Scope 27 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 28 | 29 | ## Enforcement 30 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at [mzandi2@uic.edu](mailto:mzandi2@uic.edu). All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 31 | 32 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 33 | 34 | ## Attribution 35 | This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html) 36 | 37 | For answers to common questions about this code of conduct, see [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Mehrdad Zandigohar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /License.md: -------------------------------------------------------------------------------- 1 | YEAR: 2024 2 | 3 | COPYRIGHT HOLDER: scRegulate Authors 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scRegulate 2 | **S**ingle-**C**ell **Regula**tory-Embedded Variational Inference of **T**ranscription Factor Activity from Gene **E**xpression 3 | 4 | [![GitHub issues](https://img.shields.io/github/issues/YDaiLab/scRegulate)](https://github.com/YDaiLab/scRegulate/issues) 5 | [![PyPI - Project](https://img.shields.io/pypi/v/scRegulate)](https://pypi.org/project/scRegulate/) 6 | [![Conda](https://img.shields.io/conda/v/zandigohar/scregulate?label=conda)](https://anaconda.org/zandigohar/scregulate) 7 | [![Docs](https://img.shields.io/badge/docs-GitHub%20Pages-blue)](https://ydailab.github.io/scRegulate/) 8 | 9 | 10 | ## Introduction 11 | **scRegulate** is a powerful tool designed for the inference of transcription factor activity from single cell/nucleus RNA data using advanced generative modeling techniques. It leverages a unified learning framework to optimize the modeling of cellular regulatory networks, providing researchers with accurate insights into transcriptional regulation. With its efficient clustering capabilities, **scRegulate** facilitates the analysis of complex biological data, making it an essential resource for studies in genomics and molecular biology. 12 | 13 |
14 | 15 |
16 | 17 | For further information and example tutorials, please check our documentation: 18 | - [PBMC 3K Tutorial (HTML)](https://ydailab.github.io/scRegulate/tutorial_main.html) 19 | - [Reproducibility Guide (HTML)](https://ydailab.github.io/scRegulate/Data_Preparation.html) 20 | 21 | If you have any questions or concerns, feel free to [open an issue](https://github.com/YDaiLab/scRegulate/issues). 22 | 23 | ## Requirements 24 | scRegulate is implemented in the PyTorch framework. Running scRegulate on `CUDA` is highly recommended if available. 25 | 26 | 27 | Before installing and running scRegulate, ensure you have the following libraries installed: 28 | - **PyTorch** (version 2.0 or higher) 29 | Install with the exact command from the [PyTorch “Get Started” page](https://pytorch.org/get-started/locally/) for your OS, Python version and (optionally) CUDA toolkit. 30 | - **NumPy** (version 1.23 or higher) 31 | - **Scanpy** (version 1.9 or higher) 32 | - **Anndata** (version 0.8 or higher) 33 | 34 | You can install these dependencies using `pip`: 35 | 36 | ```bash 37 | pip install torch numpy scanpy anndata 38 | ``` 39 | 40 | ## Installation 41 | 42 | **Option 1:** 43 | You can install **scRegulate** via pip for a lightweight installation: 44 | 45 | ```bash 46 | pip install scregulate 47 | ``` 48 | 49 | **Option 2:** 50 | Alternatively, if you want the latest, unreleased version, you can install it directly from the source on GitHub: 51 | 52 | ```bash 53 | pip install git+https://github.com/YDaiLab/scRegulate.git 54 | ``` 55 | 56 | **Option 3:** 57 | For users who prefer Conda or Mamba for environment management, you can install **scRegulate** along with extra dependencies: 58 | 59 | **Conda:** 60 | ```bash 61 | conda install -c zandigohar scregulate 62 | ``` 63 | 64 | **Mamba:** 65 | ```bash 66 | mamba create -n scRegulate -c zandigohar scregulate 67 | ``` 68 | 69 | ## FAQ 70 | 71 | **Q1: Do I need a GPU to run scRegulate?** 72 | No, a GPU is not required. However, using a CUDA-enabled GPU is strongly recommended for faster training and inference, especially with large datasets. 73 | 74 | **Q2: How do I know if I can use a GPU with scRegulate?** 75 | There are two quick checks: 76 | 77 | 1. **System check** 78 | In your terminal, run `nvidia-smi`. If you see your GPU listed (model, memory, driver version), your machine has an NVIDIA GPU with the driver installed. 79 | 80 | 2. **Python check** 81 | In a Python shell, run: 82 | ```python 83 | import torch 84 | print(torch.cuda.is_available()) # True means PyTorch can see your GPU 85 | print(torch.cuda.device_count()) # How many GPUs are usable 86 | ``` 87 | 88 | **Q3: Can I use scRegulate with Seurat or R-based tools?** 89 | scRegulate is written in Python and works directly with `AnnData` objects (e.g., from Scanpy). You can convert Seurat objects to AnnData using tools like `SeuratDisk`. 90 | 91 | **Q4: How can I visualize inferred TF activities?** 92 | TF activities inferred by scRegulate are stored in the `obsm` slot of the AnnData object. You can use `scanpy.pl.embedding`, `scanpy.pl.heatmap`, or export the matrix for custom plots. 93 | 94 | **Q5: What kind of prior networks does scRegulate accept?** 95 | scRegulate supports user-provided gene regulatory networks (GRNs) in CSV or matrix format. These can be curated from public databases or inferred from ATAC-seq or motif analysis. 96 | 97 | **Q6: Can I use scRegulate for multi-omics integration?** 98 | Not directly. While scRegulate focuses on TF activity from RNA, you can incorporate priors derived from other omics (e.g., ATAC) to **guide** the model. 99 | 100 | **Q7: What file formats are supported?** 101 | scRegulate works with `.h5ad` files (AnnData format). Input files should contain gene expression matrices with proper normalization. 102 | 103 | **Q8: How do I cite scRegulate?** 104 | See the [Citation](#citation) section below for the latest reference and preprint link. 105 | 106 | **Q9: How can I reproduce the paper’s results?** 107 | See our [Reproducibility Guide](https://github.com/YDaiLab/scRegulate/blob/main/notebooks/Data_Preparation.ipynb) for step-by-step instructions. Then run scregulate. 108 | 109 | ## Citation 110 | 111 | **scRegulate** manuscript is currently under peer review. 112 | 113 | If you use **scRegulate** in your research, please cite: 114 | 115 | Mehrdad Zandigohar, Jalees Rehman and Yang Dai (2025). **scRegulate: Single-Cell Regulatory-Embedded Variational Inference of Transcription Factor Activity from Gene Expression**, Bioinformatics Journal (under review). 116 | 117 | 📄 Read the preprint on bioRxiv: [10.1101/2025.04.17.649372](https://doi.org/10.1101/2025.04.17.649372) 118 | 119 | ## Development & Contact 120 | scRegulate was developed and is actively maintained by Mehrdad Zandigohar as part of his PhD research at the University of Illinois Chicago (UIC), in the lab of Dr. Yang Dai. 121 | 122 | 📬 For private questions, please email: mzandi2@uic.edu 123 | 124 | 🤝 For collaboration inquiries, please contact PI: Dr. Yang Dai (yangdai@uic.edu) 125 | 126 | Contributions, feature suggestions, and feedback are always welcome! 127 | 128 | ## License 129 | 130 | The code in **scRegulate** is licensed under the [MIT License](https://opensource.org/licenses/MIT), which permits academic and commercial use, modification, and distribution. 131 | 132 | Please note that any third-party dependencies bundled with **scRegulate** may have their own respective licenses. 133 | 134 | -------------------------------------------------------------------------------- /assets/Visual_Abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YDaiLab/scRegulate/52df4611f4e99014c20caca9ebf321dc44d660b3/assets/Visual_Abstract.png -------------------------------------------------------------------------------- /assets/tool_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | SEc R e g0101010001000110 179 | -------------------------------------------------------------------------------- /conda.recipe/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - pytorch 4 | - defaults 5 | 6 | python: 7 | - 3.10 8 | numpy: 9 | - 1.23 10 | -------------------------------------------------------------------------------- /conda.recipe/environment.yml: -------------------------------------------------------------------------------- 1 | name: build_env 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.10 8 | - pip 9 | - pytorch>=2.0 10 | - scanpy=1.10.4 11 | - anndata=0.10.8 12 | - numpy=1.26.4 13 | - matplotlib-base>=3.6 14 | - optuna >=3.0 15 | -------------------------------------------------------------------------------- /conda.recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set version = environ.get('GIT_TAG_NAME', '0.0.0') %} 2 | 3 | package: 4 | name: scregulate 5 | version: "{{ version }}" 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | noarch: python 12 | script: python -m pip install . 13 | number: 0 14 | 15 | requirements: 16 | host: 17 | - python >=3.10,<3.11 18 | - pip 19 | run: 20 | - python >=3.10,<3.11 21 | - numpy >=1.26 22 | - pytorch >=2.0,<2.3 23 | - scanpy >=1.10,<1.11 24 | - anndata >=0.10 25 | - matplotlib-base >=3.6,<3.9 26 | - pillow >=8.0 27 | - optuna >=3.0 28 | 29 | test: 30 | imports: 31 | - scregulate 32 | commands: 33 | - pip check 34 | requires: 35 | - pip 36 | source_files: 37 | - setup.py 38 | channels: 39 | - conda-forge 40 | - pytorch 41 | 42 | about: 43 | home: https://github.com/YDaiLab/scRegulate 44 | dev_url: https://github.com/YDaiLab/scRegulate 45 | license: MIT 46 | license_file: LICENSE 47 | summary: Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data 48 | description: | 49 | scRegulate is a Python toolkit designed for inferring transcription factor activity 50 | and clustering single-cell RNA-seq data using variational inference with biological priors. 51 | It is built on PyTorch and Scanpy, and supports CUDA acceleration for high-performance inference. 52 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | 28 | 29 |

📘 scRegulate Tutorials

30 | 31 |

This page contains multiple tutorials to help you:

32 | 36 | 37 |
38 |
🧬 PBMC 3K Tutorial
39 |
📄 Reproduce Paper Results
40 |
41 | 42 |
43 |

🧬 TF Inference on PBMC 3K

44 |

This tutorial walks you through the basic usage of scRegulate on a small PBMC 3K dataset.

45 | 49 |
50 | 51 |
52 |

📄 Reproducing Manuscript Results

53 |

This notebook replicates the preprocessing and analysis pipeline used in our paper.

54 | 58 |
59 | 60 | 71 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | This page contains multiple tutorials consisting of: 2 | 1) TF inference on PBMC 3K [Notebook](https://github.com/YDaiLab/scRegulate/blob/main/notebooks/tutorial_main.ipynb) | [HTML](https://ydailab.github.io/scRegulate/tutorial_main.html) 3 | 2) Reproducing manuscript results [Notebook](https://github.com/YDaiLab/scRegulate/blob/main/notebooks/Data_Preparation.ipynb) | [HTML](https://ydailab.github.io/scRegulate/Data_Preparation.html) 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 77.0.3"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "scRegulate" 7 | dynamic = ["version"] 8 | description = "Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = "MIT" 12 | authors = [ 13 | {name = "Mehrdad Zandigohar", email = "mehr.zgohar@gmail.com"}, 14 | ] 15 | dependencies = [ 16 | "torch>=2.0", 17 | "numpy>=1.23", 18 | "scanpy>=1.9", 19 | "anndata>=0.8" 20 | ] 21 | 22 | [project.urls] 23 | "Homepage" = "https://github.com/YDaiLab/scRegulate" 24 | "Documentation" = "https://github.com/YDaiLab/scRegulate#readme" 25 | "Issue Tracker" = "https://github.com/YDaiLab/scRegulate/issues" 26 | "Paper (bioRxiv)" = "https://doi.org/10.1101/2025.04.17.649372" 27 | 28 | [tool.setuptools.dynamic] 29 | version = { attr = "scregulate.__version__.__version__" } 30 | 31 | 32 | 33 | [tool.setuptools.packages.find] 34 | where = ["."] 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0 2 | numpy>=1.23 3 | scanpy>=1.9 4 | anndata>=0.8 5 | -------------------------------------------------------------------------------- /scregulate/__init__.py: -------------------------------------------------------------------------------- 1 | # scregulate/__init__.py 2 | from .__version__ import __version__ 3 | 4 | # Core training and auto-tuning 5 | from .train import train_model, adapt_prior_and_data 6 | from .auto_tuning import auto_tune 7 | 8 | # Model architecture 9 | from .vae_model import scRNA_VAE 10 | 11 | # Loss functions and gradient norm 12 | from .loss_functions import loss_function, get_gradient_norm 13 | 14 | # Training utilities 15 | from .train_utils import ( 16 | schedule_parameter, 17 | schedule_mask_factor, 18 | apply_gradient_mask, 19 | compute_average_loss, 20 | clip_gradients, 21 | ) 22 | 23 | # General utilities 24 | from .utils import ( 25 | set_random_seed, 26 | create_dataloader, 27 | clear_memory, 28 | extract_GRN, 29 | to_torch_tensor, 30 | set_active_modality, 31 | extract_modality, 32 | ) 33 | 34 | # GRN prior utilization (collectri) 35 | from .datasets import collectri_prior 36 | -------------------------------------------------------------------------------- /scregulate/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.26" 2 | -------------------------------------------------------------------------------- /scregulate/auto_tuning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import optuna 4 | from optuna.samplers import TPESampler, NSGAIISampler, QMCSampler 5 | from optuna.pruners import HyperbandPruner, SuccessiveHalvingPruner 6 | 7 | import logging 8 | from .utils import set_active_modality 9 | 10 | from .fine_tuning import fine_tune_clusters 11 | from .train import train_model 12 | from .train import adapt_prior_and_data 13 | 14 | # ---------- Logger Configuration ---------- 15 | autotune_logger = logging.getLogger("autotune") 16 | autotune_logger.setLevel(logging.INFO) 17 | handler = logging.StreamHandler() 18 | formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 19 | handler.setFormatter(formatter) 20 | if not autotune_logger.handlers: 21 | autotune_logger.addHandler(handler) 22 | autotune_logger.propagate = False 23 | 24 | # ---------- Common Validation Loss Function ---------- 25 | def compute_validation_loss(model, val_adata): 26 | autotune_logger.info("Computing validation loss...") 27 | 28 | # Retrieve gene names used in the model 29 | if not hasattr(model, "gene_names"): 30 | raise ValueError("Model does not have gene_names attribute. Ensure it is set during training.") 31 | model_genes = model.gene_names # List of gene names used during training 32 | 33 | # Align validation data with model's gene names 34 | adata_genes = val_adata.var_names 35 | shared_genes = set(adata_genes).intersection(set(model_genes)) 36 | if len(shared_genes) < len(model_genes): 37 | val_adata = val_adata[:, list(shared_genes)].copy() 38 | autotune_logger.info(f"Aligned validation data to {len(shared_genes)} shared genes.") 39 | 40 | X_val = val_adata.X 41 | if not isinstance(X_val, np.ndarray): 42 | X_val = X_val.toarray() 43 | 44 | device = "cuda" if torch.cuda.is_available() else "cpu" 45 | X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device) 46 | 47 | model.eval() 48 | with torch.no_grad(): 49 | mu, logvar = model.encode(X_val_tensor) 50 | z = model.reparameterize(mu, logvar) 51 | tf_activity = model.decode(z) 52 | recon_x = model.tf_mapping(tf_activity) 53 | mse_loss = torch.mean((recon_x - X_val_tensor) ** 2).item() 54 | 55 | autotune_logger.info(f"Validation MSE loss: {mse_loss}") 56 | return mse_loss 57 | 58 | # ---------- Training Objective Function ---------- 59 | def training_objective(trial, train_adata, val_adata, net, train_model, **kwargs): 60 | autotune_logger.info("Running training objective...") 61 | 62 | # Suggest hyperparameters 63 | freeze_epochs = trial.suggest_int('freeze_epochs', kwargs['freeze_epochs_range'][0], kwargs['freeze_epochs_range'][1]) 64 | alpha_scale = trial.suggest_float('alpha_scale', kwargs['alpha_scale_range'][0], kwargs['alpha_scale_range'][1]) 65 | alpha_max = trial.suggest_float('alpha_max', kwargs['alpha_max_range'][0], kwargs['alpha_max_range'][1]) 66 | beta_max = trial.suggest_float('beta_max', kwargs['beta_max_range'][0], kwargs['beta_max_range'][1]) 67 | gamma_max = trial.suggest_float('gamma_max', kwargs['gamma_max_range'][0], kwargs['gamma_max_range'][1]) 68 | learning_rate = trial.suggest_float('learning_rate', kwargs['learning_rate_range'][0], kwargs['learning_rate_range'][1], log=True) 69 | 70 | # Train model 71 | model, _, _ = train_model( 72 | rna_data=train_adata, 73 | net=net, 74 | encode_dims=kwargs['encode_dims'], 75 | decode_dims=kwargs['decode_dims'], 76 | z_dim=kwargs['z_dim'], 77 | epochs=kwargs['epochs'], 78 | freeze_epochs=freeze_epochs, 79 | learning_rate=learning_rate, 80 | 81 | alpha_start=0, 82 | alpha_scale=alpha_scale, 83 | alpha_max=alpha_max, 84 | 85 | beta_start=0, 86 | beta_max=beta_max, 87 | 88 | gamma_start=0, 89 | gamma_max=gamma_max, 90 | 91 | log_interval=kwargs['log_interval'], 92 | early_stopping_patience=kwargs['early_stopping_patience'], 93 | min_targets=kwargs['min_targets'], 94 | min_TFs=kwargs['min_TFs'], 95 | device=None, 96 | return_outputs=True, 97 | verbose=kwargs['verbose'] 98 | ) 99 | 100 | # Compute validation loss 101 | val_loss = compute_validation_loss(model, val_adata) 102 | return val_loss 103 | 104 | # ---------- Fine-Tuning Objective Function ---------- 105 | def fine_tuning_objective(trial, processed_adata, model, **kwargs): 106 | autotune_logger.info("Running fine-tuning objective...") 107 | 108 | # Suggest hyperparameters 109 | beta_max = trial.suggest_float("beta_max", kwargs['beta_max_range'][0], kwargs['beta_max_range'][1]) 110 | max_weight_norm = trial.suggest_float("max_weight_norm", kwargs['max_weight_norm_range'][0], kwargs['max_weight_norm_range'][1]) 111 | tf_mapping_lr = trial.suggest_float("tf_mapping_lr", kwargs['tf_mapping_lr_range'][0], kwargs['tf_mapping_lr_range'][1], log=True) 112 | fc_output_lr = trial.suggest_float("fc_output_lr", kwargs['fc_output_lr_range'][0], kwargs['fc_output_lr_range'][1], log=True) 113 | default_lr = trial.suggest_float("default_lr", kwargs['default_lr_range'][0], kwargs['default_lr_range'][1], log=True) 114 | 115 | # Log modality availability 116 | autotune_logger.info("Checking modalities in processed_adata...") 117 | if not hasattr(processed_adata, "modality"): 118 | autotune_logger.error("processed_adata does not have 'modality' attribute.") 119 | raise ValueError("Missing 'modality' attribute in processed_adata.") 120 | 121 | if "TF" not in processed_adata.modality: 122 | autotune_logger.error(f"Available modalities: {list(processed_adata.modality.keys())}") 123 | raise ValueError("Missing modality['TF'] in processed_adata.") 124 | 125 | autotune_logger.info(f"Modalities available: {list(processed_adata.modality.keys())}") 126 | 127 | # Log cluster_key and its addition to obs 128 | cluster_key = kwargs.get('cluster_key', None) 129 | autotune_logger.info(f"Adding cluster_key '{cluster_key}' to processed_adata.obs...") 130 | 131 | # Fine-tune model 132 | _, _, fine_model , _ = fine_tune_clusters( 133 | processed_adata=processed_adata, 134 | model=model, 135 | cluster_key=cluster_key, 136 | epochs=kwargs['epochs'], 137 | device=kwargs['device'], 138 | verbose=kwargs['verbose'], 139 | beta_start=0, 140 | beta_max=beta_max, 141 | max_weight_norm=max_weight_norm, 142 | early_stopping_patience=kwargs['early_stopping_patience'], 143 | tf_mapping_lr=tf_mapping_lr, 144 | fc_output_lr=fc_output_lr, 145 | default_lr=default_lr # Fine-tuned base learning rate 146 | ) 147 | 148 | # Compute validation loss 149 | autotune_logger.info("Computing validation loss after fine-tuning...") 150 | val_loss = compute_validation_loss(fine_model, processed_adata) 151 | return val_loss 152 | 153 | 154 | # ---------- Unified Auto-Tuning Function ---------- 155 | def auto_tune( 156 | mode, # 'training' or 'fine-tuning' 157 | processed_adata=None, 158 | model=None, 159 | net=None, 160 | train_model=None, # Explicitly added as a direct argument 161 | n_trials=10, 162 | **kwargs 163 | ): 164 | autotune_logger.info(f"Starting auto-tuning for {mode} with {n_trials} trials...") 165 | 166 | if processed_adata is None: 167 | raise ValueError("processed_adata must be provided for auto-tuning.") 168 | 169 | 170 | if mode == "training": 171 | train_val_split_ratio = kwargs.get('train_val_split_ratio', 0.8) 172 | np.random.seed(0) 173 | indices = np.arange(processed_adata.n_obs) 174 | np.random.shuffle(indices) 175 | train_size = int(train_val_split_ratio * len(indices)) 176 | train_idx = indices[:train_size] 177 | val_idx = indices[train_size:] 178 | 179 | train_adata = processed_adata[train_idx].copy() 180 | val_adata = processed_adata[val_idx].copy() 181 | 182 | # Adapt prior for training data 183 | W_prior, gene_names, TF_names = adapt_prior_and_data(train_adata, net, kwargs['min_targets'], kwargs['min_TFs']) 184 | 185 | # Align validation data to training gene names 186 | val_genes = val_adata.var_names.intersection(gene_names) 187 | val_adata = val_adata[:, list(val_genes)].copy() 188 | 189 | elif mode == "fine-tuning": 190 | # Ensure the active modality is set to 'RNA' 191 | autotune_logger.info("Setting the active modality to 'RNA'...") 192 | processed_adata = set_active_modality(processed_adata, "RNA") 193 | autotune_logger.info("Active modality set to 'RNA'.") 194 | 195 | if not hasattr(model, "gene_names"): 196 | raise ValueError("Model does not have gene_names attribute. Ensure it is set during training.") 197 | gene_names = model.gene_names 198 | fine_tuned_genes = processed_adata.var_names.intersection(gene_names) 199 | processed_adata._inplace_subset_var(list(fine_tuned_genes)) 200 | autotune_logger.info(f"Aligned processed_adata to {len(fine_tuned_genes)} genes.") 201 | train_adata, val_adata = None, None 202 | else: 203 | raise ValueError("Invalid mode. Choose 'training' or 'fine-tuning'.") 204 | 205 | # Create study 206 | #study = optuna.create_study(direction='minimize', sampler=TPESampler()) 207 | 208 | # Use BOHB (Bayesian Optimization with HyperBand) 209 | sampler = TPESampler(multivariate=True, seed=42) 210 | 211 | pruner = HyperbandPruner( 212 | min_resource=500, # Let model run at least 500 epochs before considering pruning 213 | max_resource=kwargs['epochs'], # Full training budget 214 | reduction_factor=3, # Prune bottom 2/3 at each round 215 | bootstrap_count=0 216 | ) 217 | 218 | study = optuna.create_study(direction='minimize', sampler=sampler, pruner=pruner) 219 | 220 | if mode == "training": 221 | study.optimize( 222 | lambda trial: training_objective( 223 | trial, train_adata, val_adata, net, train_model, **kwargs 224 | ), 225 | n_trials=n_trials 226 | ) 227 | elif mode == "fine-tuning": 228 | study.optimize( 229 | lambda trial: fine_tuning_objective( 230 | trial, processed_adata, model, **kwargs 231 | ), 232 | n_trials=n_trials 233 | ) 234 | else: 235 | raise ValueError("Invalid mode. Choose 'training' or 'fine-tuning'.") 236 | 237 | autotune_logger.info(f"Best hyperparameters: {study.best_params}") 238 | autotune_logger.info(f"Best validation loss: {study.best_value}") 239 | return study.best_params, study.best_value 240 | -------------------------------------------------------------------------------- /scregulate/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import pandas as pd 4 | 5 | def _download_file_if_needed(filename: str, url: str, target_dir: str) -> str: 6 | os.makedirs(target_dir, exist_ok=True) 7 | filepath = os.path.join(target_dir, filename) 8 | if not os.path.exists(filepath): 9 | print(f"[scRegulate] Downloading {filename} to {filepath}...") 10 | try: 11 | urllib.request.urlretrieve(url, filepath) 12 | except Exception as e: 13 | raise RuntimeError(f"Failed to download {url}.\nError: {e}") 14 | return filepath 15 | 16 | def collectri_prior(species: str = "human") -> pd.DataFrame: 17 | """ 18 | Returns the collectri prior dataframe for the given species ("human" or "mouse"). 19 | Downloads the file on first use and caches it under ~/.scregulate/priors/. 20 | 21 | Parameters: 22 | - species: str = "human" or "mouse" 23 | 24 | Returns: 25 | - pd.DataFrame with TF-target prior network 26 | """ 27 | base_url = "https://github.com/YDaiLab/scRegulate/raw/main/priors/" 28 | species_to_filename = { 29 | "human": "collectri_human_net.csv", 30 | "mouse": "collectri_mouse_net.csv" 31 | } 32 | 33 | if species not in species_to_filename: 34 | raise ValueError("species must be either 'human' or 'mouse'") 35 | 36 | filename = species_to_filename[species] 37 | url = base_url + filename 38 | target_dir = os.path.expanduser("~/.scregulate/priors") 39 | 40 | local_path = _download_file_if_needed(filename, url, target_dir) 41 | return pd.read_csv(local_path) 42 | -------------------------------------------------------------------------------- /scregulate/fine_tuning.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | import scanpy as sc 5 | import pandas as pd 6 | from torch.utils.data import DataLoader 7 | from sklearn.preprocessing import MinMaxScaler 8 | import logging 9 | import torch.nn.functional as F 10 | from .loss_functions import loss_function 11 | from .utils import ( 12 | set_random_seed, # To ensure reproducibility 13 | create_dataloader, # For DataLoader creation 14 | clear_memory, # To handle memory cleanup 15 | to_torch_tensor, # Converts dense or sparse data to PyTorch tensors 16 | schedule_parameter, # For parameter scheduling during training 17 | ) 18 | set_random_seed() 19 | 20 | # ---------- Logging Configuration ---------- 21 | finetune_logger = logging.getLogger("finetune") 22 | finetune_logger.setLevel(logging.INFO) 23 | 24 | # Reset logger: Remove any existing handlers 25 | while finetune_logger.handlers: 26 | finetune_logger.removeHandler(finetune_logger.handlers[0]) 27 | 28 | # Define a new handler 29 | handler = logging.StreamHandler() # Logs to console; use FileHandler for a file 30 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 31 | handler.setFormatter(formatter) 32 | finetune_logger.addHandler(handler) 33 | 34 | finetune_logger.propagate = False # Prevent duplicate logs from root logger 35 | 36 | # ---------- Sparsity Handling ---------- 37 | def sparse_collate_fn(batch): 38 | """ 39 | Custom collate function to handle sparse tensors. 40 | 41 | Args: 42 | batch: A batch of data (potentially sparse). 43 | 44 | Returns: 45 | torch.Tensor: Dense tensor batch. 46 | """ 47 | if isinstance(batch[0], torch.Tensor): 48 | return torch.stack(batch) 49 | elif isinstance(batch[0], torch.sparse_coo_tensor): 50 | # Convert sparse tensors to dense format before stacking 51 | return torch.stack([b.to_dense() for b in batch]) 52 | else: 53 | raise TypeError(f"Unsupported batch type: {type(batch[0])}") 54 | 55 | # ---------- Extract TF Activities ---------- 56 | def extract_tf_activities_deterministic(model, data_loader, device="cuda"): 57 | """ 58 | Extract transcription factor (TF) activities deterministically using the trained model. 59 | 60 | Args: 61 | model (torch.nn.Module): The trained model to use for inference. 62 | data_loader (DataLoader): DataLoader containing the cluster-specific data. 63 | device (str): Device for computation ('cuda' or 'cpu'). Default is 'cuda'. 64 | 65 | Returns: 66 | np.ndarray: Matrix of inferred TF activities. 67 | """ 68 | model.eval() 69 | tf_activities = [] 70 | with torch.no_grad(): 71 | for batch in data_loader: 72 | batch = batch.to(device) 73 | mu, _ = model.encode(batch) # Use mean (mu) for deterministic encoding 74 | tf_activity = model.decode(mu).clamp(min=0) 75 | #tf_activity = model.decode(mu).clamp(min=0, max=5.0) 76 | tf_activities.append(tf_activity.cpu().numpy()) 77 | return np.concatenate(tf_activities, axis=0) 78 | 79 | # ---------- Cluster-sepcific handling ---------- 80 | def scale_and_format_grns(W_posteriors, gene_names, tf_names): 81 | """ 82 | Apply MinMax scaling to each W matrix (per cluster) and return a dict of DataFrames. 83 | """ 84 | scaled_grns = {} 85 | for cluster, W in W_posteriors.items(): 86 | scaled = MinMaxScaler().fit_transform(W) 87 | scaled_df = pd.DataFrame(scaled, index=gene_names, columns=tf_names) 88 | scaled_grns[cluster] = scaled_df 89 | 90 | return scaled_grns 91 | 92 | def finalize_fine_tuning_outputs(processed_adata, all_tf_activities_matrix, tf_names, W_posteriors, gene_names): 93 | """ 94 | Prepares and returns the processed AnnData, TF activity AnnData, and GRNs. 95 | """ 96 | # Attach TF activities to AnnData 97 | processed_adata.obsm["TF_finetuned"] = all_tf_activities_matrix 98 | processed_adata.uns["W_posteriors_per_cluster"] = W_posteriors 99 | 100 | # Create new AnnData object for TF activities 101 | tf_activities_adata = sc.AnnData( 102 | X=all_tf_activities_matrix, 103 | obs=processed_adata.obs.copy(), 104 | var=pd.DataFrame(index=tf_names), 105 | ) 106 | 107 | # Format GRNs 108 | scaled_grns = scale_and_format_grns( 109 | W_posteriors, gene_names=gene_names, tf_names=tf_names 110 | ) 111 | 112 | return processed_adata, tf_activities_adata, scaled_grns 113 | 114 | 115 | 116 | # ---------- Fine-Tuning for All Clusters ---------- 117 | def fine_tune_clusters( 118 | processed_adata, 119 | model, 120 | cluster_key=None, 121 | epochs=500, 122 | batch_size=3500, 123 | device=None, 124 | max_weight_norm=4, 125 | log_interval=100, 126 | early_stopping_patience=250, 127 | min_epochs=0, 128 | beta_start=0, 129 | beta_max=0.5, 130 | tf_mapping_lr=4e-04, # Learning rate for tf_mapping layer 131 | fc_output_lr=2e-05, # Learning rate for fc_output layer 132 | default_lr=3.5e-05, # Default learning rate for other layers 133 | verbose=True, 134 | ): 135 | """ 136 | Fine-tune the model for each cluster in the AnnData object with logging intervals and early stopping. 137 | 138 | Args: 139 | processed_adata (AnnData): AnnData object containing RNA data and cluster annotations. 140 | model (torch.nn.Module): Pre-trained model to fine-tune. 141 | cluster_key (str): Key in `.obs` containing cluster annotations. Default is None. 142 | epochs (int): Number of epochs for fine-tuning each cluster. Default is 5000. 143 | batch_size (int): Batch size for fine-tuning. Default is 2000. 144 | device (str): Device to perform computation ('cuda' or 'cpu'). Default is 'cuda'. 145 | beta_start (float): Initial value for beta warm-up. Default is 0.01. 146 | beta_max (float): Maximum value for beta warm-up. Default is 0.1. 147 | max_weight_norm (float): Maximum weight norm for clipping. Default is 0.1. 148 | log_interval (int): Number of epochs between logging updates. Default is 100. 149 | early_stopping_patience (int): Number of epochs to wait without improvement. Default is 1000. 150 | min_epochs (int): Minimum number of epochs before checking for early stopping. Default is 1000. 151 | tf_mapping_lr (float): Learning rate for the tf_mapping layer. Default is 1e-2. 152 | fc_output_lr (float): Learning rate for the fc_output layer. Default is 1e-3. 153 | default_lr (float): Learning rate for other layers. Default is 1e-4. 154 | verbose (bool): If True, enables detailed logging. 155 | 156 | Returns: 157 | AnnData: Updated AnnData object with fine-tuned TF activities (`obsm`) and cluster-specific W_posteriors (`uns`). 158 | 159 | Notes: 160 | - Do not change the learning rates unless you are an expert. 161 | - tf_mapping_lr (1e-2): Faster learning rate for tf_mapping layer. 162 | - fc_output_lr (1e-3): Slower learning rate for fc_output layer. 163 | - default_lr (1e-4): Learning rate for any other trainable layers. 164 | """ 165 | set_random_seed() 166 | batch_size = int(batch_size) 167 | 168 | # Set logging verbosity, control verbosity of train_model logger only. 169 | if verbose: 170 | finetune_logger.setLevel(logging.INFO) 171 | else: 172 | finetune_logger.setLevel(logging.WARNING) 173 | 174 | # Auto-detect device if not specified 175 | if device is None: 176 | device = "cuda" if torch.cuda.is_available() else "cpu" 177 | 178 | finetune_logger.info("Starting fine-tuning for cluster(s)...") 179 | 180 | # Align gene indices to match GRN dimensions 181 | grn_gene_names = processed_adata.uns["GRN_posterior"]["gene_names"] 182 | adata_gene_names = processed_adata.var.index 183 | gene_indices = [adata_gene_names.get_loc(gene) for gene in grn_gene_names if gene in adata_gene_names] 184 | 185 | 186 | finetune_logger.info(f"Aligning data to {len(gene_indices)} genes matching the GRN.") 187 | 188 | if not cluster_key: 189 | finetune_logger.info("Cluster key not provided, fine-tuning on all cells together...") 190 | cluster_key = 'all' 191 | processed_adata.obs[cluster_key] = 1 192 | 193 | unique_clusters = processed_adata.obs[cluster_key].unique() 194 | original_model = copy.deepcopy(model) 195 | W_posteriors_per_cluster = {} 196 | all_tf_activities = [] 197 | cluster_indices_list = [] 198 | total_cells = len(processed_adata) 199 | for cluster in unique_clusters: 200 | if len(unique_clusters)>1: 201 | finetune_logger.info(f"Fine-tuning {cluster} for {epochs} epochs...") 202 | else: 203 | finetune_logger.info(f"Fine-tuning on all cells for {epochs} epochs...") 204 | 205 | cluster_size = len(processed_adata.obs[processed_adata.obs[cluster_key] == cluster]) 206 | proportional_epochs = max(min_epochs, int(epochs * (cluster_size / total_cells))) 207 | 208 | # Extract cluster data and ensure dense format 209 | cluster_indices = processed_adata.obs[cluster_key] == cluster 210 | cluster_data = processed_adata[cluster_indices].X[:, gene_indices] 211 | 212 | # Convert sparse to dense if needed 213 | if not isinstance(cluster_data, np.ndarray): 214 | cluster_data = cluster_data.todense() 215 | cluster_tensor = to_torch_tensor(cluster_data, device=device) 216 | 217 | # Create DataLoader 218 | cluster_loader = create_dataloader(cluster_tensor, batch_size=batch_size) 219 | 220 | # Fine-tune model for the cluster with early stopping 221 | model_copy = copy.deepcopy(original_model) 222 | best_loss, epochs_no_improve = float("inf"), 0 223 | 224 | # Optimizer with configurable learning rates 225 | optimizer = torch.optim.AdamW([ 226 | {"params": model_copy.tf_mapping.parameters(), "lr": tf_mapping_lr}, # Learning rate for tf_mapping 227 | {"params": model_copy.fc_output.parameters(), "lr": fc_output_lr}, # Learning rate for fc_output 228 | {"params": [p for n, p in model_copy.named_parameters() 229 | if "tf_mapping" not in n and "fc_output" not in n], "lr": default_lr} # Default LR 230 | ]) 231 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", 232 | patience=early_stopping_patience//2, 233 | factor=0.5, min_lr=default_lr) 234 | for epoch in range(proportional_epochs): 235 | model_copy.train() 236 | total_loss, total_samples = 0, 0 237 | # Warm-up beta 238 | beta = schedule_parameter(epoch, beta_start, beta_max, epochs) 239 | 240 | for batch in cluster_loader: 241 | batch = batch.to(device) 242 | optimizer.zero_grad() 243 | 244 | # Forward pass 245 | recon_batch, mu, logvar = model_copy(batch) 246 | MSE, _ = loss_function(recon_batch, batch, mu, logvar) 247 | l1_reg = torch.sum(torch.abs(model_copy.tf_mapping.weight)) 248 | 249 | cluster_weight = len(cluster_indices) / len(processed_adata) 250 | loss = (MSE + beta * l1_reg)# * cluster_weight 251 | 252 | # Backward pass 253 | loss.backward() 254 | optimizer.step() 255 | 256 | # Weight clipping 257 | with torch.no_grad(): 258 | model_copy.tf_mapping.weight.clamp_(-max_weight_norm, max_weight_norm) 259 | 260 | total_loss += loss.item() * len(batch) 261 | total_samples += len(batch) 262 | 263 | # Average loss for the epoch 264 | avg_loss = total_loss / total_samples 265 | scheduler.step(avg_loss) 266 | 267 | # Logging 268 | if epoch % log_interval == 0 or epoch == epochs - 1: 269 | finetune_logger.info(f"Epoch {epoch+1}, Avg Loss: {avg_loss/processed_adata.shape[1]:.4f}") 270 | 271 | # Early stopping 272 | if epoch > min_epochs: 273 | if avg_loss < best_loss: 274 | best_loss = avg_loss 275 | epochs_no_improve = 0 276 | else: 277 | epochs_no_improve += 1 278 | if epochs_no_improve >= early_stopping_patience: 279 | finetune_logger.info(f"Early stopping for cluster {cluster} at epoch {epoch+1}") 280 | break 281 | 282 | # Extract W_posterior and TF activities 283 | W_posterior_cluster = model_copy.tf_mapping.weight.detach().cpu().numpy() 284 | W_posteriors_per_cluster[cluster] = W_posterior_cluster 285 | 286 | tf_activities = extract_tf_activities_deterministic(model_copy, cluster_loader, device=device) 287 | 288 | cluster_indices_list.append(np.where(cluster_indices)[0]) 289 | all_tf_activities.append(tf_activities) 290 | 291 | # Clear memory 292 | clear_memory(device) 293 | 294 | # [Modification 3] - We now build the final matrix in the ORIGINAL cell order 295 | n_cells = processed_adata.n_obs 296 | 297 | # If there's at least one cluster, get the size of the TF dimension from the first cluster's result 298 | n_tfs = all_tf_activities[0].shape[1] if len(all_tf_activities) > 0 else 0 299 | 300 | # [Modification 4] - Create a zeroed matrix to hold TF activities for ALL cells 301 | all_tf_activities_matrix = np.zeros((n_cells, n_tfs), dtype=np.float32) 302 | 303 | # [Modification 5] - Insert each cluster's activities into the correct rows 304 | for row_indices, cluster_acts in zip(cluster_indices_list, all_tf_activities): 305 | all_tf_activities_matrix[row_indices, :] = cluster_acts 306 | 307 | # Attach to adata 308 | processed_adata.obsm["TF_finetuned"] = all_tf_activities_matrix 309 | processed_adata.uns["W_posteriors_per_cluster"] = W_posteriors_per_cluster 310 | 311 | tf_names = processed_adata.modality['TF'].var.index 312 | 313 | tf_activities_adata = sc.AnnData( 314 | X=all_tf_activities_matrix, 315 | obs=processed_adata.obs.copy(), 316 | var=pd.DataFrame(index=tf_names), 317 | ) 318 | 319 | gene_names = processed_adata.uns['GRN_posterior']['gene_names'] 320 | tf_names = processed_adata.uns['GRN_posterior']['TF_names'] 321 | 322 | processed_adata, tf_activities_adata, scaled_grns = finalize_fine_tuning_outputs( 323 | processed_adata, 324 | all_tf_activities_matrix, 325 | tf_names=tf_names, 326 | W_posteriors=W_posteriors_per_cluster, 327 | gene_names=gene_names, 328 | ) 329 | 330 | # If only one GRN (no cluster_key given), return a single DataFrame 331 | if len(scaled_grns) == 1: 332 | scaled_grns = list(scaled_grns.values())[0] 333 | 334 | finetune_logger.info("Fine-tuning completed for all clusters.") 335 | return processed_adata, tf_activities_adata, model_copy, scaled_grns 336 | 337 | 338 | -------------------------------------------------------------------------------- /scregulate/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def loss_function(recon_x, x, mu, logvar): 5 | """ 6 | Computes the loss function for the Variational Autoencoder (VAE). 7 | The loss consists of: 8 | - Mean Squared Error (MSE) between the input and reconstructed data. 9 | - Kullback-Leibler Divergence (KLD) to regularize the latent space. 10 | 11 | Args: 12 | recon_x: Reconstructed input. 13 | x: Original input. 14 | mu: Mean vector from the encoder. 15 | logvar: Log variance vector from the encoder. 16 | 17 | Returns: 18 | MSE: Reconstruction loss. 19 | KLD: KL divergence loss. 20 | """ 21 | MSE = F.mse_loss(recon_x, x, reduction='sum') # Reconstruction loss 22 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # KL Divergence 23 | return MSE, KLD 24 | 25 | 26 | def get_gradient_norm(model, norm_type=2): 27 | """ 28 | Computes the total gradient norm for a model. 29 | 30 | Args: 31 | model: The PyTorch model. 32 | norm_type: The type of norm to compute (default is 2 for Euclidean norm). 33 | 34 | Returns: 35 | total_norm: The total gradient norm. 36 | """ 37 | total_norm = 0.0 38 | for p in model.parameters(): 39 | if p.grad is not None: 40 | param_norm = p.grad.data.norm(norm_type) 41 | total_norm += param_norm.item() ** norm_type 42 | total_norm = total_norm ** (1. / norm_type) 43 | return total_norm 44 | -------------------------------------------------------------------------------- /scregulate/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import pandas as pd 4 | import numpy as np 5 | import scanpy as sc 6 | import gc 7 | import logging 8 | from . import ulm_standalone as ulm 9 | from .vae_model import scRNA_VAE 10 | from .loss_functions import loss_function 11 | from sklearn.preprocessing import MinMaxScaler 12 | from .utils import ( 13 | create_dataloader, 14 | schedule_parameter, 15 | schedule_mask_factor, 16 | apply_gradient_mask, 17 | clip_gradients, 18 | compute_average_loss, 19 | clear_memory, 20 | set_random_seed, 21 | ) 22 | set_random_seed() 23 | 24 | # Initialize logging 25 | # Dedicated logger for train_model 26 | # Reset logging to avoid duplicate handlers 27 | train_logger = logging.getLogger("train_model") 28 | train_logger.setLevel(logging.INFO) # Default level 29 | 30 | # Remove any existing handlers 31 | for handler in train_logger.handlers[:]: 32 | train_logger.removeHandler(handler) 33 | 34 | # Create new handler 35 | handler = logging.StreamHandler() 36 | formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 37 | handler.setFormatter(formatter) 38 | train_logger.addHandler(handler) 39 | 40 | train_logger.propagate = False # Prevent duplicate logs from parent loggers 41 | 42 | def adapt_prior_and_data(rna_data, net, min_targets=10, min_TFs=1): 43 | """ 44 | Adapt the RNA data and network to create W_prior. 45 | 46 | Args: 47 | rna_data (AnnData): AnnData object containing RNA data and metadata. 48 | net (pd.DataFrame): DataFrame containing 'source', 'target', and 'weight' columns. 49 | min_targets (int): Minimum number of target genes per TF to retain (default: 10). 50 | min_TFs (int): Minimum number of TFs per target gene to retain (default: 1). 51 | 52 | Returns: 53 | tuple: (W_prior (torch.Tensor), gene_names (list), TF_names (list)) 54 | """ 55 | train_logger.info("Adapting prior and data...") 56 | train_logger.info(f"Initial genes in RNA data: {rna_data.var.index.size}") 57 | 58 | # Validate network structure 59 | required_columns = {'source', 'target', 'weight'} 60 | if not required_columns.issubset(net.columns): 61 | train_logger.error(f"Network DataFrame is missing columns: {required_columns - set(net.columns)}") 62 | raise ValueError(f"The 'net' DataFrame must contain: {required_columns}") 63 | 64 | # Create binary matrix 65 | binary_matrix = pd.crosstab(net['source'], net['target'], values=net['weight'], aggfunc='first').fillna(0) 66 | 67 | # Filter genes in RNA data 68 | keep_genes = rna_data.var.index.intersection(net['target']) 69 | if keep_genes.empty: 70 | train_logger.error("No overlapping genes between RNA data and network.") 71 | raise ValueError("No overlapping genes between RNA data and network.") 72 | rna_data._inplace_subset_var(keep_genes) 73 | train_logger.info(f"Genes retained after intersection with network: {len(keep_genes)}") 74 | train_logger.info(f"Initial TFs in GRN matrix: {binary_matrix.index.size}") 75 | 76 | # Align binary matrix with RNA data 77 | binary_matrix = binary_matrix.loc[:, binary_matrix.columns.intersection(rna_data.var.index)] 78 | if binary_matrix.empty: 79 | train_logger.error("Binary matrix is empty after alignment.") 80 | raise ValueError("Binary matrix empty after alignment with RNA data.") 81 | 82 | # Remove TFs with no targets and genes with no TFs 83 | binary_matrix = binary_matrix.loc[binary_matrix.sum(axis=1) >= min_targets, :] 84 | binary_matrix = binary_matrix.loc[:, binary_matrix.sum(axis=0) >= min_TFs] 85 | if binary_matrix.empty: 86 | train_logger.error(f"Binary matrix empty after filtering (min_targets={min_targets}, min_TFs={min_TFs}).") 87 | raise ValueError("Binary matrix empty after filtering for min_targets and min_TFs.") 88 | 89 | # Convert to W_prior tensor 90 | W_prior = torch.tensor(binary_matrix.values, dtype=torch.float32).T 91 | 92 | # Metadata 93 | gene_names = binary_matrix.columns.tolist() 94 | TF_names = binary_matrix.index.tolist() 95 | 96 | train_logger.info(f"Retained {len(gene_names)} genes and {len(TF_names)} transcription factors.") 97 | 98 | return W_prior, gene_names, TF_names 99 | 100 | def train_model( 101 | rna_data, net, 102 | encode_dims=None, decode_dims=None, z_dim=40, 103 | train_val_split_ratio=0.8, random_state=42, 104 | batch_size=3500, epochs=2000, freeze_epochs=1500, learning_rate=2.5e-4, 105 | alpha_start=0, alpha_max=0.7, alpha_scale=0.06, 106 | beta_start=0, beta_max=0.4, 107 | gamma_start=0, gamma_max=2.6, 108 | log_interval=500, early_stopping_patience=350, 109 | min_targets=20, min_TFs=1, device=None, return_outputs=True, verbose=True 110 | ): 111 | 112 | """ 113 | Train the scRNA_VAE model using in-memory data. 114 | 115 | This function trains a variational autoencoder (VAE) designed to infer transcription factor (TF) activities and reconstruct RNA expression. It takes an AnnData object and a gene regulatory network (GRN) prior to initialize the model. 116 | 117 | Args: 118 | rna_data (AnnData, required): AnnData object containing RNA data (normalized expression in `.X` and gene names in `.var.index`). 119 | net (pd.DataFrame, required): GRN DataFrame with columns: 120 | - 'source': Transcription factors (TFs). 121 | - 'target': Target genes. 122 | - 'weight': Interaction weights between TFs and genes. 123 | encode_dims (list, optional): List of integers specifying the sizes of hidden layers for the encoder. Default is [512]. 124 | decode_dims (list, optional): List of integers specifying the sizes of hidden layers for the decoder. Default is [1024]. 125 | z_dim (int, optional): Latent space dimensionality. Default is 40. 126 | train_val_split_ratio (float, optional): Ratio of the training data to the entire dataset. Default is 0.8 (80% training, 20% validation). 127 | random_state (int, optional): Seed for reproducible splits. Default is 42. 128 | batch_size (int, optional): Number of samples per training batch. Default is 3,500. 129 | epochs (int, optional): Total number of epochs for training. Default is 2,000. 130 | freeze_epochs (int, optional): Number of epochs during which the mask factor is gradually applied to enforce prior structure. Default is 1,500. 131 | learning_rate (float, optional): Learning rate for the optimizer. Default is 2.5e-4. 132 | alpha_start (float, optional): Initial value of `alpha`, controlling the weight of initialized TF activities. Default is 0. 133 | alpha_max (float, optional): Maximum value of `alpha` during training. Default is 0.7. 134 | alpha_scale (float, optional): Scaling factor for `alpha` during training. Default is 0.06. 135 | beta_start (float, optional): Initial value of `beta`, controlling the weight of the KL divergence term in the loss. Default is 0. 136 | beta_max (float, optional): Maximum value of `beta` during training. Default is 0.4. 137 | gamma_start (float, optional): Initial value of `gamma`, controlling the L1 regularization of TF-target weights. Default is 0. 138 | gamma_max (float, optional): Maximum value of `gamma` during training. Default is 2.6. 139 | log_interval (int, optional): Frequency (in epochs) for logging training progress. Default is 500. 140 | early_stopping_patience (int, optional): Number of epochs to wait for improvement before stopping training early. Default is 350. 141 | min_targets (int, optional): Minimum number of target genes per TF to retain in the GRN prior. Default is 10. 142 | min_TFs (int, optional): Minimum number of TFs per target gene to retain in the GRN prior. Default is 1. 143 | device (str, optional): Device for computation ('cuda' or 'cpu'). If not specified, defaults to 'cuda' if available. 144 | return_outputs (bool, optional): If True, returns both the trained model and the modified AnnData object with embeddings and reconstructions. Default is True. 145 | verbose (bool, optional): If True, enable detailed logging. Default is True. 146 | 147 | Returns: 148 | model (scRNA_VAE): Trained scRNA_VAE model instance. 149 | Optional: 150 | rna_data (AnnData): Updated AnnData object with the following fields: 151 | - `obsm["latent"]`: Latent space embedding (cells x latent dimensions). 152 | - `obsm["TF"]`: Inferred transcription factor activities (cells x TFs) from the layer before the last layer of the model. 153 | - `obsm["recon_X"]`: Reconstructed RNA expression (cells x genes). 154 | - `layers["RNA"]`: Original RNA expression data (cells x genes). 155 | - `uns["GRN_prior"]`: Dictionary containing: 156 | - `"matrix"`: The GRN prior weight matrix (TFs x genes). 157 | - `"TF_names"`: List of transcription factor names. 158 | - `"gene_names"`: List of gene names. 159 | - `uns["GRN_posterior"]`: Dictionary containing: 160 | - `"matrix"`: The GRN posterior weight matrix (TFs x genes). 161 | - `"TF_names"`: List of transcription factor names. 162 | - `"gene_names"`: List of gene names. 163 | 164 | Notes: 165 | - The TF activities stored in `rna_data.obsm["TF"]` reflect the model's inferred activities, not the raw ULM estimates. 166 | - Raw ULM estimates are stored as `rna_data.obsm["ulm_estimate"]`. 167 | - `rna_data.uns["GRN_prior"]` contains the aligned GRN matrix fed to the model as a prior. 168 | - `rna_data.uns["GRN_posterior"]` contains the inferred GRN matrix learned by the model. 169 | - Training halts early if the loss stops improving for `early_stopping_patience` epochs after `alpha` reaches its maximum value. 170 | - Ensure that `net` and `rna_data` overlap in their genes to avoid errors during initialization. 171 | 172 | Examples: 173 | >>> trained_model, updated_data = train_model( 174 | >>> rna_data=my_rna_data, 175 | >>> net=my_grn 176 | >>> ) 177 | >>> GRN_prior = updated_data.uns["GRN_prior"]["matrix"] 178 | >>> GRN_posterior = updated_data.uns["GRN_posterior"]["matrix"] 179 | """ 180 | if not (0 < train_val_split_ratio < 1): 181 | raise ValueError("train_val_split_ratio must be between 0 and 1.") 182 | if encode_dims is None: 183 | encode_dims = [512] 184 | if decode_dims is None: 185 | decode_dims = [1024] 186 | 187 | set_random_seed(seed=random_state) 188 | rna_data = rna_data.copy() 189 | 190 | # Set logging verbosity, control verbosity of train_model logger only. 191 | if verbose: 192 | train_logger.setLevel(logging.INFO) 193 | else: 194 | train_logger.setLevel(logging.WARNING) 195 | 196 | # Auto-detect device if not specified 197 | if device is None: 198 | device = "cuda" if torch.cuda.is_available() else "cpu" 199 | 200 | # Determine batch size dynamically if not provided 201 | if batch_size is None: 202 | batch_size = int(train_val_split_ratio * rna_data.n_obs) 203 | train_logger.info(f"No batch size provided. Using default batch size equal to the number of training samples: {batch_size}") 204 | else: 205 | train_logger.info(f"Using provided batch size: {batch_size}") 206 | 207 | train_logger.info("=" * 40) 208 | train_logger.info("Starting scRegulate TF inference Training Pipeline") 209 | train_logger.info("=" * 40) 210 | 211 | # Adapt prior and prepare data 212 | W_prior, gene_names, TF_names = adapt_prior_and_data(rna_data, net, min_targets, min_TFs) 213 | 214 | # Split into training and validation sets 215 | train_logger.info(f"Splitting data with train-validation split ratio={train_val_split_ratio}") 216 | n_cells = rna_data.n_obs 217 | all_indices = np.arange(n_cells) 218 | train_indices = np.random.choice(all_indices, size=int(train_val_split_ratio * n_cells), replace=False) 219 | val_indices = np.setdiff1d(all_indices, train_indices) 220 | 221 | # Use the indices to reorder cell names 222 | train_cell_names = rna_data.obs.index[train_indices] 223 | val_cell_names = rna_data.obs.index[val_indices] 224 | shuffled_cell_names = np.concatenate([train_cell_names, val_cell_names]) 225 | 226 | # Align metadata to the reordering applied to the data 227 | rna_data = rna_data.copy() 228 | rna_data = rna_data[shuffled_cell_names, :] # This reorders both .obs and .X 229 | 230 | # Validation step: Ensure alignment 231 | assert np.array_equal(rna_data.obs.index, shuffled_cell_names), "Mismatch in obs index and shuffled cell names" 232 | 233 | # Run ULM once for the entire dataset 234 | train_logger.info("Running ULM...") 235 | ulm_start = time.time() 236 | ulm.run_ulm(rna_data, net=net, min_n=min_targets, batch_size=batch_size, source="source", target="target", weight="weight", verbose=verbose) 237 | train_logger.info(f"ULM completed in {time.time() - ulm_start:.2f}s") 238 | ulm_estimates = torch.tensor(rna_data.obsm["ulm_estimate"].loc[:, TF_names].to_numpy(), dtype=torch.float32) 239 | 240 | # Split ULM estimates 241 | ulm_estimates_train = ulm_estimates[train_indices, :] 242 | ulm_estimates_val = ulm_estimates[val_indices, :] 243 | 244 | # Prepare data for training and validation 245 | gene_indices = [rna_data.var.index.get_loc(gene) for gene in gene_names] 246 | scRNA_train = torch.tensor( 247 | rna_data.X[train_indices][:, gene_indices].todense() if not isinstance(rna_data.X, np.ndarray) else rna_data.X[train_indices][:, gene_indices], 248 | dtype=torch.float32 249 | ) 250 | scRNA_val = torch.tensor( 251 | rna_data.X[val_indices][:, gene_indices].todense() if not isinstance(rna_data.X, np.ndarray) else rna_data.X[val_indices][:, gene_indices], 252 | dtype=torch.float32 253 | ) 254 | 255 | # Validate ULM estimates 256 | train_logger.info("Validating ULM estimates...") 257 | # Validate ULM estimates 258 | if ulm_estimates_train.isnan().any() or ulm_estimates_val.isnan().any(): 259 | train_logger.error("ULM estimates contain NaN values. Check data preprocessing.") 260 | raise ValueError("Invalid ULM estimates.") 261 | train_logger.info(f"ULM estimates validation passed. Shape: {ulm_estimates.shape}") 262 | 263 | # Transfer data to device 264 | train_logger.info(f"Transferring data to device {device}") 265 | scRNA_train, scRNA_val = scRNA_train.to(device), scRNA_val.to(device) 266 | ulm_estimates_train, ulm_estimates_val = ulm_estimates_train.to(device), ulm_estimates_val.to(device) 267 | W_prior = W_prior.to(device) 268 | train_logger.info("Transfer complete...") 269 | 270 | # Clear memory after transfer 271 | clear_memory(device) 272 | 273 | # Initialize model 274 | model = scRNA_VAE(scRNA_train.shape[1], encode_dims, decode_dims, z_dim, W_prior.shape[1]).to(device) 275 | model.W_prior = W_prior 276 | model.gene_names = gene_names 277 | 278 | # Optimizer and scheduler 279 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 280 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=200, factor=0.5, min_lr=1e-6) 281 | mask = (W_prior == 0).float() 282 | 283 | best_val_loss, epochs_no_improve = float("inf"), 0 284 | batch_size = int(batch_size) 285 | train_loader = create_dataloader(scRNA_train, batch_size=batch_size) 286 | val_loader = create_dataloader(scRNA_val, batch_size=batch_size) 287 | start_time = time.time() 288 | 289 | alpha, beta, gamma = alpha_start, beta_start, gamma_start # Initialize parameters 290 | 291 | # Training loop 292 | for epoch in range(epochs): 293 | model.train() 294 | total_loss, total_samples = 0, 0 295 | mask_factor = schedule_mask_factor(epoch, freeze_epochs) 296 | 297 | for batch_idx, batch in enumerate(train_loader): 298 | optimizer.zero_grad() 299 | start_idx = batch_idx * batch_size 300 | end_idx = start_idx + batch.size(0) 301 | tf_activity_slice = ulm_estimates_train[start_idx:end_idx] 302 | if tf_activity_slice.size(0) != batch.size(0): 303 | continue 304 | 305 | # Forward pass and loss calculation 306 | recon_batch, mu, logvar = model(batch, tf_activity_init=tf_activity_slice, alpha=alpha) 307 | MSE, KLD = loss_function(recon_batch, batch, mu, logvar) 308 | loss = MSE + beta * KLD + gamma * torch.sum(torch.abs(model.tf_mapping.weight) * mask) 309 | loss.backward() 310 | 311 | # Apply gradient masking and clipping 312 | if epoch < freeze_epochs: 313 | apply_gradient_mask(model.tf_mapping, mask, mask_factor) 314 | clip_gradients(model) 315 | optimizer.step() 316 | 317 | total_loss += loss.detach() #* len(batch) 318 | total_samples += len(batch) 319 | 320 | avg_train_loss = compute_average_loss(total_loss, total_samples) 321 | 322 | # Evaluate on validation set 323 | model.eval() 324 | val_loss, val_samples = 0, 0 325 | with torch.no_grad(): 326 | for val_batch_idx, val_batch in enumerate(val_loader): 327 | start_idx = val_batch_idx * batch_size 328 | end_idx = start_idx + val_batch.size(0) 329 | tf_val_slice = ulm_estimates_val[start_idx:end_idx] 330 | if tf_val_slice.size(0) != val_batch.size(0): 331 | continue 332 | 333 | recon_val, mu_val, logvar_val = model(val_batch, tf_activity_init=tf_val_slice, alpha=alpha) 334 | MSE_val, KLD_val = loss_function(recon_val, val_batch, mu_val, logvar_val) 335 | loss_val = MSE_val + beta * KLD_val + gamma * torch.sum(torch.abs(model.tf_mapping.weight) * mask) 336 | val_loss += loss_val.detach() #* len(val_batch) 337 | val_samples += len(val_batch) 338 | 339 | avg_val_loss = compute_average_loss(val_loss, val_samples) 340 | 341 | # Update parameters based on validation loss improvement 342 | if avg_val_loss <= best_val_loss: 343 | best_val_loss = avg_val_loss # Update the best validation loss 344 | alpha = alpha + (alpha_max - alpha_start) * alpha_scale 345 | beta = schedule_parameter(epoch, beta_start, beta_max, epochs) 346 | gamma = schedule_parameter(epoch, gamma_start, gamma_max, epochs) 347 | 348 | 349 | 350 | alpha = min(alpha, alpha_max) 351 | 352 | scheduler.step(avg_val_loss) 353 | 354 | # Clear memory after each epoch 355 | clear_memory(device) 356 | 357 | # Check early stopping 358 | if early_stopping_patience is not None: 359 | if avg_val_loss < best_val_loss: 360 | best_val_loss = avg_val_loss 361 | epochs_no_improve = 0 362 | elif alpha == 1.0: 363 | epochs_no_improve += 1 364 | 365 | if epochs_no_improve >= early_stopping_patience: 366 | train_logger.info(f"Early stopping at epoch {epoch + 1}") 367 | break 368 | 369 | # Log progress 370 | if epoch % log_interval == 0: 371 | n_features = scRNA_train.shape[1] 372 | train_loss_per_sample = avg_train_loss / (len(train_indices) * n_features) 373 | val_loss_per_sample = avg_val_loss / (len(val_indices)* n_features) 374 | train_logger.info( 375 | f"Epoch {epoch+1}: Avg Train Loss = {train_loss_per_sample:.4f}, Avg Val Loss = {val_loss_per_sample:.4f}, " 376 | f"Alpha = {alpha:.4f}, Beta = {beta:.4f}, Gamma = {gamma:.4f}, Mask Factor: {mask_factor:.2f}" 377 | ) 378 | 379 | # End timing and log final metrics 380 | total_time = time.time() - start_time 381 | train_logger.info("Training completed in %.2fs", total_time) 382 | 383 | clear_memory(device) 384 | # Add latent space, TF activities, reconstructed RNA, and original RNA to AnnData 385 | if return_outputs: 386 | latent_space = [] 387 | TF_space = [] 388 | reconstructed_rna = [] 389 | model.eval() 390 | 391 | # Use the model's batch encoding to extract relevant spaces 392 | # Encode all samples (train + val) together 393 | # Encode all samples (train + val) together 394 | latent_space = np.concatenate([ 395 | model.encodeBatch(loader, device=device, out='z') for loader in [train_loader, val_loader] 396 | ], axis=0) 397 | # Reorder latent embeddings to match RNA order 398 | latent_space = latent_space[np.argsort(np.concatenate([train_indices, val_indices])), :] 399 | 400 | 401 | TF_space = np.concatenate([ 402 | model.encodeBatch(loader, device=device, out='tf') for loader in [train_loader, val_loader] 403 | ], axis=0) 404 | 405 | # Reorder TF embeddings to match RNA order 406 | TF_space = TF_space[np.argsort(np.concatenate([train_indices, val_indices])), :] 407 | 408 | 409 | reconstructed_rna = np.concatenate([ 410 | model.encodeBatch(loader, device=device, out='x') for loader in [train_loader, val_loader] 411 | ], axis=0) 412 | 413 | # Reorder reconstructed_rna embeddings to match RNA order 414 | reconstructed_rna = reconstructed_rna[np.argsort(np.concatenate([train_indices, val_indices])), :] 415 | 416 | 417 | # Extract posterior weight matrix 418 | W_posterior = model.tf_mapping.weight.detach().cpu().numpy() 419 | 420 | # Reorder obs based on concatenated train and validation indices 421 | meta_obs = rna_data.obs.copy() 422 | 423 | rna_modality = sc.AnnData( 424 | X=rna_data.X, # Original RNA expression 425 | obs=meta_obs, 426 | var=rna_data.var.copy(), 427 | obsm=rna_data.obsm.copy(), 428 | uns=rna_data.uns.copy(), 429 | layers=rna_data.layers.copy() if hasattr(rna_data, "layers") else None 430 | ) 431 | 432 | rna_modality.uns["type"] = "RNA" 433 | 434 | tf_modality = sc.AnnData( 435 | X=TF_space, 436 | obs=meta_obs, 437 | var=pd.DataFrame(index=TF_names) # Index of transcription factors 438 | ) 439 | tf_modality.uns["type"] = "TF" 440 | 441 | recon_rna_modality = sc.AnnData( 442 | X=reconstructed_rna, 443 | obs=meta_obs, 444 | var=pd.DataFrame(index=gene_names) # Use aligned gene names 445 | ) 446 | recon_rna_modality.uns["type"] = "Reconstructed RNA" 447 | 448 | latent_modality = sc.AnnData( 449 | X=latent_space, 450 | obs=meta_obs, 451 | var=pd.DataFrame(index=[f"latent_{i}" for i in range(latent_space.shape[1])]) 452 | ) 453 | latent_modality.uns["type"] = "Latent Space" 454 | 455 | # Add modalities to the parent AnnData object 456 | rna_data.modality = { 457 | "RNA": rna_modality, 458 | "TF": tf_modality, 459 | "recon_RNA": recon_rna_modality, 460 | "latent": latent_modality 461 | } 462 | 463 | rna_data.uns["main_modality"] = "RNA" # Set default modality 464 | rna_data.uns["current_modality"] = "RNA" # Track current modality 465 | train_logger.info(f"Default modality set to: {rna_data.uns['main_modality']}") 466 | train_logger.info(f"Current modality: {rna_data.uns['current_modality']}") 467 | 468 | 469 | # Store GRN_prior in .uns 470 | rna_data.uns["GRN_prior"] = { 471 | "matrix": W_prior.cpu().numpy(), 472 | "TF_names": TF_names, 473 | "gene_names": gene_names, 474 | } 475 | 476 | # Store normalized GRN_posterior in .uns 477 | rna_data.uns["GRN_posterior"] = { 478 | "matrix": MinMaxScaler().fit_transform(model.tf_mapping.weight.detach().cpu().numpy()), 479 | "TF_names": TF_names, 480 | "gene_names": gene_names, 481 | } 482 | 483 | GRN_posterior = pd.DataFrame( 484 | rna_data.uns['GRN_posterior']['matrix'].copy(), 485 | index=rna_data.uns['GRN_posterior']['gene_names'], 486 | columns=rna_data.uns['GRN_posterior']['TF_names'] 487 | ) 488 | 489 | train_logger.info("`GRN_prior` and `GRN_posterior` stored in the AnnData object under .uns") 490 | 491 | # Final logging 492 | train_logger.info("=" * 40) 493 | train_logger.info("[FINAL SUMMARY]") 494 | train_logger.info(f"Training stopped after {epoch + 1} epochs.") 495 | train_logger.info(f"Final Train Loss: {avg_train_loss:.4f}") 496 | train_logger.info(f"Final Valid Loss: {avg_val_loss:.4f}") 497 | train_logger.info(f"Final Alpha: {alpha:.4f}, Beta: {beta:.4f}, Gamma: {gamma:.4f}") 498 | train_logger.info(f"Total Training Time: {total_time:.2f}s") 499 | 500 | # Retrieve and log shapes from modalities 501 | train_logger.info(f"Latent Space Shape: {rna_data.modality['latent'].X.shape}") 502 | train_logger.info(f"TF Space Shape: {rna_data.modality['TF'].X.shape}") 503 | train_logger.info(f"Reconstructed RNA Shape: {rna_data.modality['recon_RNA'].X.shape}") 504 | train_logger.info(f"Original RNA Shape: {rna_data.modality['RNA'].X.shape}") 505 | 506 | train_logger.info(f"TFs: {len(TF_names)}, Genes: {rna_data.modality['RNA'].X.shape[1]}") 507 | train_logger.info("=" * 40) 508 | 509 | 510 | return model, rna_data, GRN_posterior 511 | 512 | return model 513 | -------------------------------------------------------------------------------- /scregulate/train_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def schedule_parameter(epoch, start_value, max_value, total_epochs, scale=0.6): 5 | """ 6 | Schedules a parameter (e.g., alpha, beta, gamma) to grow from `start_value` to `max_value` 7 | over a given fraction of the training epochs. 8 | 9 | Args: 10 | epoch: Current epoch. 11 | start_value: Starting value of the parameter. 12 | max_value: Maximum value of the parameter. 13 | total_epochs: Total number of training epochs. 14 | scale: Fraction of epochs over which the parameter reaches `max_value`. 15 | 16 | Returns: 17 | Scheduled parameter value for the current epoch. 18 | """ 19 | scaled_epochs = total_epochs * scale 20 | if epoch >= scaled_epochs: 21 | return max_value 22 | return start_value + (max_value - start_value) * (epoch / scaled_epochs) 23 | 24 | 25 | def schedule_mask_factor(epoch, freeze_epochs): 26 | """ 27 | Computes the mask factor using a sigmoidal schedule. 28 | 29 | Args: 30 | epoch: Current epoch. 31 | freeze_epochs: Epoch at which the mask fully applies. 32 | 33 | Returns: 34 | Mask factor value for the current epoch. 35 | """ 36 | return 1.0 / (1.0 + np.exp((epoch - freeze_epochs // 2) / (freeze_epochs // 20))) 37 | 38 | 39 | def apply_gradient_mask(linear_layer, mask_float, mask_factor): 40 | """ 41 | Applies a gradient mask to a layer's weights. 42 | 43 | Args: 44 | linear_layer: Linear layer whose gradients are masked. 45 | mask_float: Precomputed mask (e.g., `W_prior == 0`). 46 | mask_factor: Factor controlling the strength of the mask. 47 | """ 48 | with torch.no_grad(): 49 | linear_layer.weight.grad.mul_(1 - mask_factor * mask_float) 50 | 51 | 52 | def compute_average_loss(total_loss, total_samples): 53 | """ 54 | Computes the average loss per sample. 55 | 56 | Args: 57 | total_loss: Total accumulated loss over an epoch. 58 | total_samples: Total number of samples in the epoch. 59 | 60 | Returns: 61 | Average loss per sample. 62 | """ 63 | return total_loss.item() / total_samples 64 | 65 | 66 | def clip_gradients(model, max_norm=0.5): 67 | """ 68 | Clips gradients of the model to prevent exploding gradients. 69 | 70 | Args: 71 | model: PyTorch model. 72 | max_norm: Maximum allowed norm for gradients. 73 | """ 74 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 75 | -------------------------------------------------------------------------------- /scregulate/ulm_standalone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standalone Implementation of the Univariate Linear Model (ULM). 3 | 4 | This implementation is based on the DecoupleR package's `run_ulm` function but 5 | has been adapted to work independently without any external dependencies. It 6 | calculates transcription factor activity scores from a gene expression matrix 7 | and a regulatory network. 8 | 9 | Citation: 10 | - For the original ULM method, please cite DecoupleR: 11 | Badia-i-Mompel, P., Wessels, L., & Reinders, M. (2021). DecoupleR: a computational framework 12 | to infer molecular activities from omics data. *Bioinformatics*. 13 | """ 14 | 15 | import numpy as np 16 | import pandas as pd 17 | from scipy.sparse import csr_matrix 18 | from scipy.stats import t 19 | from tqdm.auto import tqdm 20 | import logging 21 | import warnings 22 | 23 | logger = logging.getLogger(__name__) 24 | from .utils import set_random_seed 25 | set_random_seed() 26 | # --- Helper Functions --- 27 | 28 | def mat_cov(A, b): 29 | return np.dot(b.T - b.mean(), A - A.mean(axis=0)) / (b.shape[0] - 1) 30 | 31 | def mat_cor(A, b): 32 | cov = mat_cov(A, b) 33 | ssd = np.std(A, axis=0, ddof=1) * np.std(b, axis=0, ddof=1).reshape(-1, 1) 34 | return cov / ssd 35 | 36 | def t_val(r, df): 37 | return r * np.sqrt(df / ((1.0 - r + 1.0e-16) * (1.0 + r + 1.0e-16))) 38 | 39 | def filt_min_n(c, net, min_n=5): 40 | msk = np.isin(net['target'].values.astype('U'), c) 41 | net = net.iloc[msk] 42 | 43 | sources, counts = np.unique(net['source'].values.astype('U'), return_counts=True) 44 | msk = np.isin(net['source'].values.astype('U'), sources[counts >= min_n]) 45 | 46 | net = net[msk] 47 | if net.shape[0] == 0: 48 | raise ValueError(f"No sources with more than min_n={min_n} targets.") 49 | return net 50 | 51 | def match(c, r, net): 52 | regX = np.zeros((c.shape[0], net.shape[1]), dtype=np.float32) 53 | c_dict = {gene: i for i, gene in enumerate(c)} 54 | idxs = [c_dict[gene] for gene in r if gene in c_dict] 55 | regX[idxs, :] = net[: len(idxs), :] 56 | return regX 57 | 58 | def rename_net(net, source="source", target="target", weight="weight"): 59 | assert source in net.columns, f"Column '{source}' not found in net." 60 | assert target in net.columns, f"Column '{target}' not found in net." 61 | if weight is not None: 62 | assert weight in net.columns, f"Column '{weight}' not found in net." 63 | 64 | net = net.rename(columns={source: "source", target: "target", weight: "weight"}) 65 | net = net.reindex(columns=["source", "target", "weight"]) 66 | 67 | if net.duplicated(["source", "target"]).sum() > 0: 68 | raise ValueError("net contains repeated edges.") 69 | return net 70 | 71 | def get_net_mat(net): 72 | X = net.pivot(columns="source", index="target", values="weight") 73 | X[np.isnan(X)] = 0 74 | sources = X.columns.values 75 | targets = X.index.values 76 | X = X.values 77 | return sources.astype("U"), targets.astype("U"), X.astype(np.float32) 78 | 79 | 80 | # --- Main Functionality --- 81 | 82 | def ulm(mat, net, batch_size): 83 | logger.debug("Starting ULM calculation...") 84 | n_samples = mat.shape[0] 85 | n_features, n_fsets = net.shape 86 | df = n_features - 2 # Degrees of freedom 87 | 88 | if isinstance(mat, csr_matrix): 89 | logger.debug("Matrix is sparse, processing in batches...") 90 | n_batches = int(np.ceil(n_samples / batch_size)) 91 | es = np.zeros((n_samples, n_fsets), dtype=np.float32) 92 | # The progress bar is only shown if logger level <= INFO 93 | show_progress = logger.getEffectiveLevel() <= logging.INFO 94 | for i in tqdm(range(n_batches), disable=not show_progress): 95 | start, end = i * batch_size, i * batch_size + batch_size 96 | batch = mat[start:end].toarray().T 97 | r = mat_cor(net, batch) 98 | es[start:end] = t_val(r, df) 99 | else: 100 | logger.debug("Matrix is dense, processing at once...") 101 | r = mat_cor(net, mat.T) 102 | es = t_val(r, df) 103 | 104 | pv = t.sf(abs(es), df) * 2 105 | logger.debug("ULM calculation complete.") 106 | return es, pv 107 | 108 | 109 | def run_ulm(adata, net, batch_size, source="source", target="target", weight="weight", min_n=5, verbose=True): 110 | """ 111 | Run ULM on a Scanpy AnnData object and store results in `.obsm`. 112 | 113 | Args: 114 | adata: AnnData object with gene expression data. 115 | net: DataFrame with columns [source, target, weight]. 116 | source (str): Name of the TF column in net. 117 | target (str): Name of the target gene column in net. 118 | weight (str): Name of the weight column in net. 119 | batch_size (int): Batch size for ULM if data is large. 120 | min_n (int): Minimum number of targets per TF. 121 | verbose (bool): If True, set logging level to INFO, else WARNING. 122 | 123 | Returns: 124 | adata with ULM results in adata.obsm["ulm_estimate"] and adata.obsm["ulm_pvals"]. 125 | """ 126 | # Set logger level based on verbose 127 | logger.setLevel(logging.INFO if verbose else logging.WARNING) 128 | 129 | logger.info("Initializing ULM analysis...") 130 | logger.info("Extracting gene expression data...") 131 | mat = adata.to_df() 132 | genes = adata.var_names 133 | 134 | logger.info("Preparing the regulatory network...") 135 | net = rename_net(net, source=source, target=target, weight=weight) 136 | net = filt_min_n(genes, net, min_n=min_n) 137 | 138 | logger.info("Matching genes in the network to gene expression data...") 139 | sources, targets, net_mat = get_net_mat(net) 140 | net_mat = match(genes, targets, net_mat) 141 | 142 | logger.info(f"ULM parameters: {mat.shape[0]} samples, {len(genes)} genes, {net_mat.shape[1]} TFs.") 143 | logger.info("Calculating ULM estimates and p-values...") 144 | estimate, pvals = ulm(mat.values, net_mat, batch_size=batch_size) 145 | 146 | logger.info("Processing ULM results...") 147 | ulm_estimate = pd.DataFrame(estimate, index=mat.index, columns=sources) 148 | ulm_pvals = pd.DataFrame(pvals, index=mat.index, columns=sources) 149 | 150 | logger.info("Storing ULM results in AnnData object...") 151 | with warnings.catch_warnings(): 152 | warnings.simplefilter("ignore", category=UserWarning) # Suppress UserWarnings 153 | adata.obsm["ulm_estimate"] = np.maximum(ulm_estimate, 0) 154 | adata.obsm["ulm_pvals"] = ulm_pvals 155 | 156 | #adata.obsm["ulm_estimate"] = np.maximum(ulm_estimate, 0) 157 | #adata.obsm["ulm_pvals"] = ulm_pvals 158 | 159 | logger.info("ULM analysis complete.") 160 | return adata 161 | -------------------------------------------------------------------------------- /scregulate/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from torch.utils.data import DataLoader 6 | import gc 7 | import random 8 | 9 | # ---------- Random Seed Utility ---------- 10 | def set_random_seed(seed=42): 11 | """ 12 | Sets the random seed for reproducibility across multiple libraries. 13 | 14 | Args: 15 | seed (int): Random seed value to set. 16 | """ 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | torch.backends.cudnn.deterministic = True # Ensures deterministic behavior 24 | torch.backends.cudnn.benchmark = False # Disables benchmarking to avoid randomness 25 | set_random_seed() 26 | 27 | # ---------- Data Loading/Handling ---------- 28 | def as_GTRD(net): 29 | W_prior_df = net.copy() 30 | collectri_format = W_prior_df.stack().reset_index() 31 | collectri_format.columns = ["source", "target", "weight"].copy() 32 | collectri_format = collectri_format[collectri_format["weight"] == 1] 33 | return collectri_format.reset_index(drop=True) 34 | 35 | def as_TF_Link(net): 36 | W_prior_df = net.copy() 37 | collectri_format = W_prior_df[['Name.TF', 'Name.Target']] 38 | collectri_format.columns = ["source", "target"] 39 | collectri_format = collectri_format.assign(weight=1) 40 | collectri_format = collectri_format.drop_duplicates() 41 | return collectri_format.reset_index(drop=True) 42 | 43 | def create_dataloader(scRNA_batch, batch_size, shuffle=False, num_workers=0, collate_fn=None): 44 | """ 45 | Creates a PyTorch DataLoader for the scRNA_batch data. 46 | 47 | Args: 48 | scRNA_batch (torch.Tensor): Input data for training. 49 | batch_size (int): Batch size for training. 50 | shuffle (bool): Whether to shuffle the data during loading. 51 | num_workers (int): Number of workers for data loading. 52 | collate_fn (callable, optional): Custom collate function for DataLoader. 53 | 54 | Returns: 55 | DataLoader: PyTorch DataLoader for the scRNA_batch data. 56 | """ 57 | return DataLoader(scRNA_batch, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 58 | collate_fn=collate_fn) # Pass custom collate_fn if provided) 59 | 60 | def clear_memory(device): 61 | """ 62 | Clears memory for the specified device. 63 | 64 | Args: 65 | device (str): The device to clear memory for ('cuda' or 'cpu'). 66 | """ 67 | if device == "cuda": 68 | torch.cuda.empty_cache() 69 | else: 70 | gc.collect() 71 | 72 | def extract_GRN(adata, matrix_type="posterior"): 73 | """ 74 | Extract the GRN matrix (either prior or posterior) from an AnnData object as a pandas DataFrame. 75 | 76 | Args: 77 | adata (AnnData): The AnnData object containing GRN information in `.uns`. 78 | matrix_type (str): The type of GRN matrix to extract, either 'prior' or 'posterior'. 79 | Defaults to 'posterior'. 80 | 81 | Returns: 82 | pd.DataFrame: A DataFrame where rows are genes and columns are transcription factors (TFs). 83 | """ 84 | # Map short names to full keys 85 | matrix_key_map = { 86 | "posterior": "GRN_posterior", 87 | "prior": "GRN_prior" 88 | } 89 | 90 | if matrix_type not in matrix_key_map: 91 | raise ValueError(f"Invalid matrix_type '{matrix_type}'. Choose from 'prior' or 'posterior'.") 92 | 93 | full_key = matrix_key_map[matrix_type] 94 | 95 | if full_key not in adata.uns: 96 | raise ValueError(f"{full_key} is not stored in the provided AnnData object.") 97 | 98 | grn_data = adata.uns[full_key] 99 | matrix = grn_data["matrix"] 100 | TF_names = grn_data["TF_names"] 101 | gene_names = grn_data["gene_names"] 102 | 103 | # Determine the orientation of the matrix 104 | if matrix.shape[0] == len(TF_names) and matrix.shape[1] == len(gene_names): 105 | # TFs as rows, genes as columns (needs transposition) 106 | GRN_df = pd.DataFrame(matrix.T, index=gene_names, columns=TF_names) 107 | elif matrix.shape[0] == len(gene_names) and matrix.shape[1] == len(TF_names): 108 | # Genes as rows, TFs as columns (already correct orientation) 109 | GRN_df = pd.DataFrame(matrix, index=gene_names, columns=TF_names) 110 | else: 111 | raise ValueError("Matrix dimensions do not match the lengths of `gene_names` and `TF_names`.") 112 | 113 | return GRN_df 114 | 115 | 116 | def to_torch_tensor(matrix, device="cpu"): 117 | """ 118 | Converts a dense or sparse matrix to a PyTorch tensor. 119 | 120 | Args: 121 | matrix: Numpy array or sparse matrix. 122 | device (str): Device to load the tensor onto. 123 | 124 | Returns: 125 | torch.Tensor: Dense or sparse tensor. 126 | """ 127 | if isinstance(matrix, np.ndarray): 128 | return torch.tensor(matrix, dtype=torch.float32).to(device) 129 | elif hasattr(matrix, "tocoo"): # Sparse matrix support 130 | coo = matrix.tocoo() 131 | indices = torch.LongTensor(np.vstack((coo.row, coo.col))) 132 | values = torch.FloatTensor(coo.data) 133 | shape = torch.Size(coo.shape) 134 | return torch.sparse_coo_tensor(indices, values, shape).to(device) 135 | else: 136 | raise TypeError("Unsupported matrix type for conversion to PyTorch tensor.") 137 | 138 | 139 | # ---------- Parameter Scheduling ---------- 140 | def schedule_parameter(epoch, start_value, max_value, total_epochs, scale=0.6): 141 | """ 142 | Schedules a parameter (e.g., alpha, beta, gamma) to grow from `start_value` to `max_value` 143 | over a given fraction of the training epochs. 144 | 145 | Args: 146 | epoch: Current epoch. 147 | start_value: Starting value of the parameter. 148 | max_value: Maximum value of the parameter. 149 | total_epochs: Total number of training epochs. 150 | scale: Fraction of epochs over which the parameter reaches `max_value`. 151 | 152 | Returns: 153 | Scheduled parameter value for the current epoch. 154 | """ 155 | scaled_epochs = total_epochs * scale 156 | if epoch >= scaled_epochs: 157 | return max_value 158 | return start_value + (max_value - start_value) * (epoch / scaled_epochs) 159 | 160 | 161 | def schedule_mask_factor(epoch, freeze_epochs): 162 | """ 163 | Computes the mask factor using a sigmoid function. 164 | 165 | Args: 166 | epoch (int): Current epoch. 167 | freeze_epochs (int): Epoch at which the mask factor reaches its midpoint. 168 | 169 | Returns: 170 | float: Mask factor for the current epoch. 171 | """ 172 | return 1.0 / (1.0 + np.exp((epoch - freeze_epochs // 2) / (freeze_epochs // 20))) 173 | 174 | 175 | # ---------- Gradient Operations ---------- 176 | def apply_gradient_mask(layer, mask, mask_factor): 177 | """ 178 | Applies a gradient mask to a layer's weights. 179 | 180 | Args: 181 | layer (torch.nn.Linear): Linear layer whose gradients need masking. 182 | mask (torch.Tensor): Binary mask indicating where gradients should be scaled. 183 | mask_factor (float): Scaling factor for the mask. 184 | 185 | Returns: 186 | None 187 | """ 188 | with torch.no_grad(): 189 | layer.weight.grad.mul_(1 - mask_factor * mask) 190 | 191 | 192 | def clip_gradients(model, max_norm=0.5): 193 | """ 194 | Clips the gradients of the model parameters. 195 | 196 | Args: 197 | model (torch.nn.Module): PyTorch model. 198 | max_norm (float): Maximum norm for gradient clipping. 199 | 200 | Returns: 201 | None 202 | """ 203 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm) 204 | 205 | 206 | # ---------- Loss Utility ---------- 207 | def compute_average_loss(total_loss, total_samples): 208 | """ 209 | Computes the average loss per sample. 210 | 211 | Args: 212 | total_loss (torch.Tensor): Total accumulated loss. 213 | total_samples (int): Total number of samples processed. 214 | 215 | Returns: 216 | float: Average loss per sample. 217 | """ 218 | return total_loss.item() / total_samples 219 | 220 | 221 | # ---------- Modality Enhancement ---------- 222 | def set_active_modality(adata, modality_name): 223 | """ 224 | Sets the active modality in the AnnData object. 225 | 226 | Args: 227 | adata (AnnData): Parent AnnData object containing multiple modalities. 228 | modality_name (str): The name of the modality to activate. 229 | 230 | Returns: 231 | AnnData: A new AnnData object with the specified modality set as the active one. 232 | """ 233 | # Validate input modality 234 | if modality_name not in adata.modality: 235 | raise ValueError(f"Modality '{modality_name}' not found in AnnData.modality.") 236 | 237 | # Retrieve selected modality 238 | selected_modality = adata.modality[modality_name] 239 | 240 | # Create a new AnnData object with the selected modality 241 | temp_adata = sc.AnnData( 242 | X=selected_modality.X, 243 | obs=selected_modality.obs.copy(), 244 | var=selected_modality.var.copy(), 245 | obsm=selected_modality.obsm.copy(), 246 | uns=adata.uns.copy(), # Retain shared metadata in `uns` 247 | ) 248 | 249 | # Retain modality-specific `.obsm` entries 250 | temp_adata.obsm.update(adata.obsm) 251 | 252 | # Attach all modalities back to the new AnnData object 253 | temp_adata.modality = adata.modality 254 | 255 | return temp_adata 256 | 257 | def extract_modality(adata, modality_name): 258 | """ 259 | Extract a specific modality as an independent AnnData object. 260 | 261 | Args: 262 | adata (AnnData): The AnnData object containing multiple modalities. 263 | modality_name (str): The name of the modality to extract. 264 | 265 | Returns: 266 | AnnData: The requested modality as a standalone AnnData object. 267 | """ 268 | # Ensure the modality exists 269 | if not hasattr(adata, "modality") or modality_name not in adata.modality: 270 | raise ValueError(f"Modality '{modality_name}' not found in the AnnData object.") 271 | 272 | # Extract the requested modality 273 | selected_modality = adata.modality[modality_name] 274 | return sc.AnnData( 275 | X=selected_modality.X.copy(), 276 | obs=selected_modality.obs.copy(), 277 | var=selected_modality.var.copy(), 278 | obsm=selected_modality.obsm.copy(), 279 | uns=selected_modality.uns.copy(), 280 | layers=selected_modality.layers.copy() if hasattr(selected_modality, "layers") else None 281 | ) 282 | -------------------------------------------------------------------------------- /scregulate/vae_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # Define the VAE model 6 | class scRNA_VAE(nn.Module): 7 | """ 8 | Variational Autoencoder (VAE) for scRNA-seq data. 9 | 10 | Args: 11 | input_dim (int): Number of input features (genes). 12 | encode_dims (list): List of hidden layer sizes for the encoder. 13 | decode_dims (list): List of hidden layer sizes for the decoder. 14 | z_dim (int): Size of the latent space. 15 | tf_dim (int): Size of the transcription factor space. 16 | ulm_init (torch.Tensor, optional): Pre-initialized TF activity matrix. 17 | """ 18 | 19 | def __init__(self, input_dim, encode_dims, decode_dims, z_dim, tf_dim, ulm_init=None): 20 | super(scRNA_VAE, self).__init__() 21 | 22 | # Encoder 23 | self.encoder_layers = nn.ModuleList() 24 | previous_dim = input_dim 25 | for h_dim in encode_dims: 26 | self.encoder_layers.append(nn.Linear(previous_dim, h_dim)) 27 | previous_dim = h_dim 28 | self.fc_mu = nn.Linear(previous_dim, z_dim) 29 | self.fc_logvar = nn.Linear(previous_dim, z_dim) 30 | 31 | # Decoder 32 | self.decoder_layers = nn.ModuleList() 33 | previous_dim = z_dim 34 | for h_dim in decode_dims: 35 | self.decoder_layers.append(nn.Linear(previous_dim, h_dim)) 36 | previous_dim = h_dim 37 | self.fc_output = nn.Linear(previous_dim, tf_dim) 38 | self.tf_mapping = nn.Linear(tf_dim, input_dim) 39 | self.ulm_init = ulm_init 40 | 41 | 42 | def encode(self, x): 43 | for layer in self.encoder_layers: 44 | x = F.relu(layer(x)) 45 | mu = self.fc_mu(x) 46 | logvar = self.fc_logvar(x) 47 | return mu, logvar 48 | 49 | def reparameterize(self, mu, logvar): 50 | std = torch.exp(0.5 * logvar) 51 | eps = torch.randn_like(std) 52 | return mu + eps * std 53 | 54 | def decode(self, z): 55 | for layer in self.decoder_layers: 56 | z = F.relu(layer(z)) 57 | return F.relu(self.fc_output(z)) # Was torch.exp() 58 | 59 | def forward(self, x, tf_activity_init=None, alpha=1.0): 60 | mu, logvar = self.encode(x) 61 | z = self.reparameterize(mu, logvar) 62 | 63 | if tf_activity_init is not None: 64 | # Use a weighted combination of pre-initialized and learned TF activities 65 | tf_activity = (1 - alpha) * tf_activity_init.detach() + alpha * self.decode(z) 66 | else: 67 | # If no pre-initialized TF activities, just use the decoded values 68 | tf_activity = self.decode(z) 69 | 70 | recon_x = self.tf_mapping(tf_activity) 71 | return recon_x, mu, logvar 72 | 73 | 74 | def encodeBatch(self, dataloader, device='cuda', out='z'): 75 | output = [] 76 | for batch in dataloader: 77 | batch = batch.to(device) 78 | mu, logvar = self.encode(batch) 79 | z = self.reparameterize(mu, logvar) 80 | 81 | if out == 'z': 82 | output.append(z.detach().cpu()) 83 | elif out == 'x': 84 | recon_x = self.tf_mapping(self.decode(z)) 85 | output.append(recon_x.detach().cpu()) 86 | elif out == 'tf': 87 | tf_activity = self.decode(z) 88 | output.append(tf_activity.detach().cpu()) 89 | 90 | output = torch.cat(output).numpy() 91 | return output 92 | 93 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = scRegulate 3 | version = 0.1.0 4 | author = Mehrdad Zandigohar 5 | author_email = mehr.zgohar@gmail.com 6 | description = Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | license = MIT 10 | url = https://github.com/YDaiLab/scRegulate 11 | project_urls = 12 | Documentation = https://github.com/YDaiLab/scRegulate#readme 13 | Issue Tracker = https://github.com/YDaiLab/scRegulate/issues 14 | Paper (bioRxiv) = https://doi.org/10.1101/2025.04.17.649372 15 | 16 | [options] 17 | packages = find: 18 | python_requires = >=3.8 19 | install_requires = 20 | torch>=2.0 21 | numpy>=1.23 22 | scanpy>=1.9 23 | anndata>=0.8 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="scRegulate", 5 | version="0.1.0", 6 | author="Mehrdad Zandigohar", 7 | author_email="mehr.zgohar@gmail.com", 8 | description="Python Toolkit for Transcription Factor Activity Inference and Clustering of scRNA-seq Data", 9 | long_description=open("README.md", encoding="utf-8").read(), 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/YDaiLab/scRegulate", 12 | project_urls={ 13 | "Documentation": "https://github.com/YDaiLab/scRegulate#readme", 14 | "Issue Tracker": "https://github.com/YDaiLab/scRegulate/issues", 15 | "Paper (bioRxiv)": "https://doi.org/10.1101/2025.04.17.649372" 16 | }, 17 | license="MIT", 18 | packages=find_packages(), 19 | python_requires=">=3.8", 20 | install_requires=[ 21 | "torch>=2.0", 22 | "numpy>=1.23", 23 | "scanpy>=1.9", 24 | "anndata>=0.8" 25 | ], 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | "Operating System :: OS Independent", 30 | ], 31 | ) 32 | --------------------------------------------------------------------------------