├── .codecov.yaml ├── .editorconfig ├── .github ├── actions │ └── setup │ │ └── action.yaml └── workflows │ ├── build.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── data ├── gene_annotations │ ├── gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz │ └── human_mouse_gene_orthologs.csv └── gene_programs │ └── metabolite_enzyme_sensor_gps │ ├── human_metabolite_enzymes.tsv │ ├── human_metabolite_sensors.tsv │ ├── mouse_metabolite_enzymes.tsv │ └── mouse_metabolite_sensors.tsv ├── docs ├── Makefile ├── _static │ ├── .gitkeep │ ├── css │ │ └── custom.css │ ├── nichecompass_fig1.png │ ├── nichecompass_logo.svg │ └── nichecompass_logo_readme.png ├── _templates │ ├── .gitkeep │ └── autosummary │ │ └── class.rst ├── api │ ├── developer.md │ ├── index.md │ └── user.md ├── conf.py ├── contributing.md ├── extensions │ └── typed_returns.py ├── index.md ├── installation.md ├── references.bib ├── references.md ├── release_notes │ └── index.md ├── requirements.txt ├── tutorials │ ├── index.md │ ├── multimodal_tutorials.md │ ├── notebooks │ │ ├── mouse_brain_multimodal.ipynb │ │ ├── mouse_cns_sample_integration.ipynb │ │ ├── mouse_cns_single_sample.ipynb │ │ └── mouse_cns_spatial_reference_mapping.ipynb │ ├── sample_integration_tutorials.md │ ├── single_sample_tutorials.md │ └── spatial_reference_mapping_tutorials.md └── user_guide │ └── index.md ├── environment.yaml ├── pyproject.toml ├── src └── nichecompass │ ├── __init__.py │ ├── benchmarking │ ├── __init__.py │ ├── cas.py │ ├── clisis.py │ ├── gcs.py │ ├── metrics.py │ ├── mlami.py │ ├── nasw.py │ └── utils.py │ ├── data │ ├── __init__.py │ ├── dataloaders.py │ ├── dataprocessors.py │ ├── datareaders.py │ ├── datasets.py │ └── utils.py │ ├── models │ ├── __init__.py │ ├── basemodelmixin.py │ ├── nichecompass.py │ └── utils.py │ ├── modules │ ├── __init__.py │ ├── basemodulemixin.py │ ├── losses.py │ ├── vgaemodulemixin.py │ └── vgpgae.py │ ├── nn │ ├── __init__.py │ ├── aggregators.py │ ├── decoders.py │ ├── encoders.py │ ├── layercomponents.py │ ├── layers.py │ └── utils.py │ ├── train │ ├── __init__.py │ ├── basetrainermixin.py │ ├── metrics.py │ ├── trainer.py │ └── utils.py │ └── utils │ ├── __init__.py │ ├── analysis.py │ ├── gene_programs.py │ ├── multimodal_mapping.py │ └── utils.py └── tests └── test_basic.py /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/actions/setup/action.yaml: -------------------------------------------------------------------------------- 1 | name: Setup 2 | 3 | inputs: 4 | python-version: 5 | required: false 6 | default: '3.9' 7 | torch-version: 8 | required: false 9 | default: '2.0.0' 10 | cuda-version: 11 | required: false 12 | default: cpu 13 | full_install: 14 | required: false 15 | default: true 16 | 17 | runs: 18 | using: composite 19 | 20 | steps: 21 | - name: Set up Python ${{ inputs.python-version }} 22 | uses: actions/setup-python@v4.3.0 23 | with: 24 | python-version: ${{ inputs.python-version }} 25 | check-latest: true 26 | cache: pip 27 | cache-dependency-path: | 28 | pyproject.toml 29 | 30 | - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }} 31 | run: | 32 | pip install torch==${{ inputs.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }} 33 | python -c "import torch; print('PyTorch:', torch.__version__)" 34 | python -c "import torch; print('CUDA available:', torch.cuda.is_available())" 35 | python -c "import torch; print('CUDA:', torch.version.cuda)" 36 | shell: bash 37 | 38 | - name: Install extension packages 39 | if: ${{ inputs.full_install == 'true' }} 40 | run: | 41 | pip install scipy 42 | pip install --no-index --upgrade torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html 43 | shell: bash 44 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | package: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.9" 22 | cache: "pip" 23 | cache-dependency-path: "**/pyproject.toml" 24 | - name: Install build dependencies 25 | run: python -m pip install --upgrade pip wheel twine build 26 | - name: Build package 27 | run: python -m build 28 | - name: Check package 29 | run: twine check --strict dist/*.whl 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ 8 | jobs: 9 | release: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/nichecompass/ 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.9" 25 | cache: "pip" 26 | - run: pip install build 27 | - run: python -m build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: "0 5 1,15 * *" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | runs-on: ${{ matrix.os }} 18 | defaults: 19 | run: 20 | shell: bash -e {0} # -e to fail on error 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | include: 26 | - os: ubuntu-latest 27 | python: "3.9" 28 | 29 | name: ${{ matrix.name }} Python ${{ matrix.python }} 30 | 31 | env: 32 | OS: ${{ matrix.os }} 33 | PYTHON: ${{ matrix.python }} 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | - name: Set up Python ${{ matrix.python }} 38 | uses: actions/setup-python@v5 39 | with: 40 | python-version: ${{ matrix.python }} 41 | cache: "pip" 42 | cache-dependency-path: "**/pyproject.toml" 43 | 44 | - name: Set up external packages 45 | if: steps.changed-files-specific.outputs.only_changed != 'true' 46 | uses: ./.github/actions/setup 47 | - name: Install test dependencies 48 | run: | 49 | python -m pip install --upgrade pip wheel 50 | - name: Install dependencies 51 | run: | 52 | pip install ${{ matrix.pip-flags }} ".[dev,tests]" 53 | - name: Test 54 | env: 55 | MPLBACKEND: agg 56 | PLATFORM: ${{ matrix.os }} 57 | DISPLAY: :42 58 | run: | 59 | coverage run -m pytest -v --color=yes 60 | - name: Report coverage 61 | run: | 62 | coverage report 63 | - name: Upload coverage 64 | uses: codecov/codecov-action@v3 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # DS_Store 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | venv/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # vscode 133 | .vscode/settings.json 134 | 135 | # data 136 | data/gene_programs/* 137 | data/spatial_omics/* 138 | !data/gene_programs/metabolite_enzyme_sensor_gps/ 139 | 140 | # artifacts 141 | artifacts 142 | 143 | # charliecloud image for HPC 144 | /env/nichecompass.tar.gz -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.6.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/PyCQA/flake8 7 | rev: 4.0.1 8 | hooks: 9 | - id: flake8 10 | - repo: https://github.com/pycqa/isort 11 | rev: 5.10.1 12 | hooks: 13 | - id: isort 14 | name: isort (python) 15 | additional_dependencies: [toml] 16 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.9" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . 16 | extra_requirements: 17 | - docsbuild 18 | submodules: 19 | include: [docs/tutorials/notebooks] 20 | recursive: true 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Sebastian Birk, Carlos Talavera-López, Mohammad Lotfollahi 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | nichecompass-logo 2 | 3 | [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://github.com/Lotfollahi-lab/nichecompass/blob/main/LICENSE) 4 | [![Stars](https://img.shields.io/github/stars/Lotfollahi-lab/nichecompass?logo=GitHub&color=yellow)](https://github.com/Lotfollahi-lab/nichecompass/stargazers) 5 | [![PyPI](https://img.shields.io/pypi/v/nichecompass.svg)](https://pypi.org/project/nichecompass) 6 | [![PyPIDownloads](https://static.pepy.tech/badge/nichecompass)](https://pepy.tech/project/nichecompass) 7 | [![Docs](https://readthedocs.org/projects/nichecompass/badge/?version=latest)](https://nichecompass.readthedocs.io/en/stable/?badge=stable) 8 | [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 9 | [![PyPI](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 10 | 11 | NicheCompass (**N**iche **I**dentification based on **C**ellular grap**H** **E**mbeddings of **COM**munication **P**rograms **A**ligned across **S**patial **S**amples) is a package for end-to-end analysis of spatial multi-omics data, including spatial atlas building, niche identification & characterization, cell-cell communication inference and spatial reference mapping. It is built on top of [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) and [AnnData](https://anndata.readthedocs.io/en/latest/). 12 | 13 | ## Resources 14 | - An installation guide, tutorials and API documentation is available in the [documentation](https://nichecompass.readthedocs.io/). 15 | - Please use [issues](https://github.com/Lotfollahi-lab/nichecompass/issues) to submit bug reports. 16 | - If you would like to contribute, check out the [contributing guide](https://nichecompass.readthedocs.io/en/latest/contributing.html). 17 | - If you find NicheCompass useful for your research, please consider citing the NicheCompass manuscript. 18 | 19 | ## Reference 20 | ``` 21 | @article{Birk2025, 22 | author = {Birk, S. and Bonafonte-Pard{\`a}s, I. and Feriz, A. M. and et al.}, 23 | title = {Quantitative characterization of cell niches in spatially resolved omics data}, 24 | journal = {Nature Genetics}, 25 | year = {2025}, 26 | doi = {10.1038/s41588-025-02120-6}, 27 | url = {https://doi.org/10.1038/s41588-025-02120-6} 28 | } 29 | ``` 30 | -------------------------------------------------------------------------------- /data/gene_annotations/gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/data/gene_annotations/gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ 2 | div.cell_output table.dataframe { 3 | font-size: 0.8em; 4 | } 5 | 6 | .wy-nav-side { 7 | background: #242335; 8 | } 9 | 10 | .wy-side-nav-search { 11 | background-color: #F3F4F7; 12 | } 13 | 14 | .icon.icon-home, 15 | .icon.icon-home .logo { 16 | color: #000000; /* Black color */ 17 | } 18 | -------------------------------------------------------------------------------- /docs/_static/nichecompass_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_static/nichecompass_fig1.png -------------------------------------------------------------------------------- /docs/_static/nichecompass_logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/nichecompass_logo_readme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_static/nichecompass_logo_readme.png -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/api/developer.md: -------------------------------------------------------------------------------- 1 | # Developer 2 | 3 | ## Benchmarking 4 | 5 | ```{eval-rst} 6 | .. module:: nichecompass.benchmarking 7 | .. currentmodule:: nichecompass 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | 12 | benchmarking.utils.compute_knn_graph_connectivities_and_distances 13 | ``` 14 | 15 | ## Data 16 | 17 | ```{eval-rst} 18 | .. module:: nichecompass.data 19 | .. currentmodule:: nichecompass 20 | 21 | .. autosummary:: 22 | :toctree: generated 23 | 24 | data.initialize_dataloaders 25 | data.edge_level_split 26 | data.node_level_split_mask 27 | data.prepare_data 28 | data.SpatialAnnTorchDataset 29 | ``` 30 | 31 | ## Models 32 | 33 | ```{eval-rst} 34 | .. module:: nichecompass.models 35 | .. currentmodule:: nichecompass 36 | 37 | .. autosummary:: 38 | :toctree: generated 39 | 40 | models.utils.load_saved_files 41 | models.utils.validate_var_names 42 | models.utils.initialize_model 43 | ``` 44 | 45 | ## Modules 46 | 47 | ```{eval-rst} 48 | .. module:: nichecompass.modules 49 | .. currentmodule:: nichecompass 50 | 51 | .. autosummary:: 52 | :toctree: generated 53 | 54 | modules.VGPGAE 55 | modules.VGAEModuleMixin 56 | modules.BaseModuleMixin 57 | modules.compute_cat_covariates_contrastive_loss 58 | modules.compute_edge_recon_loss 59 | modules.compute_gp_group_lasso_reg_loss 60 | modules.compute_gp_l1_reg_loss 61 | modules.compute_kl_reg_loss 62 | modules.compute_omics_recon_nb_loss 63 | ``` 64 | 65 | ## NN 66 | 67 | ```{eval-rst} 68 | .. module:: nichecompass.nn 69 | .. currentmodule:: nichecompass 70 | 71 | .. autosummary:: 72 | :toctree: generated 73 | 74 | nn.OneHopAttentionNodeLabelAggregator 75 | nn.OneHopGCNNormNodeLabelAggregator 76 | nn.OneHopSumNodeLabelAggregator 77 | nn.CosineSimGraphDecoder 78 | nn.FCOmicsFeatureDecoder 79 | nn.MaskedOmicsFeatureDecoder 80 | nn.Encoder 81 | nn.MaskedLinear 82 | nn.AddOnMaskedLayer 83 | ``` 84 | 85 | ## Train 86 | 87 | ```{eval-rst} 88 | .. module:: nichecompass.train 89 | .. currentmodule:: nichecompass 90 | 91 | .. autosummary:: 92 | :toctree: generated 93 | 94 | train.Trainer 95 | train.eval_metrics 96 | train.plot_eval_metrics 97 | ``` 98 | -------------------------------------------------------------------------------- /docs/api/index.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | Import NicheCompass as: 4 | 5 | ``` 6 | import nichecompass 7 | ``` 8 | 9 | ```{toctree} 10 | :maxdepth: 2 11 | 12 | user 13 | developer 14 | ``` -------------------------------------------------------------------------------- /docs/api/user.md: -------------------------------------------------------------------------------- 1 | # User 2 | 3 | ## Benchmarking 4 | 5 | ```{eval-rst} 6 | .. module:: nichecompass.benchmarking 7 | .. currentmodule:: nichecompass 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | 12 | benchmarking.compute_benchmarking_metrics 13 | benchmarking.compute_cas 14 | benchmarking.compute_clisis 15 | benchmarking.compute_gcs 16 | benchmarking.compute_mlami 17 | benchmarking.compute_nasw 18 | ``` 19 | 20 | ## Data 21 | ```{eval-rst} 22 | .. module:: nichecompass.data 23 | .. currentmodule:: nichecompass 24 | 25 | .. autosummary:: 26 | :toctree: generated 27 | 28 | data.load_spatial_adata_from_csv 29 | ``` 30 | 31 | ## Models 32 | 33 | ```{eval-rst} 34 | .. module:: nichecompass.models 35 | .. currentmodule:: nichecompass 36 | 37 | .. autosummary:: 38 | :toctree: generated 39 | 40 | models.NicheCompass 41 | ``` 42 | 43 | ## Gene Program Utilities 44 | 45 | ```{eval-rst} 46 | .. module:: nichecompass.utils 47 | .. currentmodule:: nichecompass 48 | 49 | .. autosummary:: 50 | :toctree: generated 51 | 52 | utils.add_gps_from_gp_dict_to_adata 53 | utils.add_multimodal_mask_to_adata 54 | utils.extract_gp_dict_from_collectri_tf_network 55 | utils.extract_gp_dict_from_nichenet_lrt_interactions 56 | utils.extract_gp_dict_from_mebocost_es_interactions 57 | utils.extract_gp_dict_from_omnipath_lr_interactions 58 | utils.filter_and_combine_gp_dict_gps 59 | utils.get_gene_annotations 60 | utils.generate_multimodal_mapping_dict 61 | utils.get_unique_genes_from_gp_dict 62 | utils.generate_enriched_gp_info_plots 63 | ``` 64 | 65 | ## Cell-Cell Communication Utilities 66 | 67 | ```{eval-rst} 68 | .. module:: nichecompass.utils 69 | .. currentmodule:: nichecompass 70 | 71 | .. autosummary:: 72 | :toctree: generated 73 | 74 | utils.compute_communication_gp_network 75 | utils.visualize_communication_gp_network 76 | ``` 77 | 78 | ## Other Utilities 79 | ```{eval-rst} 80 | .. module:: nichecompass.utils 81 | .. currentmodule:: nichecompass 82 | 83 | .. autosummary:: 84 | :toctree: generated 85 | 86 | utils.create_new_color_dict 87 | ``` 88 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | import sys 9 | from datetime import datetime 10 | from importlib.metadata import metadata 11 | from pathlib import Path 12 | 13 | HERE = Path(__file__).parent 14 | sys.path.insert(0, str(HERE / "extensions")) 15 | 16 | # -- Project information ----------------------------------------------------- 17 | 18 | # NOTE: If you installed your project in editable mode, this might be stale. 19 | # If this is the case, reinstall it to refresh the metadata 20 | info = metadata("nichecompass") 21 | project_name = info.get("Name", "NicheCompass") 22 | project = project_name 23 | author = info.get("Author", "Sebastian Birk") 24 | copyright = f"{datetime.now():%Y}, {author}." 25 | version = info["Version"] 26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL")) 27 | repository_url = urls["Source"] 28 | 29 | # The full version, including alpha/beta/rc tags 30 | release = info["Version"] 31 | 32 | bibtex_bibfiles = ["references.bib"] 33 | templates_path = ["_templates"] 34 | nitpicky = True # Warn about broken links 35 | needs_sphinx = "4.0" 36 | 37 | html_context = { 38 | "display_github": True, # Integrate GitHub 39 | "github_user": "Lotfollahi-lab", # Username 40 | "github_repo": project_name, # Repo name 41 | "github_version": "main", # Version 42 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 43 | } 44 | 45 | # -- General configuration --------------------------------------------------- 46 | 47 | # Add any Sphinx extension module names here, as strings. 48 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 49 | extensions = [ 50 | "myst_nb", 51 | "sphinx_copybutton", 52 | "sphinx.ext.autodoc", 53 | "sphinx.ext.intersphinx", 54 | "sphinx.ext.autosummary", 55 | "sphinx.ext.extlinks", 56 | "sphinx.ext.napoleon", 57 | "sphinxcontrib.bibtex", 58 | "sphinx_autodoc_typehints", 59 | "sphinx.ext.mathjax", 60 | "sphinx_design", 61 | "IPython.sphinxext.ipython_console_highlighting", 62 | "sphinxext.opengraph", 63 | "hoverxref.extension", 64 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 65 | ] 66 | 67 | autosummary_generate = True 68 | autodoc_member_order = "groupwise" 69 | default_role = "literal" 70 | napoleon_google_docstring = False 71 | napoleon_numpy_docstring = True 72 | napoleon_include_init_with_doc = False 73 | napoleon_use_rtype = True # having a separate entry generally helps readability 74 | napoleon_use_param = True 75 | myst_heading_anchors = 6 # create anchors for h1-h6 76 | myst_enable_extensions = [ 77 | "amsmath", 78 | "colon_fence", 79 | "deflist", 80 | "dollarmath", 81 | "html_image", 82 | "html_admonition", 83 | ] 84 | myst_url_schemes = ("http", "https", "mailto") 85 | nb_output_stderr = "remove" 86 | nb_execution_mode = "off" 87 | nb_merge_streams = True 88 | typehints_defaults = "braces" 89 | 90 | source_suffix = { 91 | ".rst": "restructuredtext", 92 | ".ipynb": "myst-nb", 93 | ".myst": "myst-nb", 94 | } 95 | 96 | intersphinx_mapping = { 97 | "python": ("https://docs.python.org/3", None), 98 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 99 | "numpy": ("https://numpy.org/doc/stable/", None), 100 | } 101 | 102 | # List of patterns, relative to source directory, that match files and 103 | # directories to ignore when looking for source files. 104 | # This pattern also affects html_static_path and html_extra_path. 105 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 106 | 107 | # -- Options for HTML output ------------------------------------------------- 108 | 109 | # The theme to use for HTML and HTML Help pages. See the documentation for 110 | # a list of builtin themes. 111 | # 112 | html_theme = "sphinx_rtd_theme" 113 | html_static_path = ["_static"] 114 | html_css_files = ["css/custom.css"] 115 | 116 | html_title = "NicheCompass" 117 | html_logo = "_static/nichecompass_logo.svg" 118 | 119 | html_theme_options = { 120 | "repository_url": repository_url, 121 | "use_repository_button": True, 122 | "path_to_docs": "docs/", 123 | "navigation_with_keys": False, 124 | } 125 | 126 | pygments_style = "default" 127 | 128 | nitpick_ignore = [ 129 | # If building the documentation fails because of a missing link that is outside your control, 130 | # you can add an exception to this list. 131 | # ("py:class", "igraph.Graph"), 132 | ] 133 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This will be added shortly. 4 | 5 | 6 | 7 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 8 | [codecov]: https://about.codecov.io/sign-up/ 9 | [codecov docs]: https://docs.codecov.com/docs 10 | [codecov bot]: https://docs.codecov.com/docs/team-bot 11 | [codecov app]: https://github.com/apps/codecov 12 | [pre-commit.ci]: https://pre-commit.ci/ 13 | [readthedocs.org]: https://readthedocs.org/ 14 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 15 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 16 | [pre-commit]: https://pre-commit.com/ 17 | [pytest]: https://docs.pytest.org/ 18 | [semver]: https://semver.org/ 19 | [sphinx]: https://www.sphinx-doc.org/en/master/ 20 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 21 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 22 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 23 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 24 | [pypi]: https://pypi.org/ 25 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | from __future__ import annotations 4 | 5 | import re 6 | from collections.abc import Generator, Iterable 7 | 8 | from sphinx.application import Sphinx 9 | from sphinx.ext.napoleon import NumpyDocstring 10 | 11 | 12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: 13 | for line in lines: 14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 21 | lines_raw = self._dedent(self._consume_to_next_section()) 22 | if lines_raw[0] == ":": 23 | del lines_raw[0] 24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw))) 25 | if lines and lines[-1]: 26 | lines.append("") 27 | return lines 28 | 29 | 30 | def setup(app: Sphinx): 31 | """Set app.""" 32 | NumpyDocstring._parse_returns_section = _parse_returns_section 33 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | NicheCompass (**N**iche **I**dentification based on **C**ellular grap**H** **E**mbeddings of **COM**munication **P**rograms **A**ligned across **S**patial **S**amples) is a package for end-to-end analysis of spatial multi-omics data, including spatial atlas building, niche identification & characterization, cell-cell communication inference and spatial reference mapping. It is built on top of [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) and [AnnData](https://anndata.readthedocs.io/en/latest/). 4 | The package is developed and maintained by the [Lotfollahi Lab](https://github.com/Lotfollahi-lab) at the Wellcome Sanger Institute. 5 | 6 | ::::{grid} 1 2 3 3 7 | :gutter: 2 8 | 9 | :::{grid-item-card} Installation {octicon}`plug;1em;` 10 | :link: installation 11 | :link-type: doc 12 | 13 | Check out the installation guide. 14 | ::: 15 | 16 | :::{grid-item-card} Tutorials {octicon}`play;1em;` 17 | :link: tutorials/index 18 | :link-type: doc 19 | 20 | Learn by following example applications of NicheCompass. 21 | ::: 22 | 23 | :::{grid-item-card} User Guide {octicon}`book;1em;` 24 | :link: user_guide/index 25 | :link-type: doc 26 | 27 | Review good practices for training NicheCompass models on your own data. 28 | ::: 29 | 30 | :::{grid-item-card} API {octicon}`info;1em;` 31 | :link: api/index 32 | :link-type: doc 33 | 34 | Find a detailed description of NicheCompass APIs. 35 | ::: 36 | 37 | :::{grid-item-card} Release Notes {octicon}`tag;1em;` 38 | :link: release_notes/index 39 | :link-type: doc 40 | 41 | Follow the latest changes to NicheCompass. 42 | ::: 43 | 44 | :::{grid-item-card} Contributing {octicon}`code;1em;` 45 | :link: contributing 46 | :link-type: doc 47 | 48 | Help improve NicheCompass. 49 | ::: 50 | :::: 51 | 52 | If you find NicheCompass useful for your research, please consider citing the NicheCompass manuscript. 53 | 54 | ```{toctree} 55 | :hidden: true 56 | :maxdepth: 3 57 | :titlesonly: true 58 | 59 | installation 60 | tutorials/index 61 | api/index 62 | release_notes/index 63 | contributing.md 64 | references.md 65 | ``` 66 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | NicheCompass is available for Python 3.9. It does yet not support Apple silicon. 4 | 5 | 6 | We do not recommend installation on your system Python. Please set up a virtual 7 | environment, e.g. via conda through the [Mambaforge] distribution, or create a 8 | [Docker] image. 9 | 10 | ## Additional Libraries 11 | 12 | To use NicheCompass, you need to install some external libraries. These include: 13 | - [PyTorch] 14 | - [PyTorch Scatter] 15 | - [PyTorch Sparse] 16 | - [bedtools] 17 | 18 | We recommend to install the PyTorch libraries with GPU support. If you have 19 | CUDA, this can be done as: 20 | 21 | ``` 22 | pip install torch==${TORCH}+${CUDA} --extra-index-url https://download.pytorch.org/whl/${CUDA} 23 | pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html 24 | ``` 25 | where `${TORCH}` and `${CUDA}` should be replaced by the specific PyTorch and 26 | CUDA versions, respectively. 27 | 28 | For example, for PyTorch 2.0.0 and CUDA 11.7, type: 29 | ``` 30 | pip install torch==2.0.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 31 | pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.0.0+cu117.html 32 | ``` 33 | 34 | To install bedtools, you can use conda: 35 | ``` 36 | conda install bedtools=2.30.0 -c bioconda 37 | ``` 38 | 39 | Alternatively, we have provided a conda environment file with all required 40 | external libraries, which you can use as: 41 | ``` 42 | conda env create -f environment.yaml 43 | ``` 44 | 45 | To enable GPU support for JAX, after the installation run: 46 | ``` 47 | pip install jaxlib==0.3.25+cuda${CUDA}.cudnn${CUDNN} -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 48 | ``` 49 | 50 | For example, for CUDA 11.7, type: 51 | ``` 52 | pip install jaxlib==0.4.7+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 53 | ``` 54 | 55 | 56 | ## Installation via PyPi 57 | 58 | Subsequently, install NicheCompass via pip: 59 | ``` 60 | pip install nichecompass 61 | ``` 62 | 63 | Install optional dependencies required for benchmarking, multimodal analysis, running tutorials etc. with: 64 | ``` 65 | pip install nichecompass[all] 66 | ``` 67 | 68 | [Mambaforge]: https://github.com/conda-forge/miniforge 69 | [Docker]: https://www.docker.com 70 | [PyTorch]: http://pytorch.org 71 | [PyTorch Scatter]: https://github.com/rusty1s/pytorch_scatter 72 | [PyTorch Sparse]: https://github.com/rusty1s/pytorch_sparse 73 | [bedtools]: https://bedtools.readthedocs.io 74 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Virshup_2023, 2 | doi = {10.1038/s41587-023-01733-8}, 3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, 4 | year = 2023, 5 | month = {apr}, 6 | publisher = {Springer Science and Business Media {LLC}}, 7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis}, 8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, 9 | journal = {Nature Biotechnology} 10 | } -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/release_notes/index.md: -------------------------------------------------------------------------------- 1 | # Release notes 2 | 3 | All notable changes to this project will be documented in this file. The format 4 | is based on [keep a changelog], and this project adheres to 5 | [Semantic Versioning]. Full commit history is available in the [commit logs]. 6 | 7 | ### 0.2.3 (17.02.2025) 8 | 9 | - Added numpy<2 dependency as version upgrade of NumPy to major version 2 breaks required scanpy version. 10 | [@sebastianbirk] 11 | 12 | ### 0.2.2 (09.01.2025) 13 | 14 | - Synced repository with Zenodo to mint DOI for publication. 15 | [@sebastianbirk] 16 | 17 | ### 0.2.1 (15.10.2024) 18 | 19 | - Added a user guide to the package documentation. 20 | [@sebastianbirk] 21 | 22 | ### 0.2.0 (22.08.2024) 23 | 24 | - Fixed a bug in the configuration of random seeds. 25 | - Fixed a bug in the definition of MEBOCOST prior gene programs. 26 | - Raised the default quality filter threshold for the retrieval of OmniPath gene programs. 27 | - Modified the GP filtering logic to combine GPs with the same source genes and drop GPs that do not have source genes if they have a specified overlap in target genes with other GPs. 28 | - Changed the default hyperparameters for model training based on ablation experiments. 29 | [@sebastianbirk] 30 | 31 | ### 0.1.2 (13.02.2024) 32 | 33 | - The version was incremented due to package upload requirements. 34 | [@sebastianbirk] 35 | 36 | ### 0.1.1 (13.02.2024) 37 | 38 | - The version was incremented due to package upload requirements. 39 | [@sebastianbirk] 40 | 41 | ### 0.1.0 (13.02.2024) 42 | 43 | - First NicheCompass version. 44 | [@sebastianbirk] 45 | 46 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 47 | [Semantic Versioning]: https://semver.org/spec/v2.0.0.html 48 | [commit logs]: https://github.com/Lotfollahi-lab/nichecompass/commits 49 | [@sebastianbirk]: https://github.com/sebastianbirk 50 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | https://download.pytorch.org/whl/cpu/torch-1.13.0%2Bcpu-cp39-cp39-linux_x86_64.whl 2 | https://data.pyg.org/whl/torch-1.13.0%2Bcpu/torch_sparse-0.6.17%2Bpt113cpu-cp39-cp39-linux_x86_64.whl 3 | https://data.pyg.org/whl/torch-1.13.0%2Bcpu/torch_scatter-2.1.1%2Bpt113cpu-cp39-cp39-linux_x86_64.whl 4 | -------------------------------------------------------------------------------- /docs/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | Get started with NicheCompass by following our tutorials. 4 | 5 | ```{toctree} 6 | :maxdepth: 2 7 | 8 | single_sample_tutorials 9 | sample_integration_tutorials 10 | spatial_reference_mapping_tutorials 11 | multimodal_tutorials 12 | ``` -------------------------------------------------------------------------------- /docs/tutorials/multimodal_tutorials.md: -------------------------------------------------------------------------------- 1 | # Multimodal Tutorials 2 | 3 | ```{toctree} 4 | :maxdepth: 1 5 | 6 | notebooks/mouse_brain_multimodal 7 | ``` -------------------------------------------------------------------------------- /docs/tutorials/sample_integration_tutorials.md: -------------------------------------------------------------------------------- 1 | # Sample Integration Tutorials 2 | ```{toctree} 3 | :maxdepth: 1 4 | 5 | notebooks/mouse_cns_sample_integration 6 | ``` -------------------------------------------------------------------------------- /docs/tutorials/single_sample_tutorials.md: -------------------------------------------------------------------------------- 1 | # Single Sample Tutorials 2 | 3 | ```{toctree} 4 | :maxdepth: 1 5 | 6 | notebooks/mouse_cns_single_sample 7 | ``` -------------------------------------------------------------------------------- /docs/tutorials/spatial_reference_mapping_tutorials.md: -------------------------------------------------------------------------------- 1 | # Spatial Reference Mapping Tutorials 2 | 3 | ```{toctree} 4 | :maxdepth: 1 5 | 6 | notebooks/mouse_cns_spatial_reference_mapping 7 | ``` -------------------------------------------------------------------------------- /docs/user_guide/index.md: -------------------------------------------------------------------------------- 1 | # User guide 2 | 3 | ## Hyperparameter selection 4 | 5 | We conducted various ablation experiments on both simulated and real spatial transcriptomics data to evaluate important model design choices and hyperparameters. The detailed results and interpretations can be found in the NicheCompass manuscript. 6 | 7 | In summary, we recommend the following: 8 | - Regarding the loss, we observed that finding a balance between gene expression and edge reconstruction is a key element for good niche identification (NID) and GP recovery (GPR) performance, while regularization of de novo GP weights is essential for GPR. The loss weights have been specified in the NicheCompass package accordingly and in most cases do not need to be changed by the user. 9 | - With respect to size of the KNN neighborhood graph, a smaller number of neighbors is more efficient in NID and de novo GP detection while a larger number of neighbors can facilitate GPR of prior GPs. Here, we recommend users to specify a neighborhood size based on the expected range of interactions in the tissue. Empirically, we observed a neighborhood size between 4 and 12 to work well. 10 | - The inclusion of de novo GPs is crucial for GPR. However, the number of de novo GPs should not be too high as important genes might be split across multiple GPs. The default in the NicheCompass package is 100 and we do not recommend to increase this number. 11 | - GP pruning can slighlty improve GPR and NID while reducing the embedding size of the model; we therefore recommend users to use the default setting of weak GP pruning. 12 | - We recommend that users define prior GPs solely based on the biology that they are interested in (as opposed to including as many prior GPs as possible). 13 | - We recommend users to use a GATv2 encoder layer (as opposed to a GCNConv encoder layer) unless performance is a bottleneck or niche characterization is not a priority and the data has single-cell resolution. 14 | - Since the use of prior GPs can significantly improve NID compared to a scenario without prior GPs, we recommend users to use the default set of prior GPs even if interpretability is not a main objective. -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: nichecompass 2 | channels: 3 | - pytorch 4 | - pyg 5 | - nvidia 6 | - bioconda 7 | dependencies: 8 | - python=3.9 9 | - pip 10 | - pytorch=2.0.0 11 | - pytorch-scatter=2.1.1 12 | - pytorch-sparse=0.6.17 13 | - pytorch-cuda=11.7 14 | - bedtools=2.30.0 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | name = "nichecompass" 7 | version = "0.2.3" 8 | description = "End-to-end analysis of spatial multi-omics data" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {name = "Sebastian Birk"}, 14 | ] 15 | maintainers = [ 16 | {name = "Sebastian Birk", email = "sebastian.birk@outlook.com"}, 17 | ] 18 | urls.Documentation = "https://nichecompass.readthedocs.io/" 19 | urls.Source = "https://github.com/Lotfollahi-lab/nichecompass" 20 | urls.Home-page = "https://github.com/Lotfollahi-lab/nichecompass" 21 | classifiers = [ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Science/Research", 24 | "Natural Language :: English", 25 | "Programming Language :: Python :: 3.9", 26 | "Operating System :: POSIX :: Linux", 27 | "Topic :: Scientific/Engineering :: Bio-Informatics", 28 | ] 29 | dependencies = [ 30 | "mlflow>=1.28.0", 31 | "pyreadr>=0.4.6", 32 | "scanpy>=1.9.3,<1.10.0", 33 | "torch-geometric>=2.2.0", 34 | "omnipath>=1.0.7", 35 | "decoupler>=1.4.0", 36 | "numpy<2", 37 | ] 38 | optional-dependencies.dev = [ 39 | "pre-commit", 40 | "twine>=4.0.2", 41 | ] 42 | optional-dependencies.docs = [ 43 | "docutils>=0.8,!=0.18.*,!=0.19.*", 44 | "sphinx>=4.1", 45 | "sphinx-book-theme>=1.0.0", 46 | "myst-nb", 47 | "myst-parser", 48 | "sphinxcontrib-bibtex>=1.0.0", 49 | "sphinx-autodoc-typehints", 50 | "sphinx_rtd_theme", 51 | "sphinxext-opengraph", 52 | "sphinx-copybutton", 53 | "sphinx-design", 54 | "sphinx-hoverxref", 55 | # For notebooks 56 | "ipykernel", 57 | "ipython", 58 | "pandas", 59 | ] 60 | optional-dependencies.docsbuild = [ 61 | "nichecompass[docs,benchmarking,multimodal]" 62 | ] # docs build dependencies 63 | optional-dependencies.tests = [ 64 | "pytest", 65 | "coverage", 66 | ] 67 | optional-dependencies.benchmarking = [ 68 | "scib-metrics>=0.3.3", 69 | "pynndescent>=0.5.8", 70 | "scikit-misc>=0.3.0", 71 | "squidpy>=1.2.2", 72 | "jax==0.4.7", 73 | "jaxlib==0.4.7" 74 | ] 75 | optional-dependencies.multimodal = [ 76 | "scglue>=0.3.2", 77 | ] 78 | optional-dependencies.tutorials = [ 79 | "jupyter", 80 | ] 81 | all = ["nichecompass[dev,docs,tests,benchmarking,multimodal,tutorials]"] 82 | 83 | [tool.coverage.run] 84 | source = ["nichecompass"] 85 | omit = [ 86 | "**/test_*.py", 87 | ] 88 | 89 | [tool.pytest.ini_options] 90 | testpaths = ["tests"] 91 | xfail_strict = true 92 | addopts = [ 93 | "--import-mode=importlib", # allow using test files with same name 94 | ] 95 | 96 | [tool.ruff] 97 | line-length = 120 98 | src = ["src"] 99 | extend-include = ["*.ipynb"] 100 | 101 | [tool.ruff.format] 102 | docstring-code-format = true 103 | 104 | [tool.ruff.lint] 105 | select = [ 106 | "F", # Errors detected by Pyflakes 107 | "E", # Error detected by Pycodestyle 108 | "W", # Warning detected by Pycodestyle 109 | "I", # isort 110 | "D", # pydocstyle 111 | "B", # flake8-bugbear 112 | "TID", # flake8-tidy-imports 113 | "C4", # flake8-comprehensions 114 | "BLE", # flake8-blind-except 115 | "UP", # pyupgrade 116 | "RUF100", # Report unused noqa directives 117 | ] 118 | ignore = [ 119 | # line too long -> we accept long comment lines; formatter gets rid of long code lines 120 | "E501", 121 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 122 | "E731", 123 | # allow I, O, l as variable names -> I is the identity matrix 124 | "E741", 125 | # Missing docstring in public package 126 | "D104", 127 | # Missing docstring in public module 128 | "D100", 129 | # Missing docstring in __init__ 130 | "D107", 131 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 132 | "B008", 133 | # __magic__ methods are are often self-explanatory, allow missing docstrings 134 | "D105", 135 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 136 | "D400", 137 | # First line should be in imperative mood; try rephrasing 138 | "D401", 139 | ## Disable one in each pair of mutually incompatible rules 140 | # We don’t want a blank line before a class docstring 141 | "D203", 142 | # We want docstrings to start immediately after the opening triple quote 143 | "D213", 144 | ] 145 | 146 | [tool.ruff.lint.pydocstyle] 147 | convention = "numpy" 148 | 149 | [tool.ruff.lint.per-file-ignores] 150 | "docs/*" = ["I"] 151 | "tests/*" = ["D"] 152 | "*/__init__.py" = ["F401"] 153 | 154 | [tool.cruft] 155 | skip = [ 156 | "tests", 157 | "src/**/__init__.py", 158 | "src/**/basic.py", 159 | "docs/api.md", 160 | "docs/changelog.md", 161 | "docs/references.bib", 162 | "docs/references.md", 163 | "docs/notebooks/example.ipynb", 164 | ] 165 | -------------------------------------------------------------------------------- /src/nichecompass/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | from . import data, models, modules, nn, train, utils 4 | 5 | __all__ = ["data", "models", "modules", "nn", "train", "utils"] 6 | 7 | __version__ = version("nichecompass") -------------------------------------------------------------------------------- /src/nichecompass/benchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | # This is a trick to make jax use the right cudnn version (needs to be executed 2 | # before importing scanpy) 3 | #import jax.numpy as jnp 4 | #temp_array = jnp.array([1, 2, 3]) 5 | #temp_idx = jnp.array([1]) 6 | #temp_array[temp_idx] 7 | 8 | from .cas import compute_avg_cas, compute_cas 9 | from .clisis import compute_clisis 10 | from .gcs import compute_avg_gcs, compute_gcs 11 | from .metrics import compute_benchmarking_metrics 12 | from .mlami import compute_mlami 13 | from .nasw import compute_nasw 14 | 15 | __all__ = ["compute_avg_cas", 16 | "compute_avg_gcs", 17 | "compute_benchmarking_metrics", 18 | "compute_cas", 19 | "compute_clisis", 20 | "compute_gcs", 21 | "compute_mlami", 22 | "compute_nasw"] -------------------------------------------------------------------------------- /src/nichecompass/benchmarking/clisis.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the Cell Type Local Inverse Simpson's Index Similarity 3 | (CLISIS) benchmark for testing how accurately the latent nearest neighbor graph 4 | preserves local neighborhood cell type heterogeneity from the spatial (physical) 5 | nearest neighbor graph. It is a measure for local cell type neighborhood 6 | preservation. 7 | """ 8 | 9 | from typing import Optional 10 | 11 | import numpy as np 12 | from anndata import AnnData 13 | from scib_metrics import lisi_knn 14 | 15 | from .utils import compute_knn_graph_connectivities_and_distances 16 | 17 | 18 | def compute_clisis( 19 | adata: AnnData, 20 | cell_type_key: str="cell_type", 21 | batch_key: Optional[str]=None, 22 | spatial_knng_key: str="spatial_knng", 23 | latent_knng_key: str="nichecompass_latent_knng", 24 | spatial_key: Optional[str]="spatial", 25 | latent_key: Optional[str]="nichecompass_latent", 26 | n_neighbors: Optional[int]=90, 27 | n_jobs: int=1, 28 | seed: int=0) -> float: 29 | """ 30 | Compute the Cell Type Local Inverse Simpson's Index Similarity (CLISIS). The 31 | CLISIS measures how accurately the latent nearest neighbor graph preserves 32 | local neighborhood cell type heterogeneity from the spatial nearest neighbor 33 | graph. The CLISIS ranges between '0' and '1' with higher values indicating 34 | better local neighborhood cell type heterogeneity preservation. It is 35 | computed by first calculating the Cell Type Local Inverse Simpson's Index 36 | (CLISI) as proposed by Luecken, M. D. et al. Benchmarking atlas-level data 37 | integration in single-cell genomics. Nat. Methods 19, 41–50 (2022) on the 38 | latent and spatial nearest neighbor graph(s) respectively.* Afterwards, the 39 | ratio of the two CLISI scores is taken and logarithmized as proposed by 40 | Heidari, E. et al. Supervised spatial inference of dissociated single-cell 41 | data with SageNet. bioRxiv 2022.04.14.488419 (2022) 42 | doi:10.1101/2022.04.14.488419, leveraging the properties of the log that 43 | np.log2(x/y) = -np.log2(y/x) and np.log2(x/x) = 0. At this stage, values 44 | closer to 0 indicate better local neighborhood cell type heterogeneity 45 | preservation. We then normalize the resulting value by the maximum possible 46 | value that would occur in the case of minimal local neighborhood cell type 47 | preservation to scale our metric between '0' and '1'. Finally, we compute 48 | the median of the absolute normalized scores and subtract it from 1 so that 49 | values closer to '1' indicate better local neighborhood cell type 50 | heterogeneity preservation. 51 | 52 | If a ´batch_key´ is provided, separate spatial nearest neighbor graphs per 53 | batch will be computed and the spatial clisi scores are computed for each 54 | batch separately. 55 | 56 | If existent, uses precomputed nearest neighbor graphs stored in 57 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 58 | ´adata.obsp[latent_knng_key + '_connectivities']´. 59 | Alternatively, computes them on the fly using ´spatial_key´, ´latent_key´ 60 | and ´n_neighbors´, and stores them in 61 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 62 | ´adata.obsp[latent_knng_key + '_connectivities']´ respectively. 63 | 64 | * The Inverse Simpson's Index measures the expected number of 65 | samples needed to be sampled before two are drawn from the same category. 66 | The Local Inverse Simpson's Index combines perplexity-based neighborhood 67 | construction with the Inverse Simpson's Index to account for distances 68 | between neighbors. The CLISI score is the LISI applied to cell nearest 69 | neighbor graphs with cell types as categories, and indicates the effective 70 | number of different cell types represented in the local neighborhood of each 71 | cell. If the cells are well mixed, we might expect the CLISI score to be 72 | close to the number of unique cell types (e.g. neigborhoods with an equal 73 | number of cells from 2 cell types get a CLISI of 2). Note, however, that 74 | even under perfect mixing, the value would be smaller than the number of 75 | unique cell types if the absolute number of cells is different for different 76 | cell types. 77 | 78 | Parameters 79 | ---------- 80 | adata: 81 | AnnData object with cell type annotations stored in 82 | ´adata.obs[cell_type_key]´, precomputed nearest neighbor graphs stored 83 | in ´adata.obsp[spatial_knng_key + '_connectivities']´ and 84 | ´adata.obsp[latent_knng_key + '_connectivities']´ or spatial coordinates 85 | stored in ´adata.obsm[spatial_key]´ and the latent representation from a 86 | model stored in ´adata.obsm[latent_key]´. 87 | cell_type_key: 88 | Key under which the cell type annotations are stored in ´adata.obs´. 89 | batch_key: 90 | Key under which the batches are stored in ´adata.obs´. If ´None´, the 91 | adata is assumed to only have one unique batch. 92 | spatial_knng_key: 93 | Key under which the spatial nearest neighbor graph is / will be stored 94 | in ´adata.obsp´ with the suffix '_connectivities'. 95 | latent_knng_key: 96 | Key under which the latent nearest neighbor graph is / will be stored in 97 | ´adata.obsp´ with the suffix '_connectivities'. 98 | spatial_key: 99 | Key under which the spatial coordinates are stored in ´adata.obsm´. 100 | latent_key: 101 | Key under which the latent representation from a model is stored in 102 | ´adata.obsm´. 103 | n_neighbors: 104 | Number of neighbors used for the construction of the nearest neighbor 105 | graphs from the spatial coordinates and the latent representation from 106 | a model in case they are constructed. 107 | n_jobs: 108 | Number of jobs to use for parallelization of neighbor search. 109 | seed: 110 | Random seed for reproducibility. 111 | 112 | Returns 113 | ---------- 114 | clisis: 115 | The Cell Type Local Inverse Simpson's Index Similarity. 116 | """ 117 | # Adding '_connectivities' as expected / added by 118 | # 'compute_knn_graph_connectivities_and_distances' 119 | spatial_knng_connectivities_key = spatial_knng_key + "_connectivities" 120 | latent_knng_connectivities_key = latent_knng_key + "_connectivities" 121 | 122 | if spatial_knng_connectivities_key in adata.obsp: 123 | print("Using precomputed spatial nearest neighbor graph...") 124 | 125 | print("Computing spatial cell CLISI scores for entire dataset...") 126 | spatial_cell_clisi_scores = lisi_knn( 127 | X=adata.obsp[f"{spatial_knng_key}_distances"], 128 | labels=adata.obs[cell_type_key]) 129 | 130 | elif batch_key is None: 131 | print("Computing spatial nearest neighbor graph for entire dataset...") 132 | compute_knn_graph_connectivities_and_distances( 133 | adata=adata, 134 | feature_key=spatial_key, 135 | knng_key=spatial_knng_key, 136 | n_neighbors=n_neighbors, 137 | random_state=seed, 138 | n_jobs=n_jobs) 139 | 140 | print("Computing spatial cell CLISI scores for entire dataset...") 141 | spatial_cell_clisi_scores = lisi_knn( 142 | X=adata.obsp[f"{spatial_knng_key}_distances"], 143 | labels=adata.obs[cell_type_key]) 144 | 145 | elif batch_key is not None: 146 | # Compute cell CLISI scores for spatial nearest neighbor graph 147 | # of each batch separately and add to array 148 | unique_batches = adata.obs[batch_key].unique().tolist() 149 | spatial_cell_clisi_scores = np.zeros(len(adata)) 150 | adata.obs["index"] = np.arange(len(adata)) 151 | for batch in unique_batches: 152 | adata_batch = adata[adata.obs[batch_key] == batch] 153 | 154 | print("Computing spatial nearest neighbor graph for " 155 | f"{batch_key} {batch}...") 156 | # Compute batch-specific spatial (ground truth) nearest 157 | # neighbor graph 158 | compute_knn_graph_connectivities_and_distances( 159 | adata=adata_batch, 160 | feature_key=spatial_key, 161 | knng_key=spatial_knng_key, 162 | n_neighbors=n_neighbors, 163 | random_state=seed, 164 | n_jobs=n_jobs) 165 | 166 | print("Computing spatial cell CLISI scores for " 167 | f"{batch_key} {batch}...") 168 | batch_spatial_cell_clisi_scores = lisi_knn( 169 | X=adata_batch.obsp[f"{spatial_knng_key}_distances"], 170 | labels=adata_batch.obs[cell_type_key]) 171 | 172 | # Save results 173 | spatial_cell_clisi_scores[adata_batch.obs["index"].values] = ( 174 | batch_spatial_cell_clisi_scores) 175 | 176 | if latent_knng_connectivities_key in adata.obsp: 177 | print("Using precomputed latent nearest neighbor graph...") 178 | else: 179 | print("Computing latent nearest neighbor graph...") 180 | # Compute latent connectivities 181 | compute_knn_graph_connectivities_and_distances( 182 | adata=adata, 183 | feature_key=latent_key, 184 | knng_key=latent_knng_key, 185 | n_neighbors=n_neighbors, 186 | random_state=seed, 187 | n_jobs=n_jobs) 188 | 189 | print("Computing latent cell CLISI scores...") 190 | latent_cell_clisi_scores = lisi_knn( 191 | X=adata.obsp[f"{latent_knng_key}_distances"], 192 | labels=adata.obs[cell_type_key]) 193 | 194 | print("Computing CLISIS...") 195 | cell_rclisi_scores = latent_cell_clisi_scores / spatial_cell_clisi_scores 196 | cell_log_rclisi_scores = np.log2(cell_rclisi_scores) 197 | 198 | n_cell_types = adata.obs[cell_type_key].nunique() 199 | max_cell_log_rclisi = np.log2(n_cell_types / 1) 200 | norm_cell_log_rclisi_scores = cell_log_rclisi_scores / max_cell_log_rclisi 201 | 202 | clisis = (1 - np.nanmedian(abs(norm_cell_log_rclisi_scores))) 203 | return clisis -------------------------------------------------------------------------------- /src/nichecompass/benchmarking/gcs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the Graph Connectivity Similarity (GCS) benchmark for 3 | testing how accurately the latent nearest neighbor graph preserves edges and 4 | non-edges from the spatial (physical) nearest neighbor graph. 5 | """ 6 | 7 | from typing import Optional 8 | 9 | import numpy as np 10 | import scipy.sparse as sp 11 | from anndata import AnnData 12 | 13 | from .utils import compute_knn_graph_connectivities_and_distances 14 | 15 | 16 | def compute_gcs( 17 | adata: AnnData, 18 | batch_key: Optional[str]=None, 19 | spatial_knng_key: str="spatial_knng", 20 | latent_knng_key: str="nichecompass_latent_knng", 21 | spatial_key: Optional[str]="spatial", 22 | latent_key: Optional[str]="nichecompass_latent", 23 | n_neighbors: Optional[int]=15, 24 | n_jobs: int=1, 25 | seed: int=0): 26 | """ 27 | Compute the graph connectivity similarity (GCS). The GCS measures how 28 | accurately the latent nearest neighbor graph preserves edges and non-edges 29 | from the spatial (ground truth) nearest neighbor graph. A value of '1' 30 | indicates perfect graph similarity and a value of '0' indicates no graph 31 | connectivity similarity at all. 32 | 33 | If a ´batch_key´ is provided, the GCS will be computed on each batch 34 | separately, and the average across all batches is returned. 35 | 36 | If existent, uses precomputed nearest neighbor graphs stored in 37 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 38 | ´adata.obsp[latent_knng_key + '_connectivities']´. 39 | Alternatively, computes them on the fly using ´spatial_key´, ´latent_key´ 40 | and ´n_neighbors´, and stores them in 41 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 42 | ´adata.obsp[latent_knng_key + '_connectivities']´ respectively. 43 | 44 | Parameters 45 | ---------- 46 | adata: 47 | AnnData object with precomputed nearest neighbor graphs stored in 48 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 49 | ´adata.obsp[latent_knng_key + '_connectivities']´ or spatial coordinates 50 | stored in ´adata.obsm[spatial_key]´ and the latent representation from a 51 | model stored in ´adata.obsm[latent_key]´. 52 | batch_key: 53 | Key under which the batches are stored in ´adata.obs´. 54 | spatial_knng_key: 55 | Key under which the spatial nearest neighbor graph is / will be stored 56 | in ´adata.obsp´ with the suffix '_connectivities'. 57 | latent_knng_key: 58 | Key under which the latent nearest neighbor graph is / will be stored in 59 | ´adata.obsp´ with the suffix '_connectivities'. 60 | spatial_key: 61 | Key under which the spatial coordinates are stored in ´adata.obsm´. 62 | latent_key: 63 | Key under which the latent representation from a model is stored in 64 | ´adata.obsm´. 65 | n_neighbors: 66 | Number of neighbors used for the construction of the nearest neighbor 67 | graphs from the spatial coordinates and the latent representation from 68 | a model in case they are constructed. 69 | n_jobs: 70 | Number of jobs to use for parallelization of neighbor search. 71 | seed: 72 | Random seed for reproducibility. 73 | 74 | Returns 75 | ---------- 76 | gcs: 77 | Normalized matrix similarity between the spatial nearest neighbor graph 78 | and the latent nearest neighbor graph as measured by one minus the 79 | size-normalized Frobenius norm of the element-wise matrix differences. 80 | """ 81 | # Adding '_connectivities' as expected / added by 82 | # 'compute_knn_graph_connectivities_and_distances' 83 | spatial_knng_connectivities_key = spatial_knng_key + "_connectivities" 84 | latent_knng_connectivities_key = latent_knng_key + "_connectivities" 85 | 86 | if batch_key is not None: 87 | adata_batch_list = [] 88 | unique_batches = adata.obs[batch_key].unique().tolist() 89 | for batch in unique_batches: 90 | adata_batch = adata[adata.obs[batch_key] == batch] 91 | adata_batch_list.append(adata_batch) 92 | 93 | if spatial_knng_connectivities_key in adata.obsp: 94 | print("Using precomputed spatial nearest neighbor graph...") 95 | 96 | elif batch_key is None: 97 | print("Computing spatial nearest neighbor graph for entire dataset...") 98 | compute_knn_graph_connectivities_and_distances( 99 | adata=adata, 100 | feature_key=spatial_key, 101 | knng_key=spatial_knng_key, 102 | n_neighbors=n_neighbors, 103 | random_state=seed, 104 | n_jobs=n_jobs) 105 | 106 | elif batch_key is not None: 107 | # Compute spatial nearest neighbor graph for each batch separately 108 | for i, batch in enumerate(unique_batches): 109 | print("Computing spatial nearest neighbor graph for " 110 | f"{batch_key} {batch}...") 111 | compute_knn_graph_connectivities_and_distances( 112 | adata=adata_batch_list[i], 113 | feature_key=spatial_key, 114 | knng_key=spatial_knng_key, 115 | n_neighbors=n_neighbors, 116 | random_state=seed, 117 | n_jobs=n_jobs) 118 | 119 | if latent_knng_connectivities_key in adata.obsp: 120 | print("Using precomputed latent nearest neighbor graph...") 121 | elif batch_key is None: 122 | print("Computing latent nearest neighbor graph for entire dataset...") 123 | compute_knn_graph_connectivities_and_distances( 124 | adata=adata, 125 | feature_key=latent_key, 126 | knng_key=latent_knng_key, 127 | n_neighbors=n_neighbors, 128 | random_state=seed, 129 | n_jobs=n_jobs) 130 | elif batch_key is not None: 131 | # Compute latent nearest neighbor graph for each batch separately 132 | for i, batch in enumerate(unique_batches): 133 | print("Computing latent nearest neighbor graph for " 134 | f"{batch_key} {batch}...") 135 | compute_knn_graph_connectivities_and_distances( 136 | adata=adata_batch_list[i], 137 | feature_key=latent_key, 138 | knng_key=latent_knng_key, 139 | n_neighbors=n_neighbors, 140 | random_state=seed, 141 | n_jobs=n_jobs) 142 | 143 | if batch_key is None: 144 | print("Computing GCS for entire dataset...") 145 | n_neighbors = adata.uns[f"{latent_knng_key}_n_neighbors"] 146 | # Compute Frobenius norm of the matrix of differences to quantify 147 | # distance (square root of the sum of absolute squares) 148 | connectivities_diff = ( 149 | adata.obsp[latent_knng_connectivities_key] - 150 | adata.obsp[spatial_knng_connectivities_key]) 151 | gcd = sp.linalg.norm(connectivities_diff, 152 | ord="fro") 153 | 154 | # Normalize gcd to be between 0 and 1 and convert to gcs by subtracting 155 | # from 1. Maximum number of differences per node is 2 * n_neighbors ( 156 | # sc.pp.neighbors returns a weighted symmetric knn graph with the node- 157 | # wise sums of weights not exceeding the number of neighbors; the 158 | # maximum difference for a node is reached if none of the neighbors 159 | # coincide for the node) 160 | gcs = 1 - (gcd ** 2 / (n_neighbors * 2 * connectivities_diff.shape[0])) 161 | elif batch_key is not None: 162 | # Compute GCS per batch and average 163 | gcs_list = [] 164 | for i, batch in enumerate(unique_batches): 165 | print(f"Computing GCS for {batch_key} {batch}...") 166 | n_neighbors = adata_batch_list[i].uns[ 167 | f"{latent_knng_key}_n_neighbors"] 168 | batch_connectivities_diff = ( 169 | adata_batch_list[i].obsp[latent_knng_connectivities_key] - 170 | adata_batch_list[i].obsp[spatial_knng_connectivities_key]) 171 | batch_gcd = sp.linalg.norm(batch_connectivities_diff, 172 | ord="fro") 173 | batch_gcs = 1 - ( 174 | batch_gcd ** 2 / ( 175 | n_neighbors * 2 * batch_connectivities_diff.shape[0])) 176 | gcs_list.append(batch_gcs) 177 | gcs = np.mean(gcs_list) 178 | return gcs 179 | 180 | 181 | def compute_avg_gcs( 182 | adata: AnnData, 183 | batch_key: Optional[str]=None, 184 | spatial_key: str="spatial", 185 | latent_key: str="nichecompass_latent", 186 | min_n_neighbors: int=1, 187 | max_n_neighbors: int=15, 188 | seed: int=0) -> float: 189 | """ 190 | Compute multiple Graph Connectivity Similarities (GCS) by varying the number 191 | of neighbors used for nearest neighbor graph construction (between 192 | ´min_n_neighbors´ and ´max_n_neighbors´) and return the average GCS. Can use 193 | precomputed spatial and latent nearest neighbor graphs stored in 194 | ´adata.obsp[f'{spatial_key}_{n_neighbors}nng_connectivities']´ and 195 | ´adata.obsp[f'{latent_key}_{n_neighbors}nng_connectivities']´ respectively. 196 | 197 | Parameters 198 | ---------- 199 | adata: 200 | AnnData object with spatial coordinates stored in 201 | ´adata.obsm[spatial_key]´ and the latent representation from a model 202 | stored in ´adata.obsm[latent_key]´. Precomputed nearest neighbor graphs 203 | can optionally be stored in 204 | ´adata.obsp[f'{spatial_key}_{n_neighbors}nng_connectivities']´ 205 | and ´adata.obsp[f'{latent_key}_{n_neighbors}nng_connectivities']´. 206 | batch_key: 207 | Key under which the batches are stored in ´adata.obs´. If ´None´, the 208 | adata is assumed to only have one unique batch. 209 | spatial_key: 210 | Key under which the spatial coordinates are stored in ´adata.obsm´. 211 | latent_key: 212 | Key under which the latent representation from a model is stored in 213 | ´adata.obsm´. 214 | min_n_neighbors: 215 | Minimum number of neighbors used for computing the average GCS. 216 | max_n_neighbors: 217 | Maximum number of neighbors used for computing the average GCS. 218 | seed: 219 | Random seed for reproducibility. 220 | 221 | Returns 222 | ---------- 223 | avg_gcs: 224 | Average GCS computed over different nearest neighbor graphs with varying 225 | number of neighbors. 226 | """ 227 | gcs_list = [] 228 | for n_neighbors in range(min_n_neighbors, max_n_neighbors): 229 | gcs_list.append(compute_gcs( 230 | adata=adata, 231 | batch_key=batch_key, 232 | spatial_knng_key=f"{spatial_key}_{n_neighbors}nng", 233 | latent_knng_key=f"{latent_key}_{n_neighbors}nng", 234 | spatial_key=spatial_key, 235 | latent_key=latent_key, 236 | n_neighbors=n_neighbors, 237 | seed=seed)) 238 | avg_gcs = np.mean(gcs_list) 239 | return avg_gcs 240 | -------------------------------------------------------------------------------- /src/nichecompass/benchmarking/mlami.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the Maximum Leiden Adjusted Mutual Info (MLAMI) benchmark 3 | for testing how accurately the latent feature space preserves global spatial 4 | organization from the spatial (physical) feature space by comparing clustering 5 | overlaps. 6 | """ 7 | 8 | from typing import Optional 9 | 10 | import numpy as np 11 | import scanpy as sc 12 | from anndata import AnnData 13 | from sklearn.metrics import adjusted_mutual_info_score 14 | 15 | from .utils import compute_knn_graph_connectivities_and_distances 16 | 17 | 18 | def compute_mlami( 19 | adata: AnnData, 20 | batch_key: Optional[str]=None, 21 | spatial_knng_key: str="spatial_knng", 22 | latent_knng_key: str="nichecompass_latent_knng", 23 | spatial_key: Optional[str]="spatial", 24 | latent_key: Optional[str]="nichecompass_latent", 25 | n_neighbors: Optional[int]=15, 26 | min_res: float=0.1, 27 | max_res: float=1.0, 28 | res_num: int=3, 29 | n_jobs: int=1, 30 | seed: int=0) -> float: 31 | """ 32 | Compute the Maximum Leiden Adjusted Mutual Info (MLAMI). The MLAMI ranges 33 | between '0' and '1' with higher values indicating that the latent feature 34 | space more accurately preserves global spatial organization from the spatial 35 | (ground truth) feature space. To compute the MLAMI, Leiden clusterings with 36 | different resolutions are computed for both nearest neighbor graphs. The 37 | Adjusted Mutual Info (AMI) between all clustering resolution pairs is 38 | computed to quantify cluster overlap and the maximum value is returned as 39 | metric for spatial organization preservation. 40 | 41 | If a ´batch_key´ is provided, the MLAMI will be computed on each batch 42 | separately (with latent Leiden clusters computed on the integrated latent 43 | space), and the average across all batches is returned. 44 | 45 | If existent, uses precomputed nearest neighbor graphs stored in 46 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 47 | ´adata.obsp[latent_knng_key + '_connectivities']´. 48 | Alternatively, computes them on the fly using ´spatial_key´, ´latent_key´ 49 | and ´n_neighbors´, and stores them in 50 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 51 | ´adata.obsp[latent_knng_key + '_connectivities']´ respectively. 52 | 53 | Parameters 54 | ---------- 55 | adata: 56 | AnnData object with precomputed nearest neighbor graphs stored in 57 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and 58 | ´adata.obsp[latent_knng_key + '_connectivities']´ or spatial coordinates 59 | stored in ´adata.obsm[spatial_key]´ and the latent representation from a 60 | model stored in ´adata.obsm[latent_key]´. 61 | batch_key: 62 | Key under which the batches are stored in ´adata.obs´. If ´None´, the 63 | adata is assumed to only have one unique batch. 64 | spatial_knng_key: 65 | Key under which the spatial nearest neighbor graph is / will be stored 66 | in ´adata.obsp´ with the suffix '_connectivities'. 67 | latent_knng_key: 68 | Key under which the latent nearest neighbor graph is / will be stored in 69 | ´adata.obsp´ with the suffix '_connectivities'. 70 | spatial_key: 71 | Key under which the spatial coordinates are stored in ´adata.obsm´. 72 | latent_key: 73 | Key under which the latent representation from a model is stored in 74 | ´adata.obsm´. 75 | n_neighbors: 76 | Number of neighbors used for the construction of the nearest neighbor 77 | graphs from the spatial coordinates and the latent representation from 78 | a model in case they are constructed. 79 | min_res: 80 | Minimum resolution for Leiden clustering. 81 | max_res: 82 | Maximum resolution for Leiden clustering. 83 | res_num: 84 | Number of linearly spaced Leiden resolutions between ´min_res´ and 85 | ´max_res´ for which Leiden clusterings will be computed. 86 | n_jobs: 87 | Number of jobs to use for parallelization of neighbor search. 88 | seed: 89 | Random seed for reproducibility. 90 | 91 | Returns 92 | ---------- 93 | mlami: 94 | MLAMI between all clustering resolution pairs. 95 | """ 96 | # Adding '_connectivities' as expected / added by 97 | # 'compute_knn_graph_connectivities_and_distances' 98 | spatial_knng_connectivities_key = spatial_knng_key + "_connectivities" 99 | latent_knng_connectivities_key = latent_knng_key + "_connectivities" 100 | 101 | if batch_key is not None: 102 | adata_batch_list = [] 103 | unique_batches = adata.obs[batch_key].unique().tolist() 104 | for batch in unique_batches: 105 | adata_batch = adata[adata.obs[batch_key] == batch] 106 | adata_batch_list.append(adata_batch) 107 | 108 | if spatial_knng_connectivities_key in adata.obsp: 109 | print("Using precomputed spatial nearest neighbor graph...") 110 | elif batch_key is None: 111 | print("Computing spatial nearest neighbor graph for entire dataset...") 112 | compute_knn_graph_connectivities_and_distances( 113 | adata=adata, 114 | feature_key=spatial_key, 115 | knng_key=spatial_knng_key, 116 | n_neighbors=n_neighbors, 117 | random_state=seed, 118 | n_jobs=n_jobs) 119 | elif batch_key is not None: 120 | # Compute spatial nearest neighbor graph for each batch separately 121 | for i, batch in enumerate(unique_batches): 122 | print("Computing spatial nearest neighbor graph for " 123 | f"{batch_key} {batch}...") 124 | compute_knn_graph_connectivities_and_distances( 125 | adata=adata_batch_list[i], 126 | feature_key=spatial_key, 127 | knng_key=spatial_knng_key, 128 | n_neighbors=n_neighbors, 129 | random_state=seed, 130 | n_jobs=n_jobs) 131 | 132 | if latent_knng_connectivities_key in adata.obsp: 133 | print("Using precomputed latent nearest neighbor graph...") 134 | else: 135 | print("Computing latent nearest neighbor graph...") 136 | compute_knn_graph_connectivities_and_distances( 137 | adata=adata, 138 | feature_key=latent_key, 139 | knng_key=latent_knng_key, 140 | n_neighbors=n_neighbors, 141 | random_state=seed, 142 | n_jobs=n_jobs) 143 | 144 | # Define search space of clustering resolutions 145 | clustering_resolutions = np.linspace(start=min_res, 146 | stop=max_res, 147 | num=res_num, 148 | dtype=np.float32) 149 | 150 | if batch_key is None: 151 | print("Computing spatial Leiden clusterings for entire dataset...") 152 | # Calculate spatial Leiden clustering for different resolutions 153 | for resolution in clustering_resolutions: 154 | sc.tl.leiden(adata=adata, 155 | resolution=resolution, 156 | random_state=seed, 157 | key_added=f"leiden_spatial_{str(resolution)}", 158 | adjacency=adata.obsp[spatial_knng_connectivities_key]) 159 | elif batch_key is not None: 160 | # Compute spatial Leiden clustering for each batch separately 161 | for i, batch in enumerate(unique_batches): 162 | print("Computing spatial Leiden clusterings for " 163 | f"{batch_key} {batch}...") 164 | # Calculate spatial Leiden clustering for different resolutions 165 | for resolution in clustering_resolutions: 166 | sc.tl.leiden( 167 | adata=adata_batch_list[i], 168 | resolution=resolution, 169 | random_state=seed, 170 | key_added=f"leiden_spatial_{str(resolution)}", 171 | adjacency=adata_batch_list[i].obsp[spatial_knng_connectivities_key]) 172 | 173 | print("Computing latent Leiden clusterings...") 174 | # Calculate latent Leiden clustering for different resolutions 175 | for resolution in clustering_resolutions: 176 | sc.tl.leiden(adata, 177 | resolution=resolution, 178 | random_state=seed, 179 | key_added=f"leiden_latent_{str(resolution)}", 180 | adjacency=adata.obsp[latent_knng_connectivities_key]) 181 | if batch_key is not None: 182 | for i, batch in enumerate(unique_batches): 183 | adata_batch_list[i].obs[f"leiden_latent_{str(resolution)}"] = ( 184 | adata.obs[f"leiden_latent_{str(resolution)}"]) 185 | 186 | if batch_key is None: 187 | print("Computing MLAMI for entire dataset...") 188 | # Calculate max LAMI over all clustering resolutions 189 | lami_list = [] 190 | for spatial_resolution in clustering_resolutions: 191 | for latent_resolution in clustering_resolutions: 192 | lami_list.append(_compute_ami( 193 | adata=adata, 194 | cluster_group1_key=f"leiden_spatial_{str(spatial_resolution)}", 195 | cluster_group2_key=f"leiden_latent_{str(latent_resolution)}")) 196 | mlami = np.max(lami_list) 197 | elif batch_key is not None: 198 | for i, batch in enumerate(unique_batches): 199 | print(f"Computing MLAMI for {batch_key} {batch}...") 200 | batch_lami_list = [] 201 | for spatial_resolution in clustering_resolutions: 202 | for latent_resolution in clustering_resolutions: 203 | batch_lami_list.append(_compute_ami( 204 | adata=adata_batch_list[i], 205 | cluster_group1_key=f"leiden_spatial_{str(spatial_resolution)}", 206 | cluster_group2_key=f"leiden_latent_{str(latent_resolution)}")) 207 | batch_mlami = np.max(batch_lami_list) 208 | mlami = np.mean(batch_mlami) 209 | return mlami 210 | 211 | 212 | def _compute_ami(adata: AnnData, 213 | cluster_group1_key: str, 214 | cluster_group2_key: str) -> float: 215 | """ 216 | Compute the Adjusted Mutual Information (AMI) between two different 217 | cluster assignments. AMI compares the overlap of two clusterings. For 218 | details, see documentation at 219 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html#sklearn.metrics.adjusted_mutual_info_score. 220 | 221 | Parameters 222 | ---------- 223 | adata: 224 | AnnData object with clustering labels stored in 225 | ´adata.obs[cluster_group1_key]´ and ´adata.obs[cluster_group2_key]´. 226 | cluster_group1_key: 227 | Key under which the clustering labels from the first clustering 228 | assignment are stored in ´adata.obs´. 229 | cluster_group2_key: 230 | Key under which the clustering labels from the second clustering 231 | assignment are stored in ´adata.obs´. 232 | 233 | Returns 234 | ---------- 235 | ami: 236 | AMI score as calculated by the sklearn implementation. 237 | """ 238 | cluster_group1 = adata.obs[cluster_group1_key].tolist() 239 | cluster_group2 = adata.obs[cluster_group2_key].tolist() 240 | 241 | ami = adjusted_mutual_info_score(cluster_group1, 242 | cluster_group2, 243 | average_method="arithmetic") 244 | return ami 245 | -------------------------------------------------------------------------------- /src/nichecompass/benchmarking/nasw.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the Niche Average Silhouette Width (NASW) benchmark 3 | for testing how well the latent feature space can be clustered into distinct 4 | and compact clusters. 5 | """ 6 | 7 | from typing import Optional 8 | 9 | import numpy as np 10 | import scanpy as sc 11 | import scib_metrics 12 | from anndata import AnnData 13 | 14 | from .utils import compute_knn_graph_connectivities_and_distances 15 | 16 | 17 | def compute_nasw( 18 | adata: AnnData, 19 | latent_knng_key: str="nichecompass_latent_knng", 20 | latent_key: Optional[str]="nichecompass_latent", 21 | n_neighbors: Optional[int]=15, 22 | min_res: float=0.1, 23 | max_res: float=1.0, 24 | res_num: int=3, 25 | n_jobs: int=1, 26 | seed: int=0) -> float: 27 | """ 28 | Compute the Niche Average Silhouette Width (NASW). The NASW ranges between 29 | '0' and '1' with higher values indicating more distinct and compact 30 | clusters in the latent feature space. To compute the NASW, Leiden 31 | clusterings with different resolutions are computed for the latent nearest 32 | neighbor graph. The NASW for all clustering resolutions is computed and the 33 | average value is returned as metric for clusterability. 34 | 35 | If existent, uses a precomputed latent nearest neighbor graph stored in 36 | ´adata.obsp[latent_knng_key + '_connectivities']´. 37 | Alternatively, computes it on the fly using ´latent_key´ and ´n_neighbors´, 38 | and stores it in ´adata.obsp[latent_knng_key + '_connectivities']´. 39 | 40 | Parameters 41 | ---------- 42 | adata: 43 | AnnData object with a precomputed latent nearest neighbor graph stored 44 | in ´adata.obsp[latent_knng_key + '_connectivities']´ or the latent 45 | representation from a model stored in ´adata.obsm[latent_key]´. 46 | latent_knng_key: 47 | Key under which the latent nearest neighbor graph is / will be stored in 48 | ´adata.obsp´ with the suffix '_connectivities'. 49 | latent_key: 50 | Key under which the latent representation from a model is stored in 51 | ´adata.obsm´. 52 | n_neighbors: 53 | Number of neighbors used for the construction of the latent nearest 54 | neighbor graph from the latent representation from a model in case they 55 | are constructed. 56 | min_res: 57 | Minimum resolution for Leiden clustering. 58 | max_res: 59 | Maximum resolution for Leiden clustering. 60 | res_num: 61 | Number of linearly spaced Leiden resolutions between ´min_res´ and 62 | ´max_res´ for which Leiden clusterings will be computed. 63 | n_jobs: 64 | Number of jobs to use for parallelization of neighbor search. 65 | seed: 66 | Random seed for reproducibility. 67 | 68 | Returns 69 | ---------- 70 | nasw: 71 | Average NASW across all clustering resolutions. 72 | """ 73 | # Adding '_connectivities' as expected / added by 74 | # 'compute_knn_graph_connectivities_and_distances' 75 | latent_knng_connectivities_key = latent_knng_key + "_connectivities" 76 | 77 | if latent_knng_connectivities_key in adata.obsp: 78 | print("Using precomputed latent nearest neighbor graph...") 79 | else: 80 | print("Computing latent nearest neighbor graph...") 81 | compute_knn_graph_connectivities_and_distances( 82 | adata=adata, 83 | feature_key=latent_key, 84 | knng_key=latent_knng_key, 85 | n_neighbors=n_neighbors, 86 | random_state=seed, 87 | n_jobs=n_jobs) 88 | 89 | # Define search space of clustering resolutions 90 | clustering_resolutions = np.linspace(start=min_res, 91 | stop=max_res, 92 | num=res_num, 93 | dtype=np.float32) 94 | 95 | print("Computing latent Leiden clusterings...") 96 | # Calculate latent Leiden clustering for different resolutions 97 | for resolution in clustering_resolutions: 98 | if not f"leiden_latent_{str(resolution)}" in adata.obs: 99 | sc.tl.leiden(adata, 100 | resolution=resolution, 101 | random_state=seed, 102 | key_added=f"leiden_latent_{str(resolution)}", 103 | adjacency=adata.obsp[latent_knng_connectivities_key]) 104 | else: 105 | print("Using precomputed latent Leiden clusters for resolution " 106 | f"{str(resolution)}.") 107 | 108 | print("Computing NASW...") 109 | # Calculate max MNASW over all clustering resolutions 110 | nasw_list = [] 111 | for resolution in clustering_resolutions: 112 | nasw_list.append(scib_metrics.silhouette_label( 113 | X=adata.obsm[latent_key], 114 | labels=adata.obs[f"leiden_latent_{str(resolution)}"])) 115 | nasw = np.mean(nasw_list) 116 | return nasw -------------------------------------------------------------------------------- /src/nichecompass/benchmarking/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains helper functions for the ´benchmarking´ subpackage. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import scanpy as sc 9 | from anndata import AnnData 10 | from scib_metrics.nearest_neighbors import pynndescent 11 | 12 | 13 | def compute_knn_graph_connectivities_and_distances( 14 | adata: AnnData, 15 | feature_key: str="nichecompass_latent", 16 | knng_key: str="nichecompass_latent_15knng", 17 | n_neighbors: int=15, 18 | random_state: int=0, 19 | n_jobs: int=1) -> None: 20 | """ 21 | Compute approximate k-nearest-neighbors graph. 22 | 23 | Parameters 24 | ---------- 25 | adata: 26 | AnnData object with the features for knn graph computation stored in 27 | ´adata.obsm[feature_key]´. 28 | feature_key: 29 | Key in ´adata.obsm´ that will be used to compute the knn graph. 30 | knng_key: 31 | Key under which the knn graph connectivities will be stored 32 | in ´adata.obsp´ with the suffix '_connectivities', the knn graph 33 | distances will be stored in ´adata.obsp´ with the suffix '_distances', 34 | and the number of neighbors will be stored in ´adata.uns with the suffix 35 | '_n_neighbors' . 36 | n_neighbors: 37 | Number of neighbors of the knn graph. 38 | random_state: 39 | Random state for reproducibility. 40 | n_jobs: 41 | Number of jobs to use for parallelization of neighbor search. 42 | """ 43 | neigh_output = pynndescent( 44 | adata.obsm[feature_key], 45 | n_neighbors=n_neighbors, 46 | random_state=random_state, 47 | n_jobs=n_jobs) 48 | indices, distances = neigh_output.indices, neigh_output.distances 49 | 50 | # This is a trick to get lisi metrics to work by adding the tiniest possible value 51 | # to 0 distance neighbors so that each cell has the same amount of neighbors 52 | # (otherwise some cells lose neighbors with distance 0 due to sparse representation) 53 | row_idx = np.where(distances == 0)[0] 54 | col_idx = np.where(distances == 0)[1] 55 | new_row_idx = row_idx[np.where(row_idx != indices[row_idx, col_idx])[0]] 56 | new_col_idx = col_idx[np.where(row_idx != indices[row_idx, col_idx])[0]] 57 | distances[new_row_idx, new_col_idx] = (distances[new_row_idx, new_col_idx] + 58 | np.nextafter(0, 1, dtype=np.float32)) 59 | 60 | sp_distances, sp_conns = sc.neighbors._compute_connectivities_umap( 61 | indices[:, :n_neighbors], 62 | distances[:, :n_neighbors], 63 | adata.n_obs, 64 | n_neighbors=n_neighbors) 65 | adata.obsp[f"{knng_key}_connectivities"] = sp_conns 66 | adata.obsp[f"{knng_key}_distances"] = sp_distances 67 | adata.uns[f"{knng_key}_n_neighbors"] = n_neighbors 68 | 69 | 70 | def convert_to_one_hot(vector: np.ndarray, 71 | n_classes: Optional[int]) -> np.array: 72 | """ 73 | Converts an input 1-D vector of integer labels into a 2-D array of one-hot 74 | vectors, where for an i'th input value of j, a '1' will be inserted in the 75 | i'th row and j'th column of the output one-hot vector. 76 | 77 | Implementation is adapted from 78 | https://github.com/theislab/scib/blob/29f79d0135f33426481f9ff05dd1ae55c8787142/scib/metrics/lisi.py#L498 79 | (05.12.22). 80 | 81 | Parameters 82 | ---------- 83 | vector: 84 | Vector to be one-hot-encoded. 85 | n_classes: 86 | Number of classes to be considered for one-hot-encoding. If ´None´, the 87 | number of classes will be inferred from ´vector´. 88 | 89 | Returns 90 | ---------- 91 | one_hot: 92 | 2-D NumPy array of one-hot-encoded vectors. 93 | 94 | Example: 95 | ´´´ 96 | vector = np.array((1, 0, 4)) 97 | one_hot = _convert_to_one_hot(vector) 98 | print(one_hot) 99 | [[0 1 0 0 0] 100 | [1 0 0 0 0] 101 | [0 0 0 0 1]] 102 | ´´´ 103 | """ 104 | if n_classes is None: 105 | n_classes = np.max(vector) + 1 106 | 107 | one_hot = np.zeros(shape=(len(vector), n_classes)) 108 | one_hot[np.arange(len(vector)), vector] = 1 109 | return one_hot.astype(int) 110 | -------------------------------------------------------------------------------- /src/nichecompass/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloaders import initialize_dataloaders 2 | from .dataprocessors import (edge_level_split, 3 | node_level_split_mask, 4 | prepare_data) 5 | from .datareaders import load_spatial_adata_from_csv 6 | from .datasets import SpatialAnnTorchDataset 7 | 8 | __all__ = ["initialize_dataloaders", 9 | "edge_level_split", 10 | "node_level_split_mask", 11 | "prepare_data", 12 | "load_spatial_adata_from_csv", 13 | "SpatialAnnTorchDataset"] 14 | -------------------------------------------------------------------------------- /src/nichecompass/data/dataloaders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains dataloaders for the training of an NicheCompass model. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | from torch_geometric.data import Data 8 | from torch_geometric.loader import LinkNeighborLoader, NeighborLoader 9 | 10 | 11 | def initialize_dataloaders(node_masked_data: Data, 12 | edge_train_data: Optional[Data]=None, 13 | edge_val_data: Optional[Data]=None, 14 | edge_batch_size: Optional[int]=64, 15 | node_batch_size: int=64, 16 | n_direct_neighbors: int=-1, 17 | n_hops: int=1, 18 | shuffle: bool=True, 19 | edges_directed: bool=False, 20 | neg_edge_sampling_ratio: float=1.) -> dict: 21 | """ 22 | Initialize edge-level and node-level training and validation dataloaders. 23 | 24 | Parameters 25 | ---------- 26 | node_masked_data: 27 | PyG Data object with node-level split masks. 28 | edge_train_data: 29 | PyG Data object containing the edge-level training set. 30 | edge_val_data: 31 | PyG Data object containing the edge-level validation set. 32 | edge_batch_size: 33 | Batch size for the edge-level dataloaders. 34 | node_batch_size: 35 | Batch size for the node-level dataloaders. 36 | n_direct_neighbors: 37 | Number of sampled direct neighbors of the current batch nodes to be 38 | included in the batch. Defaults to ´-1´, which means to include all 39 | direct neighbors. 40 | n_hops: 41 | Number of neighbor hops / levels for neighbor sampling of nodes to be 42 | included in the current batch. E.g. ´2´ means to not only include 43 | sampled direct neighbors of current batch nodes but also sampled 44 | neighbors of the direct neighbors. 45 | shuffle: 46 | If `True`, shuffle the dataloaders. 47 | edges_directed: 48 | If `False`, both symmetric edge index pairs are included in the same 49 | edge-level batch (1 edge has 2 symmetric edge index pairs). 50 | neg_edge_sampling_ratio: 51 | Negative sampling ratio of edges. This is currently implemented in an 52 | approximate way, i.e. negative edges may contain false negatives. 53 | 54 | Returns 55 | ---------- 56 | loader_dict: 57 | Dictionary containing training and validation PyG LinkNeighborLoader 58 | (for edge reconstruction) and NeighborLoader (for gene expression 59 | reconstruction) objects. 60 | """ 61 | loader_dict = {} 62 | 63 | # Node-level dataloaders 64 | loader_dict["node_train_loader"] = NeighborLoader( 65 | node_masked_data, 66 | num_neighbors=[n_direct_neighbors] * n_hops, 67 | batch_size=node_batch_size, 68 | directed=False, 69 | shuffle=shuffle, 70 | input_nodes=node_masked_data.train_mask) 71 | if node_masked_data.val_mask.sum() != 0: 72 | loader_dict["node_val_loader"] = NeighborLoader( 73 | node_masked_data, 74 | num_neighbors=[n_direct_neighbors] * n_hops, 75 | batch_size=node_batch_size, 76 | directed=False, 77 | shuffle=shuffle, 78 | input_nodes=node_masked_data.val_mask) 79 | 80 | # Edge-level dataloaders 81 | if edge_train_data is not None: 82 | loader_dict["edge_train_loader"] = LinkNeighborLoader( 83 | edge_train_data, 84 | num_neighbors=[n_direct_neighbors] * n_hops, 85 | batch_size=edge_batch_size, 86 | edge_label=None, # will automatically be added as 1 for all edges 87 | edge_label_index=edge_train_data.edge_label_index[:, edge_train_data.edge_label.bool()], # limit the edges to the ones from the edge_label_adj 88 | directed=edges_directed, 89 | shuffle=shuffle, 90 | neg_sampling_ratio=neg_edge_sampling_ratio) 91 | if edge_val_data is not None and edge_val_data.edge_label.sum() != 0: 92 | loader_dict["edge_val_loader"] = LinkNeighborLoader( 93 | edge_val_data, 94 | num_neighbors=[n_direct_neighbors] * n_hops, 95 | batch_size=edge_batch_size, 96 | edge_label=None, # will automatically be added as 1 for all edges 97 | edge_label_index=edge_val_data.edge_label_index[:, edge_val_data.edge_label.bool()], # limit the edges to the ones from the edge_label_adj 98 | directed=edges_directed, 99 | shuffle=shuffle, 100 | neg_sampling_ratio=neg_edge_sampling_ratio) 101 | 102 | return loader_dict -------------------------------------------------------------------------------- /src/nichecompass/data/dataprocessors.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains data processors for the training of an NicheCompass model. 3 | """ 4 | 5 | from typing import List, Optional, Tuple 6 | 7 | import scipy.sparse as sp 8 | import torch 9 | from anndata import AnnData 10 | from torch_geometric.data import Data 11 | from torch_geometric.transforms import RandomNodeSplit, RandomLinkSplit 12 | from torch_geometric.utils import add_self_loops, remove_self_loops 13 | 14 | from .datasets import SpatialAnnTorchDataset 15 | 16 | 17 | def edge_level_split(data: Data, 18 | edge_label_adj: Optional[sp.csr_matrix], 19 | val_ratio: float=0.1, 20 | test_ratio: float=0., 21 | is_undirected: bool=True, 22 | neg_sampling_ratio: float=0.) -> Tuple[Data, Data, Data]: 23 | """ 24 | Split a PyG Data object into training, validation and test PyG Data objects 25 | using an edge-level split. The training split does not include edges in the 26 | validation and test splits and the validation split does not include edges 27 | in the test split. However, nodes will not be split and all node features 28 | will be accessible from all splits. 29 | 30 | Check https://github.com/pyg-team/pytorch_geometric/issues/3668 for more 31 | context how RandomLinkSplit works. 32 | 33 | Parameters 34 | ---------- 35 | data: 36 | PyG Data object to be split. 37 | edge_label_adj: 38 | Adjacency matrix which contains edges for edge reconstruction. If 39 | ´None´, uses the 'normal' adjacency matrix used for message passing. 40 | val_ratio: 41 | Ratio of edges to be included in the validation split. 42 | test_ratio: 43 | Ratio of edges to be included in the test split. 44 | is_undirected: 45 | If ´True´, the graph is assumed to be undirected, and positive and 46 | negative samples will not leak (reverse) edge connectivity across 47 | different splits. This is set to ´False´, as there is an issue with 48 | replication of self loops. 49 | neg_sampling_ratio: 50 | Ratio of negative sampling. This should be set to 0 if negative sampling 51 | is done by the dataloader. 52 | 53 | Returns 54 | ---------- 55 | train_data: 56 | Training PyG Data object. 57 | val_data: 58 | Validation PyG Data object. 59 | test_data: 60 | Test PyG Data object. 61 | """ 62 | # Clone data to not modify in place to not affect node split 63 | data_no_self_loops = data.clone() 64 | 65 | # Remove self loops temporarily as we don't want them as edge labels. There 66 | # is also an issue with RandomLinkSplit (self loops will be replicated for 67 | # message passing). We will add the self loops again after the split 68 | data_no_self_loops.edge_index, data_no_self_loops.edge_attr = ( 69 | remove_self_loops(edge_index=data.edge_index, 70 | edge_attr=data.edge_attr)) 71 | 72 | if edge_label_adj is not None: 73 | # Add edge label which is 1 for edges from edge_label_adj and 0 otherwise. 74 | # This will be used by dataloader to only sample edges from edge_label_adj 75 | # as opposed to from adj. 76 | data_no_self_loops.edge_label = torch.tensor( 77 | [(edge_label_adj[edge_index[0].item(), edge_index[1].item()] == 1.0) for 78 | edge_index in data_no_self_loops.edge_attr]).int() 79 | 80 | random_link_split = RandomLinkSplit( 81 | num_val=val_ratio, 82 | num_test=test_ratio, 83 | is_undirected=is_undirected, 84 | key="edge_label", # if ´edge_label´ is not existent, it will be added with 1s 85 | neg_sampling_ratio=neg_sampling_ratio) 86 | train_data, val_data, test_data = random_link_split(data_no_self_loops) 87 | 88 | # Readd self loops for message passing 89 | for split_data in [train_data, val_data, test_data]: 90 | split_data.edge_index = add_self_loops( 91 | edge_index=split_data.edge_index, 92 | num_nodes=split_data.x.shape[0])[0] 93 | split_data.edge_attr = add_self_loops( 94 | edge_index=split_data.edge_attr.t(), 95 | num_nodes=split_data.x.shape[0])[0].t() 96 | return train_data, val_data, test_data 97 | 98 | 99 | def node_level_split_mask(data: Data, 100 | val_ratio: float=0.1, 101 | test_ratio: float=0., 102 | split_key: str="x") -> Data: 103 | """ 104 | Split data on node-level into training, validation and test sets by adding 105 | node-level masks (train_mask, val_mask, test_mask) to the PyG Data object. 106 | 107 | Parameters 108 | ---------- 109 | data: 110 | PyG Data object to be split. 111 | val_ratio: 112 | Ratio of nodes to be included in the validation split. 113 | test_ratio: 114 | Ratio of nodes to be included in the test split. 115 | split_key: 116 | The attribute key of the PyG Data object that holds the ground 117 | truth labels. Only nodes in which the key is present will be split. 118 | 119 | Returns 120 | ---------- 121 | data: 122 | PyG Data object with ´train_mask´, ´val_mask´ and ´test_mask´ attributes 123 | added. 124 | """ 125 | random_node_split = RandomNodeSplit( 126 | num_val=val_ratio, 127 | num_test=test_ratio, 128 | key=split_key) 129 | data = random_node_split(data) 130 | return data 131 | 132 | 133 | def prepare_data(adata: AnnData, 134 | cat_covariates_label_encoders: List[dict], 135 | adata_atac: Optional[AnnData]=None, 136 | counts_key: Optional[str]="counts", 137 | adj_key: str="spatial_connectivities", 138 | cat_covariates_keys: Optional[List[str]]=None, 139 | edge_val_ratio: float=0.1, 140 | edge_test_ratio: float=0., 141 | node_val_ratio: float=0.1, 142 | node_test_ratio: float=0.) -> dict: 143 | """ 144 | Prepare data for model training including edge-level and node-level train, 145 | validation, and test splits. 146 | 147 | Parameters 148 | ---------- 149 | adata: 150 | AnnData object with counts stored in ´adata.layers[counts_key]´ or 151 | ´adata.X´ depending on ´counts_key´, and sparse adjacency matrix stored 152 | in ´adata.obsp[adj_key]´. 153 | adata_atac: 154 | Additional optional AnnData object with paired spatial ATAC data. 155 | cat_covariates_label_encoders: 156 | List of categorical covariates label encoders from the model (label 157 | encoding indeces need to be aligned with the ones from the model to get 158 | the correct categorical covariates embeddings). 159 | counts_key: 160 | Key under which the counts are stored in ´adata.layer´. If ´None´, uses 161 | ´adata.X´ as counts. 162 | adj_key: 163 | Key under which the sparse adjacency matrix is stored in ´adata.obsp´. 164 | cat_covariates_keys: 165 | Keys under which the categorical covariates are stored in ´adata.obs´. 166 | edge_val_ratio: 167 | Fraction of the data that is used as validation set on edge-level. 168 | edge_test_ratio: 169 | Fraction of the data that is used as test set on edge-level. 170 | node_val_ratio: 171 | Fraction of the data that is used as validation set on node-level. 172 | node_test_ratio: 173 | Fraction of the data that is used as test set on node-level. 174 | 175 | Returns 176 | ---------- 177 | data_dict: 178 | Dictionary containing edge-level training, validation and test PyG 179 | Data objects and node-level PyG Data object with split masks under keys 180 | ´edge_train_data´, ´edge_val_data´, ´edge_test_data´, and 181 | ´node_masked_data´ respectively. The edge-level PyG Data objects contain 182 | edges in the ´edge_label_index´ attribute and edge labels in the 183 | ´edge_label´ attribute. 184 | """ 185 | data_dict = {} 186 | dataset = SpatialAnnTorchDataset( 187 | adata=adata, 188 | adata_atac=adata_atac, 189 | counts_key=counts_key, 190 | adj_key=adj_key, 191 | cat_covariates_keys=cat_covariates_keys, 192 | cat_covariates_label_encoders=cat_covariates_label_encoders) 193 | 194 | # PyG Data object (has 2 edge index pairs for one edge because of symmetry; 195 | # one edge index pair will be removed in the edge-level split). 196 | data = Data(x=dataset.x, 197 | edge_index=dataset.edge_index, 198 | edge_attr=dataset.edge_index.t()) # store index of edge nodes as 199 | # edge attribute for 200 | # aggregation weight retrieval 201 | # in mini batches 202 | 203 | if cat_covariates_keys is not None: 204 | data.cat_covariates_cats = dataset.cat_covariates_cats 205 | 206 | # Edge-level split for edge reconstruction 207 | edge_train_data, edge_val_data, edge_test_data = edge_level_split( 208 | data=data, 209 | edge_label_adj=dataset.edge_label_adj, 210 | val_ratio=edge_val_ratio, 211 | test_ratio=edge_test_ratio) 212 | data_dict["edge_train_data"] = edge_train_data 213 | data_dict["edge_val_data"] = edge_val_data 214 | data_dict["edge_test_data"] = edge_test_data 215 | 216 | # Node-level split for gene expression reconstruction 217 | data_dict["node_masked_data"] = node_level_split_mask( 218 | data=data, 219 | val_ratio=node_val_ratio, 220 | test_ratio=node_test_ratio) 221 | return data_dict -------------------------------------------------------------------------------- /src/nichecompass/data/datareaders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains data readers for the training of an NicheCompass model. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import anndata as ad 8 | import pandas as pd 9 | import scipy.sparse as sp 10 | 11 | 12 | def load_spatial_adata_from_csv(counts_file_path: str, 13 | adj_file_path: str, 14 | cell_type_file_path: Optional[str]=None, 15 | adj_key: str="spatial_connectivities", 16 | cell_type_col: str="cell_type", 17 | cell_type_key: str="cell_type") -> ad.AnnData: 18 | """ 19 | Create AnnData object from two csv files containing gene expression feature 20 | matrix and adjacency matrix respectively. Optionally, a third csv file with 21 | cell types can be provided. 22 | 23 | Parameters 24 | ---------- 25 | counts_file_path: 26 | File path of the csv file which contains gene expression feature matrix 27 | data. 28 | adj_file_path: 29 | File path of the csv file which contains adjacency matrix data. 30 | cell_type_file_path: 31 | File path of the csv file which contains cell type data. 32 | adj_key: 33 | Key under which the sparse adjacency matrix will be stored in 34 | ´adata.obsp´. 35 | cell_type_col: 36 | Column under wich the cell type is stored in the ´cell_type_file´. 37 | cell_type_key: 38 | Key under which the cell types will be stored in ´adata.obs´. 39 | 40 | Returns 41 | ---------- 42 | adata: 43 | AnnData object with gene expression data stored in ´adata.X´ and sparse 44 | adjacency matrix (coo format) stored in ´adata.obps[adj_key]´. 45 | """ 46 | adata = ad.read_csv(counts_file_path) 47 | adj_df = pd.read_csv(adj_file_path, sep=",", header=0) 48 | adj = adj_df.values 49 | adata.obsp[adj_key] = sp.csr_matrix(adj).tocoo() 50 | 51 | if cell_type_file_path: 52 | cell_type_df = pd.read_csv(cell_type_file_path, sep=",", header=0) 53 | adata.obs[cell_type_key] = cell_type_df[cell_type_col].values 54 | return adata -------------------------------------------------------------------------------- /src/nichecompass/data/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the SpatialAnnTorchDataset class to provide a standardized 3 | dataset format for the training of an NicheCompass model. 4 | """ 5 | 6 | from typing import List, Optional 7 | 8 | import scipy.sparse as sp 9 | import torch 10 | from anndata import AnnData 11 | from torch_geometric.utils import add_self_loops, remove_self_loops 12 | 13 | from .utils import encode_labels, sparse_mx_to_sparse_tensor 14 | 15 | 16 | class SpatialAnnTorchDataset(): 17 | """ 18 | Spatially annotated torch dataset class to extract node features, node 19 | labels, adjacency matrix and edge indices in a standardized format from an 20 | AnnData object. 21 | 22 | Parameters 23 | ---------- 24 | adata: 25 | AnnData object with counts stored in ´adata.layers[counts_key]´ or 26 | ´adata.X´ depending on ´counts_key´, and sparse adjacency matrix stored 27 | in ´adata.obsp[adj_key]´. 28 | adata_atac: 29 | Additional optional AnnData object with paired spatial ATAC data. 30 | cat_covariates_label_encoders: 31 | List of categorical covariates label encoders from the model (label 32 | encoding indeces need to be aligned with the ones from the model to get 33 | the correct categorical covariates embeddings). 34 | counts_key: 35 | Key under which the counts are stored in ´adata.layer´. If ´None´, uses 36 | ´adata.X´ as counts. 37 | adj_key: 38 | Key under which the sparse adjacency matrix is stored in ´adata.obsp´. 39 | self_loops: 40 | If ´True´, add self loops to the adjacency matrix to model autocrine 41 | communication. 42 | cat_covariates_keys: 43 | Keys under which the categorical covariates are stored in ´adata.obs´. 44 | """ 45 | def __init__(self, 46 | adata: AnnData, 47 | cat_covariates_label_encoders: List[dict], 48 | adata_atac: Optional[AnnData]=None, 49 | counts_key: Optional[str]="counts", 50 | adj_key: str="spatial_connectivities", 51 | edge_label_adj_key: str="edge_label_spatial_connectivities", 52 | self_loops: bool=True, 53 | cat_covariates_keys: Optional[str]=None): 54 | if counts_key is None: 55 | x = adata.X 56 | else: 57 | x = adata.layers[counts_key] 58 | 59 | # Store features in dense format 60 | if sp.issparse(x): 61 | self.x = torch.tensor(x.toarray()) 62 | else: 63 | self.x = torch.tensor(x) 64 | 65 | # Concatenate ATAC feature vector in dense format if provided 66 | if adata_atac is not None: 67 | if sp.issparse(adata_atac.X): 68 | self.x = torch.cat( 69 | (self.x, torch.tensor(adata_atac.X.toarray())), axis=1) 70 | else: 71 | self.x = torch.cat((self.x, torch.tensor(adata_atac.X)), axis=1) 72 | 73 | # Store adjacency matrix in torch_sparse SparseTensor format 74 | if sp.issparse(adata.obsp[adj_key]): 75 | self.adj = sparse_mx_to_sparse_tensor(adata.obsp[adj_key]) 76 | else: 77 | self.adj = sparse_mx_to_sparse_tensor( 78 | sp.csr_matrix(adata.obsp[adj_key])) 79 | 80 | # Store edge label adjacency matrix 81 | if edge_label_adj_key in adata.obsp: 82 | self.edge_label_adj = sp.csr_matrix(adata.obsp[edge_label_adj_key]) 83 | else: 84 | self.edge_label_adj = None 85 | 86 | # Validate adjacency matrix symmetry 87 | if (self.adj.nnz() != self.adj.t().nnz()): 88 | raise ImportError("The input adjacency matrix has to be symmetric.") 89 | 90 | self.edge_index = self.adj.to_torch_sparse_coo_tensor()._indices() 91 | 92 | if self_loops: 93 | # Add self loops to account for autocrine communication 94 | # Remove self loops in case there are already before adding new ones 95 | self.edge_index, _ = remove_self_loops(self.edge_index) 96 | self.edge_index, _ = add_self_loops(self.edge_index, 97 | num_nodes=self.x.size(0)) 98 | 99 | if cat_covariates_keys is not None: 100 | self.cat_covariates_cats = [] 101 | for cat_covariate_key, cat_covariate_label_encoder in zip( 102 | cat_covariates_keys, 103 | cat_covariates_label_encoders): 104 | cat_covariate_cats = torch.tensor( 105 | encode_labels(adata, 106 | cat_covariate_label_encoder, 107 | cat_covariate_key), dtype=torch.long) 108 | self.cat_covariates_cats.append(cat_covariate_cats) 109 | self.cat_covariates_cats = torch.stack(self.cat_covariates_cats, 110 | dim=1) 111 | 112 | self.n_node_features = self.x.size(1) 113 | self.size_factors = self.x.sum(1) # fix for ATAC case 114 | 115 | def __len__(self): 116 | """Return the number of observations stored in SpatialAnnTorchDataset""" 117 | return self.x.size(0) -------------------------------------------------------------------------------- /src/nichecompass/data/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains helper functions for the ´data´ subpackage. 3 | """ 4 | 5 | import anndata as ad 6 | import numpy as np 7 | import torch 8 | from scipy.sparse import csr_matrix 9 | from torch_sparse import SparseTensor 10 | 11 | 12 | def encode_labels(adata: ad.AnnData, 13 | label_encoder: dict, 14 | label_key="str") -> np.ndarray: 15 | """ 16 | Encode labels from an `adata` object stored in `adata.obs` to integers. 17 | 18 | Implementation is adapted from 19 | https://github.com/theislab/scarches/blob/c21492d409150cec73d26409f9277b3ac971f4a7/scarches/dataset/trvae/_utils.py#L4 20 | (20.01.2023). 21 | 22 | Parameters 23 | ---------- 24 | adata: 25 | AnnData object with labels stored in `adata.obs[label_key]`. 26 | label_encoder: 27 | Dictionary where keys are labels and values are label encodings. 28 | label_key: 29 | Key where in `adata.obs` the labels to be encoded are stored. 30 | 31 | Returns 32 | ------- 33 | encoded_labels: 34 | Integer-encoded labels. 35 | """ 36 | unique_labels = list(np.unique(adata.obs[label_key])) 37 | encoded_labels = np.zeros(adata.shape[0]) 38 | 39 | if not set(unique_labels).issubset(set(label_encoder.keys())): 40 | print(f"Warning: Labels in adata.obs[{label_key}] are not a subset of " 41 | "the label encoder!") 42 | print("Therefore integer value of those labels is set to '-1'.") 43 | for unique_label in unique_labels: 44 | if unique_label not in label_encoder.keys(): 45 | encoded_labels[adata.obs[label_key] == unique_label] = -1 46 | 47 | for label, label_encoding in label_encoder.items(): 48 | encoded_labels[adata.obs[label_key] == label] = label_encoding 49 | return encoded_labels 50 | 51 | 52 | def sparse_mx_to_sparse_tensor(sparse_mx: csr_matrix) -> SparseTensor: 53 | """ 54 | Convert a scipy sparse matrix into a torch_sparse SparseTensor. 55 | 56 | Parameters 57 | ---------- 58 | sparse_mx: 59 | Sparse scipy csr_matrix. 60 | 61 | Returns 62 | ---------- 63 | sparse_tensor: 64 | torch_sparse SparseTensor object that can be utilized by PyG. 65 | """ 66 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 67 | indices = torch.from_numpy( 68 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 69 | values = torch.from_numpy(sparse_mx.data) 70 | shape = torch.Size(sparse_mx.shape) 71 | torch_sparse_coo_tensor = torch.sparse.FloatTensor(indices, values, shape) 72 | sparse_tensor = SparseTensor.from_torch_sparse_coo_tensor( 73 | torch_sparse_coo_tensor) 74 | return sparse_tensor -------------------------------------------------------------------------------- /src/nichecompass/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nichecompass import NicheCompass 2 | from .basemodelmixin import BaseModelMixin 3 | 4 | __all__ = ["NicheCompass", 5 | "BaseModelMixin"] 6 | -------------------------------------------------------------------------------- /src/nichecompass/models/basemodelmixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains generic base model functionalities, added as a Mixin to the 3 | NicheCompass model. 4 | """ 5 | 6 | import inspect 7 | import os 8 | import warnings 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import pickle 13 | import scipy.sparse as sp 14 | import torch 15 | from anndata import AnnData 16 | 17 | from .utils import initialize_model, load_saved_files, validate_var_names 18 | 19 | 20 | class BaseModelMixin(): 21 | """ 22 | Base model mix in class for universal model functionalities. 23 | 24 | Parts of the implementation are adapted from 25 | https://github.com/theislab/scarches/blob/master/scarches/models/base/_base.py#L15 26 | (01.10.2022) and 27 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_base_model.py#L63 28 | (01.10.2022). 29 | """ 30 | def _get_user_attributes(self) -> list: 31 | """ 32 | Get all the attributes defined in a model instance, for example 33 | self.is_trained_. 34 | 35 | Returns 36 | ---------- 37 | attributes: 38 | Attributes defined in a model instance. 39 | """ 40 | attributes = inspect.getmembers( 41 | self, lambda a: not (inspect.isroutine(a))) 42 | attributes = [a for a in attributes if not ( 43 | a[0].startswith("__") and a[0].endswith("__"))] 44 | return attributes 45 | 46 | def _get_public_attributes(self) -> dict: 47 | """ 48 | Get only public attributes defined in a model instance. By convention 49 | public attributes have a trailing underscore. 50 | 51 | Returns 52 | ---------- 53 | public_attributes: 54 | Public attributes defined in a model instance. 55 | """ 56 | public_attributes = self._get_user_attributes() 57 | public_attributes = {a[0]: a[1] for a in public_attributes if 58 | a[0][-1] == "_"} 59 | return public_attributes 60 | 61 | def _get_init_params(self, locals: dict) -> dict: 62 | """ 63 | Get the model init signature with associated passed in values from 64 | locals (except the AnnData object passed in). 65 | 66 | Parameters 67 | ---------- 68 | locals: 69 | Dictionary returned by calling the ´locals()´ method. 70 | 71 | Returns 72 | ---------- 73 | user_params: 74 | Model initialization attributes defined in a model instance. 75 | """ 76 | init = self.__init__ 77 | sig = inspect.signature(init) 78 | init_params = [p for p in sig.parameters] 79 | user_params = {p: locals[p] for p in locals if p in init_params} 80 | user_params = {k: v for (k, v) in user_params.items() if not 81 | isinstance(v, AnnData)} 82 | return user_params 83 | 84 | def save(self, 85 | dir_path: str, 86 | overwrite: bool=False, 87 | save_adata: bool=False, 88 | adata_file_name: str="adata.h5ad", 89 | save_adata_atac: bool=False, 90 | adata_atac_file_name: str="adata_atac.h5ad", 91 | **anndata_write_kwargs): 92 | """ 93 | Save model to disk (the Trainer optimizer state is not saved). 94 | 95 | Parameters 96 | ---------- 97 | dir_path: 98 | Path of the directory where the model will be saved. 99 | overwrite: 100 | If `True`, overwrite existing data. If `False` and directory 101 | already exists at `dir_path`, error will be raised. 102 | save_adata: 103 | If `True`, also saves the AnnData object. 104 | adata_file_name: 105 | File name under which the AnnData object will be saved. 106 | save_adata_atac: 107 | If `True`, also saves the ATAC AnnData object. 108 | adata_atac_file_name: 109 | File name under which the ATAC AnnData object will be saved. 110 | adata_write_kwargs: 111 | Kwargs for adata write function. 112 | """ 113 | if not os.path.exists(dir_path) or overwrite: 114 | os.makedirs(dir_path, exist_ok=overwrite) 115 | else: 116 | raise ValueError(f"Directory '{dir_path}' already exists." 117 | "Please provide another directory for saving.") 118 | 119 | model_save_path = os.path.join(dir_path, "model_params.pt") 120 | attr_save_path = os.path.join(dir_path, "attr.pkl") 121 | var_names_save_path = os.path.join(dir_path, "var_names.csv") 122 | 123 | if save_adata: 124 | # Convert storage format of adjacency matrix to be writable by 125 | # adata.write() 126 | if self.adata.obsp["spatial_connectivities"] is not None: 127 | self.adata.obsp["spatial_connectivities"] = sp.csr_matrix( 128 | self.adata.obsp["spatial_connectivities"]) 129 | self.adata.write( 130 | os.path.join(dir_path, adata_file_name), **anndata_write_kwargs) 131 | 132 | if save_adata_atac: 133 | self.adata_atac.write( 134 | os.path.join(dir_path, adata_atac_file_name)) 135 | 136 | var_names = self.adata.var_names.astype(str).to_numpy() 137 | public_attributes = self._get_public_attributes() 138 | 139 | torch.save(self.model.state_dict(), model_save_path) 140 | np.savetxt(var_names_save_path, var_names, fmt="%s") 141 | with open(attr_save_path, "wb") as f: 142 | pickle.dump(public_attributes, f) 143 | 144 | @classmethod 145 | def load(cls, 146 | dir_path: str, 147 | adata: Optional[AnnData]=None, 148 | adata_atac: Optional[AnnData]=None, 149 | adata_file_name: str="adata.h5ad", 150 | adata_atac_file_name: Optional[str]=None, 151 | use_cuda: bool=False, 152 | n_addon_gps: int=0, 153 | gp_names_key: Optional[str]=None, 154 | genes_idx_key: Optional[str]=None, 155 | unfreeze_all_weights: bool=False, 156 | unfreeze_addon_gp_weights: bool=False, 157 | unfreeze_cat_covariates_embedder_weights: bool=False 158 | ) -> torch.nn.Module: 159 | """ 160 | Instantiate a model from saved output. Can be used for transfer learning 161 | scenarios and to learn de-novo gene programs by adding add-on gene 162 | programs and freezing non add-on weights. 163 | 164 | Parameters 165 | ---------- 166 | dir_path: 167 | Path to saved outputs. 168 | adata: 169 | AnnData organized in the same way as data used to train the model. 170 | If ´None´, will check for and load adata saved with the model. 171 | adata_atac: 172 | ATAC AnnData organized in the same way as data used to train the 173 | model. If ´None´ and ´adata_atac_file_name´ is not ´None´, will 174 | check for and load adata_atac saved with the model. 175 | adata_file_name: 176 | File name of the AnnData object to be loaded. 177 | adata_atac_file_name: 178 | File name of the ATAC AnnData object to be loaded. 179 | use_cuda: 180 | If `True`, load model on GPU. 181 | n_addon_gps: 182 | Number of (new) add-on gene programs to be added to the model's 183 | architecture. 184 | gp_names_key: 185 | Key under which the gene program names are stored in ´adata.uns´. 186 | unfreeze_all_weights: 187 | If `True`, unfreeze all weights. 188 | unfreeze_addon_gp_weights: 189 | If `True`, unfreeze addon gp weights. 190 | unfreeze_cat_covariates_embedder_weights: 191 | If `True`, unfreeze categorical covariates embedder weights. 192 | 193 | Returns 194 | ------- 195 | model: 196 | Model with loaded state dictionaries and, if specified, frozen non 197 | add-on weights. 198 | """ 199 | load_adata = adata is None 200 | load_adata_atac = ((adata_atac is None) & 201 | (adata_atac_file_name is not None)) 202 | use_cuda = use_cuda and torch.cuda.is_available() 203 | map_location = torch.device("cpu") if use_cuda is False else None 204 | 205 | model_state_dict, var_names, attr_dict, new_adata, new_adata_atac = ( 206 | load_saved_files(dir_path, 207 | load_adata, 208 | adata_file_name, 209 | load_adata_atac, 210 | adata_atac_file_name, 211 | map_location=map_location)) 212 | adata = new_adata if new_adata is not None else adata 213 | adata_atac = (new_adata_atac if new_adata_atac is not None else 214 | adata_atac) 215 | 216 | validate_var_names(adata, var_names) 217 | 218 | # Include all genes in gene expression reconstruction if addon nodes 219 | # are present 220 | if n_addon_gps != 0: 221 | if genes_idx_key not in adata.uns: 222 | raise ValueError("Please specifiy a valid 'genes_idx_key' if " 223 | "'n_addon_gps' > 0, so that all genes can be " 224 | "included in the genes idx.") 225 | adata.uns[genes_idx_key] = np.arange(adata.n_vars * 2) 226 | 227 | # Add new categorical covariates categories from query data 228 | cat_covariates_cats = attr_dict["cat_covariates_cats_"] 229 | cat_covariates_keys = attr_dict["init_params_"]["cat_covariates_keys"] 230 | new_cat_covariates_cats = [] 231 | if cat_covariates_keys is not None: 232 | for i, cat_covariate_key in enumerate(cat_covariates_keys): 233 | new_cat_covariate_cats = [] 234 | adata_cat_covariate_cats = adata.obs[cat_covariate_key].unique().tolist() 235 | for cat_covariate_cat in adata_cat_covariate_cats: 236 | if cat_covariate_cat not in cat_covariates_cats[i]: 237 | new_cat_covariate_cats.append(cat_covariate_cat) 238 | for cat_covariate_cat in new_cat_covariate_cats: 239 | new_cat_covariates_cats.append(cat_covariate_cat) 240 | cat_covariates_cats[i].append(cat_covariate_cat) 241 | attr_dict["init_params_"]["cat_covariates_cats"] = cat_covariates_cats 242 | 243 | if n_addon_gps != 0: 244 | attr_dict["n_addon_gps_"] += n_addon_gps 245 | attr_dict["init_params_"]["n_addon_gps"] += n_addon_gps 246 | 247 | if gp_names_key is None: 248 | raise ValueError("Please specify 'gp_names_key' so that addon " 249 | "gps can be added to the gene program list.") 250 | 251 | gps = list(adata.uns[gp_names_key]) 252 | 253 | if any("addon_GP_" in gp for gp in gps): 254 | addon_gp_idx = int(gps[-1][-1]) + 1 255 | adata.uns[gp_names_key] = np.array( 256 | gps + ["addon_GP_" + str(addon_gp_idx + i) for i in 257 | range(n_addon_gps)]) 258 | else: 259 | adata.uns[gp_names_key] = np.array( 260 | gps + ["addon_GP_" + str(i) for i in range(n_addon_gps)]) 261 | 262 | model = initialize_model(cls, adata, attr_dict, adata_atac) 263 | 264 | # set saved attrs for loaded model 265 | for attr, val in attr_dict.items(): 266 | setattr(model, attr, val) 267 | 268 | if n_addon_gps != 0 or len(new_cat_covariates_cats) > 0: 269 | model.model.load_and_expand_state_dict(model_state_dict) 270 | else: 271 | model.model.load_state_dict(model_state_dict) 272 | 273 | if use_cuda: 274 | model.model.cuda() 275 | model.model.eval() 276 | 277 | # First freeze all parameters and then subsequently unfreeze based on 278 | # load settings 279 | for param_name, param in model.model.named_parameters(): 280 | param.requires_grad = False 281 | model.freeze_ = True 282 | if unfreeze_all_weights: 283 | for param_name, param in model.model.named_parameters(): 284 | param.requires_grad = True 285 | model.freeze_ = False 286 | if unfreeze_addon_gp_weights: 287 | # allow updates of addon gp weights 288 | for param_name, param in model.model.named_parameters(): 289 | if "addon" in param_name or \ 290 | "theta" in param_name or \ 291 | "aggregator" in param_name: 292 | param.requires_grad = True 293 | if unfreeze_cat_covariates_embedder_weights: 294 | # Allow updates of categorical covariates embedder weights 295 | for param_name, param in model.model.named_parameters(): 296 | if ("cat_covariate" in param_name) & ("embedder" in param_name): 297 | param.requires_grad = True 298 | 299 | if model.freeze_ and not model.is_trained_: 300 | raise ValueError("The model has not been pre-trained and therefore " 301 | "weights should not be frozen.") 302 | 303 | return model 304 | 305 | def _check_if_trained(self, 306 | warn: bool=True): 307 | """ 308 | Check if the model is trained. 309 | 310 | Parameters 311 | ------- 312 | warn: 313 | If not trained and `warn` is True, raise a warning, else raise a 314 | RuntimeError. 315 | """ 316 | message = ("Trying to query inferred values from an untrained model. " 317 | "Please train the model first.") 318 | if not self.is_trained_: 319 | if warn: 320 | warnings.warn(message) 321 | else: 322 | raise RuntimeError(message) -------------------------------------------------------------------------------- /src/nichecompass/models/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains helper functions for the ´models´ subpackage. 3 | """ 4 | 5 | import logging 6 | import os 7 | import pickle 8 | from collections import OrderedDict 9 | from typing import Optional, Tuple, Literal 10 | 11 | import anndata as ad 12 | import numpy as np 13 | import torch 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def load_saved_files(dir_path: str, 20 | load_adata: bool, 21 | adata_file_name: str="adata.h5ad", 22 | load_adata_atac: bool=False, 23 | adata_atac_file_name: str="adata_atac.h5ad", 24 | map_location: Optional[Literal["cpu", "cuda"]]=None 25 | ) -> Tuple[OrderedDict, dict, np.ndarray, ad.AnnData]: 26 | """ 27 | Helper to load saved model files. 28 | 29 | Parts of the implementation are adapted from 30 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L55 31 | (01.10.2022) 32 | 33 | Parameters 34 | ---------- 35 | dir_path: 36 | Path where the saved model files are stored. 37 | load_adata: 38 | If `True`, also load the stored AnnData object. 39 | adata_file_name: 40 | File name under which the AnnData object is saved. 41 | load_adata_atac: 42 | If `True`, also load the stored ATAC AnnData object. 43 | adata_atac_file_name: 44 | File name under which the ATAC AnnData object is saved. 45 | map_location: 46 | Memory location where to map the model files to. 47 | 48 | Returns 49 | ---------- 50 | model_state_dict: 51 | The stored model state dict. 52 | var_names: 53 | The stored variable names. 54 | attr_dict: 55 | The stored attributes. 56 | adata: 57 | The stored AnnData object. 58 | adata_atac: 59 | The stored ATAC AnnData object. 60 | """ 61 | attr_path = os.path.join(dir_path, "attr.pkl") 62 | adata_path = os.path.join(dir_path, adata_file_name) 63 | var_names_path = os.path.join(dir_path, "var_names.csv") 64 | model_path = os.path.join(dir_path, "model_params.pt") 65 | 66 | if os.path.exists(adata_path) and load_adata: 67 | adata = ad.read(adata_path) 68 | elif not os.path.exists(adata_path) and load_adata: 69 | raise ValueError("Dir path contains no saved anndata and no adata was " 70 | "passed.") 71 | else: 72 | adata = None 73 | 74 | if load_adata_atac: 75 | adata_atac_path = os.path.join(dir_path, adata_atac_file_name) 76 | if os.path.exists(adata_atac_path): 77 | adata_atac = ad.read(adata_atac_path) 78 | else: 79 | raise ValueError("Dir path contains no saved 'adata_atac' and no " 80 | "'adata_atac' was passed.") 81 | else: 82 | adata_atac = None 83 | 84 | model_state_dict = torch.load(model_path, map_location=map_location) 85 | var_names = np.genfromtxt(var_names_path, delimiter=",", dtype=str) 86 | with open(attr_path, "rb") as handle: 87 | attr_dict = pickle.load(handle) 88 | return model_state_dict, var_names, attr_dict, adata, adata_atac 89 | 90 | 91 | def validate_var_names(adata: ad.AnnData, source_var_names: str): 92 | """ 93 | Helper to validate variable names. 94 | 95 | Parts of the implementation are adapted from 96 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L141 97 | (01.10.2022) 98 | 99 | Parameters 100 | ---------- 101 | source_var_names: 102 | Variables names against which to validate. 103 | """ 104 | user_var_names = adata.var_names.astype(str) 105 | if not np.array_equal(source_var_names, user_var_names): 106 | logger.warning( 107 | "The ´var_names´ of the passed in adata do not match the " 108 | "´var_names´ of the adata used to train the model. For valid " 109 | "results, the var_names need to be the same and in the same order " 110 | "as the adata used to train the model.") 111 | 112 | 113 | def initialize_model(cls, 114 | adata: ad.AnnData, 115 | attr_dict: dict, 116 | adata_atac: Optional[ad.AnnData]=None) -> torch.nn.Module: 117 | """ 118 | Helper to initialize a model. Adapted from 119 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L103. 120 | 121 | Parameters 122 | ---------- 123 | adata: 124 | AnnData object to be used for initialization. 125 | attr_dict: 126 | Dictionary with attributes for model initialization. 127 | adata_atac: 128 | ATAC AnnData object to be used for initialization. 129 | """ 130 | if "init_params_" not in attr_dict.keys(): 131 | raise ValueError("No init_params_ were saved by the model.") 132 | # Get the parameters for the class init signature 133 | init_params = attr_dict.pop("init_params_") 134 | 135 | # Grab all the parameters except for kwargs (is a dict) 136 | non_kwargs = {k: v for k, v in init_params.items() if not isinstance(v, dict)} 137 | # Expand out kwargs 138 | kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} 139 | kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} 140 | if adata_atac is None: 141 | # Support for legacy models 142 | model = cls(adata=adata, 143 | **non_kwargs, 144 | **kwargs) 145 | else: 146 | model = cls(adata=adata, 147 | adata_atac=adata_atac, 148 | **non_kwargs, 149 | **kwargs) 150 | return model -------------------------------------------------------------------------------- /src/nichecompass/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .basemodulemixin import BaseModuleMixin 2 | from .losses import (compute_cat_covariates_contrastive_loss, 3 | compute_edge_recon_loss, 4 | compute_gp_group_lasso_reg_loss, 5 | compute_gp_l1_reg_loss, 6 | compute_kl_reg_loss, 7 | compute_omics_recon_nb_loss) 8 | from .vgaemodulemixin import VGAEModuleMixin 9 | from .vgpgae import VGPGAE 10 | 11 | __all__ = ["BaseModuleMixin", 12 | "compute_cat_covariates_contrastive_loss", 13 | "compute_edge_recon_loss", 14 | "compute_gp_group_lasso_reg_loss", 15 | "compute_gp_l1_reg_loss", 16 | "compute_kl_reg_loss", 17 | "compute_omics_recon_nb_loss", 18 | "VGAEModuleMixin", 19 | "VGPGAE"] 20 | -------------------------------------------------------------------------------- /src/nichecompass/modules/basemodulemixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains generic module functionalities, added as a Mixin to the 3 | Variational Gene Program Graph Autoencoder module. 4 | """ 5 | 6 | import inspect 7 | from collections import OrderedDict 8 | 9 | import torch 10 | 11 | 12 | class BaseModuleMixin: 13 | """ 14 | Base module mix in class containing universal module functionalities. 15 | 16 | Parts of the implementation are adapted from 17 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_base_model.py#L63 18 | (01.10.2022). 19 | """ 20 | def _get_user_attributes(self) -> list: 21 | """ 22 | Get all the attributes defined in a module instance. 23 | 24 | Returns 25 | ---------- 26 | attributes: 27 | Attributes defined in a module instance. 28 | """ 29 | attributes = inspect.getmembers( 30 | self, lambda a: not (inspect.isroutine(a))) 31 | attributes = [a for a in attributes if not ( 32 | a[0].startswith("__") and a[0].endswith("__"))] 33 | return attributes 34 | 35 | def _get_public_attributes(self) -> dict: 36 | """ 37 | Get only public attributes defined in a module instance. By convention 38 | public attributes have a trailing underscore. 39 | 40 | Returns 41 | ---------- 42 | public_attributes: 43 | Public attributes defined in a module instance. 44 | """ 45 | public_attributes = self._get_user_attributes() 46 | public_attributes = {a[0]: a[1] for a in public_attributes if 47 | a[0][-1] == "_"} 48 | return public_attributes 49 | 50 | def load_and_expand_state_dict(self, 51 | model_state_dict: OrderedDict): 52 | """ 53 | Load model state dictionary into model and expand it to account for 54 | architectural changes through e.g. add-on nodes. 55 | 56 | Parts of the implementation are adapted from 57 | https://github.com/theislab/scarches/blob/master/scarches/models/base/_base.py#L92 58 | (01.10.2022). 59 | """ 60 | load_state_dict = model_state_dict.copy() # old model architecture state 61 | # dict 62 | new_state_dict = self.state_dict() # new model architecture state dict 63 | device = next(self.parameters()).device 64 | 65 | # Update parameter tensors from old model architecture with changes from 66 | # new model architecture 67 | for key, load_param_tensor in load_state_dict.items(): 68 | new_param_tensor = new_state_dict[key] 69 | if new_param_tensor.size() == load_param_tensor.size(): 70 | continue # nothing needs to be updated 71 | else: 72 | # new model architecture parameter tensors are different from 73 | # old model architecture parameter tensors; updates are 74 | # necessary 75 | load_param_tensor = load_param_tensor.to(device) 76 | n_dims = len(new_param_tensor.shape) 77 | idx_slicers = [slice(None)] * n_dims 78 | for i in range(n_dims): 79 | dim_diff = (new_param_tensor.shape[i] - 80 | load_param_tensor.shape[i]) 81 | idx_slicers[i] = slice(-dim_diff, None) 82 | if dim_diff > 0: 83 | break 84 | expanded_param_tensor = torch.cat( 85 | [load_param_tensor, new_param_tensor[tuple(idx_slicers)]], 86 | dim=i) 87 | load_state_dict[key] = expanded_param_tensor 88 | 89 | # Add parameter tensors from new model architecture to old model 90 | # architecture state dict 91 | for key, new_param_tensor in new_state_dict.items(): 92 | if key not in load_state_dict: 93 | load_state_dict[key] = new_param_tensor 94 | 95 | self.load_state_dict(load_state_dict) 96 | -------------------------------------------------------------------------------- /src/nichecompass/modules/vgaemodulemixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains generic VGAE functionalities, added as a Mixin to the 3 | Variational Gene Program Graph Autoencoder neural network module. 4 | """ 5 | 6 | import torch 7 | 8 | 9 | class VGAEModuleMixin: 10 | """ 11 | VGAE module mix in class containing universal VGAE module 12 | functionalities. 13 | """ 14 | def reparameterize(self, 15 | mu: torch.Tensor, 16 | logstd: torch.Tensor) -> torch.Tensor: 17 | """ 18 | Use reparameterization trick for latent space normal distribution. 19 | 20 | Parameters 21 | ---------- 22 | mu: 23 | Expected values of the latent space distribution (dim: n_obs, 24 | n_gps). 25 | logstd: 26 | Log standard deviations of the latent space distribution (dim: n_obs, 27 | n_gps). 28 | 29 | Returns 30 | ---------- 31 | rep: 32 | Reparameterized latent features (dim: n_obs, n_gps). 33 | """ 34 | if self.training: 35 | std = torch.exp(logstd) 36 | eps = torch.randn_like(mu) 37 | rep = eps.mul(std).add(mu) 38 | return rep 39 | else: 40 | rep = mu 41 | return rep 42 | -------------------------------------------------------------------------------- /src/nichecompass/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregators import (OneHopAttentionNodeLabelAggregator, 2 | OneHopGCNNormNodeLabelAggregator, 3 | OneHopSumNodeLabelAggregator) 4 | from .decoders import (CosineSimGraphDecoder, 5 | FCOmicsFeatureDecoder, 6 | MaskedOmicsFeatureDecoder) 7 | from .encoders import Encoder 8 | from .layercomponents import MaskedLinear 9 | from .layers import AddOnMaskedLayer 10 | 11 | __all__ = ["OneHopAttentionNodeLabelAggregator", 12 | "OneHopGCNNormNodeLabelAggregator", 13 | "OneHopSumNodeLabelAggregator", 14 | "CosineSimGraphDecoder", 15 | "FCOmicsFeatureDecoder", 16 | "MaskedOmicsFeatureDecoder", 17 | "Encoder", 18 | "MaskedLinear", 19 | "AddOnMaskedLayer"] 20 | -------------------------------------------------------------------------------- /src/nichecompass/nn/aggregators.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains gene expression aggregators used by the NicheCompass model. 3 | """ 4 | 5 | from typing import Literal 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch_geometric.nn.conv import MessagePassing 10 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 11 | from torch_geometric.nn.dense.linear import Linear 12 | from torch_geometric.nn.inits import glorot 13 | from torch_geometric.utils import softmax 14 | from torch_sparse import SparseTensor 15 | 16 | 17 | class OneHopAttentionNodeLabelAggregator(MessagePassing): 18 | """ 19 | One-hop Attention Node Label Aggregator class that uses a weighted sum 20 | of the omics features of a node's 1-hop neighbors to build an 21 | aggregated neighbor omics feature vector for a node. The weights are 22 | determined by an additivite attention mechanism with learnable weights. 23 | It returns a concatenation of the node's own omics feature vector and 24 | the attention-aggregated neighbor omics feature vector as node labels 25 | for the omics reconstruction task. 26 | 27 | Parts of the implementation are inspired by 28 | https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/gatv2_conv.py#L16 29 | (01.10.2022). 30 | 31 | Parameters 32 | ---------- 33 | modality: 34 | Omics modality that is aggregated. Can be either `rna` or `atac`. 35 | n_input: 36 | Number of omics features used for the Node Label Aggregation. 37 | n_heads: 38 | Number of attention heads for multi-head attention. 39 | leaky_relu_negative_slope: 40 | Slope of the leaky relu activation function. 41 | dropout_rate: 42 | Dropout probability of the normalized attention coefficients which 43 | exposes each node to a stochastically sampled neighborhood during 44 | training. 45 | """ 46 | def __init__(self, 47 | modality: Literal["rna", "atac"], 48 | n_input: int, 49 | n_heads: int=4, 50 | leaky_relu_negative_slope: float=0.2, 51 | dropout_rate: float=0.): 52 | super().__init__(node_dim=0) 53 | self.n_input = n_input 54 | self.n_heads = n_heads 55 | self.leaky_relu_negative_slope = leaky_relu_negative_slope 56 | self.linear_l_l = Linear(n_input, 57 | n_input * n_heads, 58 | bias=False, 59 | weight_initializer="glorot") 60 | self.linear_r_l = Linear(n_input, 61 | n_input * n_heads, 62 | bias=False, 63 | weight_initializer="glorot") 64 | self.attn = nn.Parameter(torch.Tensor(1, n_heads, n_input)) 65 | self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) 66 | self.dropout = nn.Dropout(p=dropout_rate) 67 | self._alpha = None 68 | self.reset_parameters() 69 | 70 | print(f"ONE HOP ATTENTION {modality.upper()} NODE LABEL AGGREGATOR -> " 71 | f"n_input: {n_input}, " 72 | f"n_heads: {n_heads}") 73 | 74 | def reset_parameters(self): 75 | """ 76 | Reset weight parameters. 77 | """ 78 | self.linear_l_l.reset_parameters() 79 | self.linear_r_l.reset_parameters() 80 | glorot(self.attn) 81 | 82 | def forward(self, 83 | x: torch.Tensor, 84 | edge_index: torch.Tensor, 85 | return_agg_weights: bool=False) -> torch.Tensor: 86 | """ 87 | Forward pass of the One-hop Attention Node Label Aggregator. 88 | 89 | Parameters 90 | ---------- 91 | x: 92 | Tensor containing the omics features of the nodes in the current 93 | node batch including sampled neighbors. 94 | (Size: n_nodes_batch_and_sampled_neighbors x n_node_features) 95 | edge_index: 96 | Tensor containing the node indices of edges in the current node 97 | batch including sampled neighbors. 98 | (Size: 2 x n_edges_batch_and_sampled_neighbors) 99 | return_agg_weights: 100 | If ´True´, also return the aggregation weights (attention weights). 101 | 102 | Returns 103 | ---------- 104 | x_neighbors: 105 | Tensor containing the node labels of the nodes in the current node 106 | batch excluding sampled neighbors. These labels are used for the 107 | omics feature reconstruction task. 108 | (Size: n_nodes_batch x (2 x n_node_features)) 109 | alpha: 110 | Aggregation weights for edges in ´edge_index´. 111 | """ 112 | x_l = x_r = x 113 | g_l = self.linear_l_l(x_l).view(-1, self.n_heads, self.n_input) 114 | g_r = self.linear_r_l(x_r).view(-1, self.n_heads, self.n_input) 115 | x_l = x_l.repeat(1, self.n_heads).view(-1, self.n_heads, self.n_input) 116 | 117 | output = self.propagate(edge_index, x=(x_l, x_r), g=(g_l, g_r)) 118 | x_neighbors = output.mean(dim=1) 119 | alpha = self._alpha 120 | self._alpha = None 121 | if return_agg_weights: 122 | return x_neighbors, alpha 123 | return x_neighbors, None 124 | 125 | def message(self, 126 | x_j: torch.Tensor, 127 | x_i: torch.Tensor, 128 | g_j: torch.Tensor, 129 | g_i: torch.Tensor, 130 | index: torch.Tensor) -> torch.Tensor: 131 | """ 132 | Message method of the MessagePassing parent class. Variables with "_i" 133 | suffix refer to the central nodes that aggregate information. Variables 134 | with "_j" suffix refer to the neigboring nodes. 135 | 136 | Parameters 137 | ---------- 138 | x_j: 139 | Gene expression of neighboring nodes (dim: n_index x n_heads x 140 | n_node_features). 141 | g_i: 142 | Key vector of central nodes (dim: n_index x n_heads x 143 | n_node_features). 144 | g_j: 145 | Query vector of neighboring nodes (dim: n_index x n_heads x 146 | n_node_features). 147 | """ 148 | g = g_i + g_j 149 | g = self.activation(g) 150 | alpha = (g * self.attn).sum(dim=-1) 151 | alpha = softmax(alpha, index) # index is 2nd dim of edge_index (index of 152 | # central node over which softmax should 153 | # be applied) 154 | self._alpha = alpha 155 | alpha = self.dropout(alpha) 156 | return x_j * alpha.unsqueeze(-1) 157 | 158 | 159 | class OneHopGCNNormNodeLabelAggregator(nn.Module): 160 | """ 161 | One-hop GCN Norm Node Label Aggregator class that uses a symmetrically 162 | normalized sum of the omics feature vector of a node's 1-hop neighbors to 163 | build an aggregated neighbor omics feature vector for a node. It returns a 164 | concatenation of the node's own omics feature vector and the gcn-norm 165 | aggregated neighbor omics feature vector as node labels for the omics 166 | reconstruction task. 167 | 168 | modality: 169 | Omics modality that is aggregated. Can be either `rna` or `atac`. 170 | """ 171 | def __init__(self, 172 | modality: Literal["rna", "atac"]): 173 | super().__init__() 174 | print(f"ONE HOP GCN NORM {modality.upper()} NODE LABEL AGGREGATOR") 175 | 176 | def forward(self, 177 | x: torch.Tensor, 178 | edge_index: torch.Tensor, 179 | return_agg_weights: bool=False) -> torch.Tensor: 180 | """ 181 | Forward pass of the One-hop GCN Norm Node Label Aggregator. 182 | 183 | Parameters 184 | ---------- 185 | x: 186 | Tensor containing the omics features of the nodes in the current 187 | node batch including sampled neighbors. 188 | (Size: n_nodes_batch_and_sampled_neighbors x n_node_features) 189 | edge_index: 190 | Tensor containing the node indices of edges in the current node 191 | batch including sampled neighbors. 192 | (Size: 2 x n_edges_batch_and_sampled_neighbors) 193 | return_agg_weights: 194 | If ´True´, also return the aggregation weights (norm weights). 195 | 196 | Returns 197 | ---------- 198 | x_neighbors: 199 | Tensor containing the node labels of the nodes in the current node 200 | batch. These labels are used for the omics reconstruction task. 201 | (Size: n_nodes_batch x (2 x n_node_features)) 202 | alpha: 203 | Neighbor aggregation weights. 204 | """ 205 | adj = SparseTensor.from_edge_index(edge_index, 206 | sparse_sizes=(x.shape[0], 207 | x.shape[0])) 208 | adj_norm = gcn_norm(adj) 209 | x_neighbors = adj_norm.t().matmul(x) 210 | if return_agg_weights: 211 | alpha = adj_norm.coo()[2] 212 | return x_neighbors, alpha 213 | return x_neighbors, None 214 | 215 | 216 | class OneHopSumNodeLabelAggregator(nn.Module): 217 | """ 218 | One-hop Sum Node Label Aggregator class that sums up the omics features of 219 | a node's 1-hop neighbors to build an aggregated neighbor omics feature 220 | vector for a node. It returns a concatenation of the node's own omics 221 | feature vector and the sum-aggregated neighbor omics feature vector as node 222 | labels for the omics reconstruction task. 223 | 224 | Parameters 225 | ---------- 226 | modality: 227 | Omics modality that is aggregated. Can be either `rna` or `atac`. 228 | """ 229 | def __init__(self, 230 | modality: Literal["rna", "atac"]): 231 | super().__init__() 232 | print(f"ONE HOP SUM {modality.upper()} NODE LABEL AGGREGATOR") 233 | 234 | def forward(self, 235 | x: torch.Tensor, 236 | edge_index:torch.Tensor, 237 | return_agg_weights: bool=False) -> torch.Tensor: 238 | """ 239 | Forward pass of the One-hop Sum Node Label Aggregator. 240 | 241 | Parameters 242 | ---------- 243 | x: 244 | Tensor containing the omics features of the nodes in the current 245 | node batch including sampled neighbors. 246 | (Size: n_nodes_batch_and_sampled_neighbors x n_node_features) 247 | edge_index: 248 | Tensor containing the node indices of edges in the current node 249 | batch including sampled neighbors. 250 | (Size: 2 x n_edges_batch_and_sampled_neighbors) 251 | 252 | Returns 253 | ---------- 254 | x_neighbors: 255 | Tensor containing the node labels of the nodes in the current node 256 | batch excluding sampled neighbors. These labels are used for the 257 | omics reconstruction task. 258 | (Size: n_nodes_batch x (2 x n_node_features)) 259 | """ 260 | adj = SparseTensor.from_edge_index(edge_index, 261 | sparse_sizes=(x.shape[0], 262 | x.shape[0])) 263 | x_neighbors = adj.t().matmul(x) 264 | return x_neighbors, None 265 | -------------------------------------------------------------------------------- /src/nichecompass/nn/decoders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains decoders used by the NicheCompass model. 3 | """ 4 | 5 | from typing import Literal, List, Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .layers import AddOnMaskedLayer 11 | from .utils import compute_cosine_similarity 12 | 13 | 14 | class CosineSimGraphDecoder(nn.Module): 15 | """ 16 | Cosine similarity graph decoder class. 17 | 18 | Takes the concatenated latent feature vectors z of the source and 19 | target nodes as input, and calculates the element-wise cosine similarity 20 | between source and target nodes to return the reconstructed edge logits. 21 | 22 | The sigmoid activation function to compute reconstructed edge probabilities 23 | is integrated into the binary cross entropy loss for computational 24 | efficiency. 25 | 26 | Parameters 27 | ---------- 28 | dropout_rate: 29 | Probability of nodes to be dropped during training. 30 | """ 31 | def __init__(self, 32 | dropout_rate: float=0.): 33 | super().__init__() 34 | print("COSINE SIM GRAPH DECODER -> " 35 | f"dropout_rate: {dropout_rate}") 36 | 37 | self.dropout = nn.Dropout(dropout_rate) 38 | 39 | def forward(self, 40 | z: torch.Tensor) -> torch.Tensor: 41 | """ 42 | Forward pass of the cosine similarity graph decoder. 43 | 44 | Parameters 45 | ---------- 46 | z: 47 | Concatenated latent feature vector of the source and target nodes 48 | (dim: 4 * edge_batch_size x n_gps due to negative edges). 49 | 50 | Returns 51 | ---------- 52 | edge_recon_logits: 53 | Reconstructed edge logits (dim: 2 * edge_batch_size due to negative 54 | edges). 55 | """ 56 | z = self.dropout(z) 57 | 58 | # Compute element-wise cosine similarity 59 | edge_recon_logits = compute_cosine_similarity( 60 | z[:int(z.shape[0]/2)], # ´edge_label_index[0]´ 61 | z[int(z.shape[0]/2):]) # ´edge_label_index[1]´ 62 | return edge_recon_logits 63 | 64 | 65 | class MaskedOmicsFeatureDecoder(nn.Module): 66 | """ 67 | Masked omics feature decoder class. 68 | 69 | Takes the latent space features z (gp scores) as input, and has a masked 70 | layer to decode the parameters of the underlying omics feature distributions. 71 | 72 | Parameters 73 | ---------- 74 | modality: 75 | Omics modality that is decoded. Can be either `rna` or `atac`. 76 | entity: 77 | Entity that is decoded. Can be either `target` or `source`. 78 | n_prior_gp_input: 79 | Number of maskable prior gp input nodes to the decoder (maskable latent 80 | space dimensionality). 81 | n_addon_gp_input: 82 | Number of non-maskable add-on gp input nodes to the decoder ( 83 | non-maskable latent space dimensionality). 84 | n_cat_covariates_embed_input: 85 | Number of categorical covariates embedding input nodes to the decoder 86 | (categorical covariates embedding dimensionality). 87 | n_output: 88 | Number of output nodes from the decoder (number of omics features). 89 | mask: 90 | Mask that determines which masked input nodes / prior gp latent features 91 | z can contribute to the reconstruction of which omics features. 92 | addon_mask: 93 | Mask that determines which add-on input nodes / add-on gp latent 94 | features z can contribute to the reconstruction of which omics features. 95 | masked_features_idx: 96 | Index of omics features that are included in the mask. 97 | recon_loss: 98 | The loss used for omics reconstruction. If `nb`, uses a negative 99 | binomial loss. 100 | """ 101 | def __init__(self, 102 | modality: Literal["rna", "atac"], 103 | entity: Literal["target", "source"], 104 | n_prior_gp_input: int, 105 | n_addon_gp_input: int, 106 | n_cat_covariates_embed_input: int, 107 | n_output: int, 108 | mask: torch.Tensor, 109 | addon_mask: torch.Tensor, 110 | masked_features_idx: List, 111 | recon_loss: Literal["nb"]): 112 | super().__init__() 113 | print(f"MASKED {entity.upper()} {modality.upper()} DECODER -> " 114 | f"n_prior_gp_input: {n_prior_gp_input}, " 115 | f"n_addon_gp_input: {n_addon_gp_input}, " 116 | f"n_cat_covariates_embed_input: {n_cat_covariates_embed_input}, " 117 | f"n_output: {n_output}") 118 | 119 | self.masked_features_idx = masked_features_idx 120 | self.recon_loss = recon_loss 121 | 122 | self.nb_means_normalized_decoder = AddOnMaskedLayer( 123 | n_input=n_prior_gp_input, 124 | n_addon_input=n_addon_gp_input, 125 | n_cat_covariates_embed_input=n_cat_covariates_embed_input, 126 | n_output=n_output, 127 | bias=False, 128 | mask=mask, 129 | addon_mask=addon_mask, 130 | masked_features_idx=masked_features_idx, 131 | activation=nn.Softmax(dim=-1)) 132 | 133 | def forward(self, 134 | z: torch.Tensor, 135 | log_library_size: torch.Tensor, 136 | cat_covariates_embed: Optional[torch.Tensor]=None, 137 | dynamic_mask: Optional[torch.Tensor]=None) -> torch.Tensor: 138 | """ 139 | Forward pass of the masked omics feature decoder. 140 | 141 | Parameters 142 | ---------- 143 | z: 144 | Tensor containing the latent space features. 145 | log_library_size: 146 | Tensor containing the omics feature log library size of the nodes. 147 | dynamic_mask: 148 | Dynamic mask that can change in each forward pass. Is used for atac 149 | modality: if a gene is removed by regularization in the rna decoder 150 | (its weight is set to 0), the corresponding peaks will be marked as 0 151 | in the `dynamic_mask`. 152 | cat_covariates_embed: 153 | Tensor containing the categorical covariates embedding (all 154 | categorical covariates embeddings concatenated into one embedding). 155 | 156 | Returns 157 | ---------- 158 | nb_means: 159 | The mean parameters of the negative binomial distribution. 160 | """ 161 | # Add categorical covariates embedding to latent feature vector 162 | if cat_covariates_embed is not None: 163 | z = torch.cat((z, cat_covariates_embed), dim=-1) 164 | 165 | nb_means_normalized = self.nb_means_normalized_decoder( 166 | input=z, 167 | dynamic_mask=dynamic_mask) 168 | nb_means = torch.exp(log_library_size) * nb_means_normalized 169 | return nb_means 170 | 171 | 172 | class FCOmicsFeatureDecoder(nn.Module): 173 | """ 174 | Fully connected omics feature decoder class. 175 | 176 | Takes the latent space features z as input, and has a fully connected layer 177 | to decode the parameters of the underlying omics feature distributions. 178 | 179 | Parameters 180 | ---------- 181 | modality: 182 | Omics modality that is decoded. Can be either `rna` or `atac`. 183 | entity: 184 | Entity that is decoded. Can be either `target` or `source`. 185 | n_prior_gp_input: 186 | Number of maskable prior gp input nodes to the decoder (maskable latent 187 | space dimensionality). 188 | n_addon_gp_input: 189 | Number of non-maskable add-on gp input nodes to the decoder ( 190 | non-maskable latent space dimensionality). 191 | n_cat_covariates_embed_input: 192 | Number of categorical covariates embedding input nodes to the decoder 193 | (categorical covariates embedding dimensionality). 194 | n_output: 195 | Number of output nodes from the decoder (number of omics features). 196 | n_layers: 197 | Number of fully connected layers used for decoding. 198 | recon_loss: 199 | The loss used for omics reconstruction. If `nb`, uses a negative 200 | binomial loss. 201 | """ 202 | def __init__(self, 203 | modality: Literal["rna", "atac"], 204 | entity: Literal["target", "source"], 205 | n_prior_gp_input: int, 206 | n_addon_gp_input: int, 207 | n_cat_covariates_embed_input: int, 208 | n_output: int, 209 | n_layers: int, 210 | recon_loss: Literal["nb"]): 211 | super().__init__() 212 | print(f"FC {entity.upper()} {modality.upper()} DECODER -> " 213 | f"n_prior_gp_input: {n_prior_gp_input}, " 214 | f"n_addon_gp_input: {n_addon_gp_input}, " 215 | f"n_cat_covariates_embed_input: {n_cat_covariates_embed_input}, " 216 | f"n_output: {n_output}") 217 | 218 | self.n_input = (n_prior_gp_input 219 | + n_addon_gp_input 220 | + n_cat_covariates_embed_input) 221 | self.recon_loss = recon_loss 222 | 223 | if n_layers == 1: 224 | self.nb_means_normalized_decoder = nn.Sequential( 225 | nn.Linear(self.n_input, n_output, bias=False), 226 | nn.Softmax(dim=-1)) 227 | elif n_layers == 2: 228 | self.nb_means_normalized_decoder = nn.Sequential( 229 | nn.Linear(self.n_input, self.n_input, bias=False), 230 | nn.ReLU(), 231 | nn.Linear(self.n_input, n_output, bias=False), 232 | nn.Softmax(dim=-1)) 233 | 234 | def forward(self, 235 | z: torch.Tensor, 236 | log_library_size: torch.Tensor, 237 | cat_covariates_embed: Optional[torch.Tensor]=None, 238 | **kwargs) -> torch.Tensor: 239 | """ 240 | Forward pass of the fully connected omics feature decoder. 241 | 242 | Parameters 243 | ---------- 244 | z: 245 | Tensor containing the latent space features. 246 | log_library_size: 247 | Tensor containing the log library size of the nodes. 248 | dynamic_mask: 249 | Dynamic mask that can change in each forward pass. Is used for atac 250 | modality. If a gene is removed by regularization in the rna decoder 251 | (its weight is set to 0), the corresponding peaks will be marked as 0 252 | in the `dynamic_mask`. 253 | cat_covariates_embed: 254 | Tensor containing the categorical covariates embedding (all 255 | categorical covariates embeddings concatenated into one embedding). 256 | 257 | Returns 258 | ---------- 259 | nb_means: 260 | The mean parameters of the negative binomial distribution. 261 | """ 262 | # Add categorical covariates embedding to latent feature vector 263 | if cat_covariates_embed is not None: 264 | z = torch.cat((z, cat_covariates_embed), dim=-1) 265 | 266 | nb_means_normalized = self.nb_means_normalized_decoder(input=z) 267 | nb_means = torch.exp(log_library_size) * nb_means_normalized 268 | return nb_means 269 | -------------------------------------------------------------------------------- /src/nichecompass/nn/encoders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the encoder used by the NicheCompass model. 3 | """ 4 | 5 | from typing import Literal, Optional, Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch_geometric.nn import GATv2Conv, GCNConv 10 | 11 | 12 | class Encoder(nn.Module): 13 | """ 14 | Encoder class. 15 | 16 | Takes the input space features x and the edge indices as input, first computes 17 | fully connected layers and then uses message passing layers to output mu and 18 | logstd of the latent space normal distribution. 19 | 20 | Parameters 21 | ---------- 22 | n_input: 23 | Number of input nodes (omics features) to the encoder. 24 | n_cat_covariates_embed_input: 25 | Number of categorical covariates embedding input nodes to the encoder. 26 | n_hidden: 27 | Number of hidden nodes outputted after the fully connected layers and 28 | intermediate message passing layers. 29 | n_latent: 30 | Number of output nodes (prior gps) from the encoder, making up the 31 | first part of the latent space features z. 32 | n_addon_latent: 33 | Number of add-on nodes in the latent space (new gps), making up the 34 | second part of the latent space features z. 35 | n_fc_layers: 36 | Number of fully connected layers before the message passing layers. 37 | conv_layer: 38 | Message passing layer used. 39 | n_layers: 40 | Number of message passing layers. 41 | cat_covariates_embed_mode: 42 | Indicates where to inject the categorical covariates embedding if 43 | injected. 44 | n_attention_heads: 45 | Only relevant if ´conv_layer == gatv2conv´. Number of attention heads 46 | used. 47 | dropout_rate: 48 | Probability of nodes to be dropped in the hidden layer during training. 49 | activation: 50 | Activation function used after the fully connected layers and 51 | intermediate message passing layers. 52 | use_bn: 53 | If ´True´, use a batch normalization layer at the end to normalize ´mu´. 54 | """ 55 | def __init__(self, 56 | n_input: int, 57 | n_cat_covariates_embed_input: int, 58 | n_hidden: int, 59 | n_latent: int, 60 | n_addon_latent: int=100, 61 | n_fc_layers: int=1, 62 | conv_layer: Literal["gcnconv", "gatv2conv"]="gatv2conv", 63 | n_layers: int=1, 64 | cat_covariates_embed_mode: Literal["input", "hidden"]="input", 65 | n_attention_heads: int=4, 66 | dropout_rate: float=0., 67 | activation: nn.Module=nn.ReLU, 68 | use_bn: bool=True): 69 | super().__init__() 70 | print("ENCODER -> " 71 | f"n_input: {n_input}, " 72 | f"n_cat_covariates_embed_input: {n_cat_covariates_embed_input}, " 73 | f"n_hidden: {n_hidden}, " 74 | f"n_latent: {n_latent}, " 75 | f"n_addon_latent: {n_addon_latent}, " 76 | f"n_fc_layers: {n_fc_layers}, " 77 | f"n_layers: {n_layers}, " 78 | f"conv_layer: {conv_layer}, " 79 | f"n_attention_heads: " 80 | f"{n_attention_heads if conv_layer == 'gatv2conv' else '0'}, " 81 | f"dropout_rate: {dropout_rate}, ") 82 | 83 | self.n_addon_latent = n_addon_latent 84 | self.n_layers = n_layers 85 | self.n_fc_layers = n_fc_layers 86 | self.cat_covariates_embed_mode = cat_covariates_embed_mode 87 | 88 | if ((cat_covariates_embed_mode == "input") & 89 | (n_cat_covariates_embed_input != 0)): 90 | # Add categorical covariates embedding to input 91 | n_input += n_cat_covariates_embed_input 92 | 93 | if n_fc_layers == 2: 94 | self.fc_l1 = nn.Linear(n_input, int(n_input / 2)) 95 | self.fc_l2 = nn.Linear(int(n_input / 2), n_hidden) 96 | self.fc_l2_bn = nn.BatchNorm1d(n_hidden) 97 | elif n_fc_layers == 1: 98 | self.fc_l1 = nn.Linear(n_input, n_hidden) 99 | 100 | if ((cat_covariates_embed_mode == "hidden") & 101 | (n_cat_covariates_embed_input != 0)): 102 | # Add categorical covariates embedding to hidden after fc_l 103 | n_hidden += n_cat_covariates_embed_input 104 | 105 | if conv_layer == "gcnconv": 106 | if n_layers == 2: 107 | self.conv_l1 = GCNConv(n_hidden, 108 | n_hidden) 109 | self.conv_mu = GCNConv(n_hidden, 110 | n_latent) 111 | self.conv_logstd = GCNConv(n_hidden, 112 | n_latent) 113 | if n_addon_latent != 0: 114 | self.addon_conv_mu = GCNConv(n_hidden, 115 | n_addon_latent) 116 | self.addon_conv_logstd = GCNConv(n_hidden, 117 | n_addon_latent) 118 | elif conv_layer == "gatv2conv": 119 | if n_layers == 2: 120 | self.conv_l1 = GATv2Conv(n_hidden, 121 | n_hidden, 122 | heads=n_attention_heads, 123 | concat=False) 124 | self.conv_mu = GATv2Conv(n_hidden, 125 | n_latent, 126 | heads=n_attention_heads, 127 | concat=False) 128 | self.conv_logstd = GATv2Conv(n_hidden, 129 | n_latent, 130 | heads=n_attention_heads, 131 | concat=False) 132 | if n_addon_latent != 0: 133 | self.addon_conv_mu = GATv2Conv(n_hidden, 134 | n_addon_latent, 135 | heads=n_attention_heads, 136 | concat=False) 137 | self.addon_conv_logstd = GATv2Conv(n_hidden, 138 | n_addon_latent, 139 | heads=n_attention_heads, 140 | concat=False) 141 | self.activation = activation 142 | self.dropout = nn.Dropout(dropout_rate) 143 | 144 | def forward(self, 145 | x: torch.Tensor, 146 | edge_index: torch.Tensor, 147 | cat_covariates_embed: Optional[torch.Tensor]=None 148 | ) -> Tuple[torch.Tensor, torch.Tensor]: 149 | """ 150 | Forward pass of the encoder. 151 | 152 | Parameters 153 | ---------- 154 | x: 155 | Tensor containing the omics features. 156 | edge_index: 157 | Tensor containing the edge indices for message passing. 158 | cat_covariates_embed: 159 | Tensor containing the categorical covariates embedding (all 160 | categorical covariates embeddings concatenated into one embedding). 161 | 162 | Returns 163 | ---------- 164 | mu: 165 | Tensor containing the expected values of the latent space normal 166 | distribution. 167 | logstd: 168 | Tensor containing the log standard deviations of the latent space 169 | normal distribution. 170 | """ 171 | if ((self.cat_covariates_embed_mode == "input") & 172 | (cat_covariates_embed is not None)): 173 | # Add categorical covariates embedding to input vector 174 | x = torch.cat((x, 175 | cat_covariates_embed), 176 | axis=1) 177 | 178 | # FC forward pass shared across all nodes 179 | hidden = self.dropout(self.activation(self.fc_l1(x))) 180 | if self.n_fc_layers == 2: 181 | hidden = self.dropout(self.activation(self.fc_l2(hidden))) 182 | hidden = self.fc_l2_bn(hidden) 183 | 184 | if ((self.cat_covariates_embed_mode == "hidden") & 185 | (cat_covariates_embed is not None)): 186 | # Add categorical covariates embedding to hidden vector 187 | hidden = torch.cat((hidden, 188 | cat_covariates_embed), 189 | axis=1) 190 | 191 | if self.n_layers == 2: 192 | # Part of forward pass shared across all nodes 193 | hidden = self.dropout(self.activation( 194 | self.conv_l1(hidden, edge_index))) 195 | 196 | # Part of forward pass only for maskable latent nodes 197 | mu = self.conv_mu(hidden, edge_index) 198 | logstd = self.conv_logstd(hidden, edge_index) 199 | 200 | # Part of forward pass only for unmaskable add-on latent nodes 201 | if self.n_addon_latent != 0: 202 | mu = torch.cat( 203 | (mu, self.addon_conv_mu(hidden, edge_index)), 204 | dim=1) 205 | logstd = torch.cat( 206 | (logstd, self.addon_conv_logstd(hidden, edge_index)), 207 | dim=1) 208 | return mu, logstd 209 | -------------------------------------------------------------------------------- /src/nichecompass/nn/layercomponents.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains neural network layer components used by the NicheCompass 3 | model. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class MaskedLinear(nn.Linear): 13 | """ 14 | Masked linear class. 15 | 16 | Parts of the implementation are adapted from 17 | https://github.com/theislab/scarches/blob/master/scarches/models/expimap/modules.py#L9; 18 | 01.10.2022. 19 | 20 | Uses static and dynamic binary masks to mask connections from the input 21 | layer to the output layer so that only unmasked connections can be used. 22 | 23 | Parameters 24 | ---------- 25 | n_input: 26 | Number of input nodes to the masked layer. 27 | n_output: 28 | Number of output nodes from the masked layer. 29 | mask: 30 | Static mask that is used to mask the node connections from the input 31 | layer to the output layer. 32 | bias: 33 | If ´True´, use a bias. 34 | """ 35 | def __init__(self, 36 | n_input: int, 37 | n_output: int, 38 | mask: torch.Tensor, 39 | bias=False): 40 | # Mask should have dim n_input x n_output 41 | if n_input != mask.shape[0] or n_output != mask.shape[1]: 42 | raise ValueError("Incorrect shape of the mask. Mask should have dim" 43 | " (n_input x n_output). Please provide a mask with" 44 | f" dimensions ({n_input} x {n_output}).") 45 | super().__init__(n_input, n_output, bias) 46 | 47 | self.register_buffer("mask", mask.t()) 48 | 49 | # Zero out weights with the mask so that the optimizer does not 50 | # consider them 51 | self.weight.data *= self.mask 52 | 53 | def forward(self, 54 | input: torch.Tensor, 55 | dynamic_mask: Optional[torch.Tensor]=None) -> torch.Tensor: 56 | """ 57 | Forward pass of the masked linear class. 58 | 59 | Parameters 60 | ---------- 61 | input: 62 | Tensor containing the input features to the masked linear class. 63 | dynamic_mask: 64 | Additional optional Tensor containing a mask that changes 65 | during training. 66 | 67 | Returns 68 | ---------- 69 | output: 70 | Tensor containing the output of the masked linear class (linear 71 | transformation of the input by only considering unmasked 72 | connections). 73 | """ 74 | if dynamic_mask is not None: 75 | dynamic_mask = dynamic_mask.t().to(self.mask.device) 76 | self.weight.data *= dynamic_mask 77 | masked_weights = self.weight * self.mask * dynamic_mask 78 | else: 79 | masked_weights = self.weight * self.mask 80 | output = nn.functional.linear(input, masked_weights, self.bias) 81 | return output 82 | -------------------------------------------------------------------------------- /src/nichecompass/nn/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains neural network layers used by the NicheCompass model. 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .layercomponents import MaskedLinear 11 | 12 | 13 | class AddOnMaskedLayer(nn.Module): 14 | """ 15 | Add-on masked layer class. 16 | 17 | Parts of the implementation are adapted from 18 | https://github.com/theislab/scarches/blob/7980a187294204b5fb5d61364bb76c0b809eb945/scarches/models/expimap/modules.py#L28; 19 | 01.10.2022. 20 | 21 | Parameters 22 | ---------- 23 | n_input: 24 | Number of mask input nodes to the add-on masked layer. 25 | n_output: 26 | Number of output nodes from the add-on masked layer. 27 | mask: 28 | Mask that is used to mask the node connections for mask inputs from the 29 | input layer to the output layer. 30 | addon_mask: 31 | Mask that is used to mask the node connections for add-on inputs from 32 | the input layer to the output layer. 33 | masked_features_idx: 34 | Index of input features that are included in the mask. 35 | bias: 36 | If ´True´, use a bias for the mask input nodes. 37 | n_addon_input: 38 | Number of add-on input nodes to the add-on masked layer. 39 | n_cat_covariates_embed_input: 40 | Number of categorical covariates embedding input nodes to the addon 41 | masked layer. 42 | activation: 43 | Activation function used at the end of the ad-on masked layer. 44 | """ 45 | def __init__(self, 46 | n_input: int, 47 | n_output: int, 48 | mask: torch.Tensor, 49 | addon_mask: torch.Tensor, 50 | masked_features_idx: List, 51 | bias: bool=False, 52 | n_addon_input: int=0, 53 | n_cat_covariates_embed_input: int=0, 54 | activation: nn.Module=nn.Softmax(dim=-1)): 55 | super().__init__() 56 | self.n_input = n_input 57 | self.n_addon_input = n_addon_input 58 | self.n_cat_covariates_embed_input = n_cat_covariates_embed_input 59 | self.masked_features_idx = masked_features_idx 60 | 61 | # Masked layer 62 | self.masked_l = MaskedLinear(n_input=n_input, 63 | n_output=n_output, 64 | mask=mask, 65 | bias=bias) 66 | 67 | # Add-on layer 68 | if n_addon_input != 0: 69 | self.addon_l = MaskedLinear(n_input=n_addon_input, 70 | n_output=n_output, 71 | mask=addon_mask, 72 | bias=False) 73 | 74 | # Categorical covariates embedding layer 75 | if n_cat_covariates_embed_input != 0: 76 | self.cat_covariates_embed_l = nn.Linear( 77 | n_cat_covariates_embed_input, 78 | n_output, 79 | bias=False) 80 | 81 | self.activation = activation 82 | 83 | def forward(self, 84 | input: torch.Tensor, 85 | dynamic_mask: Optional[torch.Tensor]=None) -> torch.Tensor: 86 | """ 87 | Forward pass of the add-on masked layer. 88 | 89 | Parameters 90 | ---------- 91 | input: 92 | Input features to the add-on masked layer. Includes add-on input 93 | nodes and categorical covariates embedding input nodes if specified. 94 | dynamic_mask: 95 | Additional optional dynamic mask for the masked layer. 96 | 97 | Returns 98 | ---------- 99 | output: 100 | Output of the add-on masked layer. 101 | """ 102 | if (self.n_addon_input == 0) & (self.n_cat_covariates_embed_input == 0): 103 | mask_input = input 104 | elif ((self.n_addon_input != 0) & 105 | (self.n_cat_covariates_embed_input == 0)): 106 | mask_input, addon_input = torch.split( 107 | input, 108 | [self.n_input, self.n_addon_input], 109 | dim=1) 110 | elif ((self.n_addon_input == 0) & 111 | (self.n_cat_covariates_embed_input != 0)): 112 | mask_input, cat_covariates_embed_input = torch.split( 113 | input, 114 | [self.n_input, self.n_cat_covariates_embed_input], 115 | dim=1) 116 | elif ((self.n_addon_input != 0) & 117 | (self.n_cat_covariates_embed_input != 0)): 118 | mask_input, addon_input, cat_covariates_embed_input = torch.split( 119 | input, 120 | [self.n_input, self.n_addon_input, self.n_cat_covariates_embed_input], 121 | dim=1) 122 | 123 | output = self.masked_l( 124 | input=mask_input, 125 | dynamic_mask=(dynamic_mask[:self.n_input, :] if 126 | dynamic_mask is not None else None)) 127 | # Dynamic mask also has entries for add-on gps 128 | if self.n_addon_input != 0: 129 | # Only unmasked features will have weights != 0. 130 | output += self.addon_l( 131 | input=addon_input, 132 | dynamic_mask=(dynamic_mask[self.n_input:, :] if 133 | dynamic_mask is not None else None)) 134 | if self.n_cat_covariates_embed_input != 0: 135 | if self.n_addon_input != 0: 136 | output += self.cat_covariates_embed_l( 137 | cat_covariates_embed_input) 138 | else: 139 | # Only add categorical covariates embedding layer output to 140 | # masked features 141 | output[:, self.masked_features_idx] += self.cat_covariates_embed_l( 142 | cat_covariates_embed_input)[:, self.masked_features_idx] 143 | output = self.activation(output) 144 | return output 145 | -------------------------------------------------------------------------------- /src/nichecompass/nn/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains helper functions for the ´nn´ subpackage. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def compute_cosine_similarity(tensor1: torch.Tensor, 9 | tensor2: torch.Tensor, 10 | eps: float=1e-8) -> torch.Tensor: 11 | """ 12 | Compute the element-wise cosine similarity between two 2D tensors. 13 | 14 | Parameters 15 | ---------- 16 | tensor1: 17 | First tensor for element-wise cosine similarity computation (dim: n_obs 18 | x n_features). 19 | tensor2: 20 | Second tensor for element-wise cosine similarity computation (dim: n_obs 21 | x n_features). 22 | 23 | Returns 24 | ---------- 25 | cosine_sim: 26 | Result tensor that contains the computed element-wise cosine 27 | similarities (dim: n_obs). 28 | """ 29 | tensor1_norm = tensor1.norm(dim=1)[:, None] 30 | tensor2_norm = tensor2.norm(dim=1)[:, None] 31 | tensor1_normalized = tensor1 / torch.max( 32 | tensor1_norm, eps * torch.ones_like(tensor1_norm)) 33 | tensor2_normalized = tensor2 / torch.max( 34 | tensor2_norm, eps * torch.ones_like(tensor2_norm)) 35 | cosine_sim = torch.mul(tensor1_normalized, tensor2_normalized).sum(1) 36 | return cosine_sim 37 | -------------------------------------------------------------------------------- /src/nichecompass/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import eval_metrics, plot_eval_metrics 2 | from .trainer import Trainer 3 | 4 | __all__ = ["eval_metrics", 5 | "plot_eval_metrics", 6 | "Trainer"] -------------------------------------------------------------------------------- /src/nichecompass/train/basetrainermixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains generic trainer functionalities, added as a Mixin to 3 | the Trainer module. 4 | """ 5 | 6 | import inspect 7 | 8 | 9 | class BaseTrainerMixin: 10 | """ 11 | Base trainer mix in class containing universal trainer functionalities. 12 | 13 | Parts of the implementation are adapted from 14 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_base_model.py#L63 15 | (01.10.2022). 16 | """ 17 | def _get_user_attributes(self) -> list: 18 | """ 19 | Get all the attributes defined in a trainer instance. 20 | 21 | Returns 22 | ---------- 23 | attributes: 24 | Attributes defined in a trainer instance. 25 | """ 26 | attributes = inspect.getmembers( 27 | self, lambda a: not (inspect.isroutine(a))) 28 | attributes = [a for a in attributes if not ( 29 | a[0].startswith("__") and a[0].endswith("__"))] 30 | return attributes 31 | 32 | def _get_public_attributes(self) -> dict: 33 | """ 34 | Get only public attributes defined in a trainer instance. By convention 35 | public attributes have a trailing underscore. 36 | 37 | Returns 38 | ---------- 39 | public_attributes: 40 | Public attributes defined in a trainer instance. 41 | """ 42 | public_attributes = self._get_user_attributes() 43 | public_attributes = {a[0]: a[1] for a in public_attributes if 44 | a[0][-1] == "_"} 45 | return public_attributes -------------------------------------------------------------------------------- /src/nichecompass/train/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains metrics to evaluate the NicheCompass model training. 3 | """ 4 | 5 | from typing import Optional, Union 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import sklearn.metrics as skm 10 | import torch 11 | from matplotlib.ticker import MaxNLocator 12 | 13 | 14 | def eval_metrics( 15 | edge_recon_probs: Union[torch.Tensor, np.ndarray], 16 | edge_labels: Union[torch.Tensor, np.ndarray], 17 | edge_same_cat_covariates_cat: Optional[Union[torch.Tensor, np.ndarray]]=None, 18 | edge_incl: Optional[Union[torch.Tensor, np.ndarray]]=None, 19 | omics_pred_dict: Optional[dict]=None) -> dict: 20 | """ 21 | Get the evaluation metrics for a (balanced) sample of positive and negative 22 | edges and a sample of nodes. 23 | 24 | Parameters 25 | ---------- 26 | edge_recon_probs: 27 | Tensor or array containing reconstructed edge probabilities. 28 | edge_labels: 29 | Tensor or array containing ground truth labels of edges. 30 | edge_incl: 31 | Boolean tensor or array indicating whether the edge should be included 32 | in the evaluation. 33 | target_rna_preds: 34 | Tensor or array containing the predicted gene expression. 35 | target_rna: 36 | Tensor or array containing the ground truth gene expression. 37 | source_rna_preds: 38 | Tensor or array containing the predicted gene expression. 39 | source_rna: 40 | Tensor or array containing the ground truth gene expression. 41 | chrom_access_preds: 42 | Tensor or array containing the predicted chromatin accessibility. 43 | chrom_access: 44 | Tensor or array containing the ground truth chromatin accessibility. 45 | 46 | Returns 47 | ---------- 48 | eval_dict: 49 | Dictionary containing the evaluation metrics ´auroc_score´ (area under 50 | the receiver operating characteristic curve), ´auprc score´ (area under 51 | the precision-recall curve), ´best_acc_score´ (accuracy under optimal 52 | classification threshold) and ´best_f1_score´ (F1 score under optimal 53 | classification threshold). 54 | """ 55 | eval_dict = {} 56 | 57 | if isinstance(edge_recon_probs, torch.Tensor): 58 | edge_recon_probs = edge_recon_probs.detach().cpu().numpy() 59 | if isinstance(edge_labels, torch.Tensor): 60 | edge_labels = edge_labels.detach().cpu().numpy() 61 | if isinstance(edge_incl, torch.Tensor): 62 | edge_incl = edge_incl.detach().cpu().numpy() 63 | 64 | if omics_pred_dict is not None: 65 | for key, value in omics_pred_dict.items(): 66 | if isinstance(value, torch.Tensor): 67 | omics_pred_dict[key] = value.detach().cpu().numpy() 68 | 69 | if "target_rna_preds" in omics_pred_dict.keys(): 70 | # Calculate the gene expression mean squared error 71 | eval_dict["target_rna_mse_score"] = skm.mean_squared_error( 72 | omics_pred_dict["target_rna"], 73 | omics_pred_dict["target_rna_preds"]) 74 | eval_dict["source_rna_mse_score"] = skm.mean_squared_error( 75 | omics_pred_dict["source_rna"], 76 | omics_pred_dict["source_rna_preds"]) 77 | 78 | if "target_atac_preds" in omics_pred_dict.keys(): 79 | # Calculate the gene expression mean squared error 80 | eval_dict["target_atac_mse_score"] = skm.mean_squared_error( 81 | omics_pred_dict["target_atac"], 82 | omics_pred_dict["target_atac_preds"]) 83 | eval_dict["source_atac_mse_score"] = skm.mean_squared_error( 84 | omics_pred_dict["source_atac"], 85 | omics_pred_dict["source_atac_preds"]) 86 | 87 | if edge_same_cat_covariates_cat is not None: 88 | for i, edge_same_cat_covariate_cat in enumerate(edge_same_cat_covariates_cat): 89 | # Only include negative sampled edges (edge label is 0) 90 | edge_same_cat_covariate_cat_incl = edge_labels == 0 91 | edge_same_cat_covariate_cat_recon_probs = edge_recon_probs[edge_same_cat_covariate_cat_incl] 92 | edge_same_cat_covariate_cat_labels = edge_same_cat_covariate_cat[edge_same_cat_covariate_cat_incl] 93 | same_cat_mask = edge_same_cat_covariate_cat_labels == 1 94 | diff_cat_mask = edge_same_cat_covariate_cat_labels == 0 95 | same_cat_mean = np.mean(edge_same_cat_covariate_cat_recon_probs[same_cat_mask]) 96 | diff_cat_mean = np.mean(edge_same_cat_covariate_cat_recon_probs[diff_cat_mask]) 97 | eval_dict[f"cat_covariate{i}_mean_sim_diff"] = diff_cat_mean - same_cat_mean 98 | 99 | if edge_incl is not None: 100 | edge_incl = edge_incl.astype(bool) 101 | # Remove edges whose node pair has different categories in categorical 102 | # covariates for which no cross-category edges are present 103 | edge_recon_probs = edge_recon_probs[edge_incl] 104 | edge_labels = edge_labels[edge_incl] 105 | 106 | # Calculate threshold independent metrics 107 | eval_dict["auroc_score"] = skm.roc_auc_score(edge_labels, edge_recon_probs) 108 | eval_dict["auprc_score"] = skm.average_precision_score(edge_labels, 109 | edge_recon_probs) 110 | 111 | # Get the optimal classification probability threshold above which an edge 112 | # is classified as positive so that the threshold optimizes the accuracy 113 | # over the sampled (balanced) set of positive and negative edges. 114 | best_acc_score = 0 115 | best_threshold = 0 116 | for threshold in np.arange(0.01, 1, 0.005): 117 | pred_labels = (edge_recon_probs > threshold).astype("int") 118 | acc_score = skm.accuracy_score(edge_labels, pred_labels) 119 | if acc_score > best_acc_score: 120 | best_threshold = threshold 121 | best_acc_score = acc_score 122 | eval_dict["best_acc_score"] = best_acc_score 123 | eval_dict["best_acc_threshold"] = best_threshold 124 | 125 | # Get the optimal classification probability threshold above which an edge 126 | # is classified as positive so that the threshold optimizes the F1 score 127 | # over the sampled (balanced) set of positive and negative edges. 128 | best_f1_score = 0 129 | for threshold in np.arange(0.01, 1, 0.005): 130 | pred_labels = (edge_recon_probs > threshold).astype("int") 131 | f1_score = skm.f1_score(edge_labels, pred_labels) 132 | if f1_score > best_f1_score: 133 | best_f1_score = f1_score 134 | eval_dict["best_f1_score"] = best_f1_score 135 | return eval_dict 136 | 137 | 138 | def plot_eval_metrics(eval_dict: dict) -> plt.figure: 139 | """ 140 | Plot evaluation metrics. 141 | 142 | Parameters 143 | ---------- 144 | eval_dict: 145 | Dictionary containing the eval metric scores to be plotted. 146 | 147 | Returns 148 | ---------- 149 | fig: 150 | Matplotlib figure containing a plot of the evaluation metrics. 151 | """ 152 | # Plot epochs as integers 153 | ax = plt.figure().gca() 154 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 155 | 156 | # Plot eval metrics 157 | for metric_key, metric_scores in eval_dict.items(): 158 | plt.plot(metric_scores, label=metric_key) 159 | plt.title("Evaluation metrics over epochs") 160 | plt.ylabel("metric score") 161 | plt.xlabel("epoch") 162 | plt.legend(loc="lower right") 163 | 164 | # Retrieve figure 165 | fig = plt.gcf() 166 | plt.close() 167 | return fig -------------------------------------------------------------------------------- /src/nichecompass/train/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains helper functions for the ´train´ subpackage. 3 | """ 4 | 5 | import sys 6 | from typing import Tuple 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from matplotlib.ticker import MaxNLocator 11 | 12 | 13 | class EarlyStopping: 14 | """ 15 | EarlyStopping class for early stopping of NicheCompass training. 16 | 17 | Parts of the implementation are adapted from 18 | https://github.com/theislab/scarches/blob/cb54fa0df3255ad1576a977b17e9d77d4907ceb0/scarches/utils/monitor.py#L4 19 | (01.10.2022). 20 | 21 | Parameters 22 | ---------- 23 | early_stopping_metric: 24 | The metric on which the early stopping criterion is calculated. 25 | metric_improvement_threshold: 26 | The minimum value which counts as metric_improvement. 27 | patience: 28 | Number of epochs which are allowed to have no metric improvement until 29 | the training is stopped. 30 | reduce_lr_on_plateau: 31 | If ´True´, the learning rate gets adjusted by ´lr_factor´ after a given 32 | number of epochs with no metric improvement. 33 | lr_patience: 34 | Number of epochs which are allowed to have no metric improvement until 35 | the learning rate is adjusted. 36 | lr_factor: 37 | Scaling factor for adjusting the learning rate. 38 | """ 39 | def __init__(self, 40 | early_stopping_metric: str="val_global_loss", 41 | metric_improvement_threshold: float=0., 42 | patience: int=8, 43 | reduce_lr_on_plateau: bool=True, 44 | lr_patience: int=4, 45 | lr_factor: float=0.1): 46 | self.early_stopping_metric = early_stopping_metric 47 | self.metric_improvement_threshold = metric_improvement_threshold 48 | self.patience = patience 49 | self.reduce_lr_on_plateau = reduce_lr_on_plateau 50 | self.lr_patience = lr_patience 51 | self.lr_factor = lr_factor 52 | self.epochs = 0 53 | self.epochs_not_improved = 0 54 | self.epochs_not_improved_lr = 0 55 | self.current_performance = np.inf 56 | self.best_performance = np.inf 57 | self.best_performance_state = np.inf 58 | 59 | def step(self, current_metric: float) -> Tuple[bool, bool]: 60 | self.epochs += 1 61 | 62 | # Calculate metric improvement 63 | self.current_performance = current_metric 64 | metric_improvement = (self.best_performance - 65 | self.current_performance) 66 | # Update best performance 67 | if metric_improvement > 0: 68 | self.best_performance = self.current_performance 69 | # Update epochs not improved 70 | if metric_improvement < self.metric_improvement_threshold: 71 | self.epochs_not_improved += 1 72 | self.epochs_not_improved_lr += 1 73 | else: 74 | self.epochs_not_improved = 0 75 | self.epochs_not_improved_lr = 0 76 | 77 | # Determine whether to continue training and whether to reduce the 78 | # learning rate 79 | if self.epochs < self.patience: 80 | continue_training = True 81 | reduce_lr = False 82 | elif self.epochs_not_improved >= self.patience: 83 | continue_training = False 84 | reduce_lr = False 85 | else: 86 | if self.reduce_lr_on_plateau == False: 87 | reduce_lr = False 88 | elif self.epochs_not_improved_lr >= self.lr_patience: 89 | reduce_lr = True 90 | self.epochs_not_improved_lr = 0 91 | print("\nReducing learning rate: metric has not improved more " 92 | f"than {self.metric_improvement_threshold} in the last " 93 | f"{self.lr_patience} epochs.") 94 | else: 95 | reduce_lr = False 96 | continue_training = True 97 | if not continue_training: 98 | print("\nStopping early: metric has not improved more than " 99 | + str(self.metric_improvement_threshold) + 100 | " in the last " + str(self.patience) + " epochs.") 101 | print("If the early stopping criterion is too strong, " 102 | "please instantiate it with different parameters " 103 | "in the train method.") 104 | return continue_training, reduce_lr 105 | 106 | def update_state(self, current_metric: float) -> bool: 107 | improved = (self.best_performance_state - current_metric) > 0 108 | if improved: 109 | self.best_performance_state = current_metric 110 | return improved 111 | 112 | 113 | def print_progress(epoch: int, logs: dict, n_epochs: int): 114 | """ 115 | Create message for '_print_progress_bar()' and print it out with a progress 116 | bar. 117 | 118 | Implementation is adapted from 119 | https://github.com/theislab/scarches/blob/master/scarches/trainers/trvae/_utils.py#L11 120 | (01.10.2022). 121 | 122 | Parameters 123 | ---------- 124 | epoch: 125 | Current epoch. 126 | logs: 127 | Dictionary with all logs (losses & metrics). 128 | n_epochs: 129 | Total number of epochs. 130 | """ 131 | # Define progress message 132 | message = "" 133 | for key in logs: 134 | message += f"{key:s}: {logs[key][-1]:.4f}; " 135 | message = message[:-2] + "\n" 136 | 137 | # Display progress bar 138 | _print_progress_bar(epoch + 1, 139 | n_epochs, 140 | prefix=f"Epoch {epoch + 1}/{n_epochs}", 141 | suffix=message, 142 | decimals=1, 143 | length=20) 144 | 145 | 146 | def _print_progress_bar(epoch: int, 147 | n_epochs: int, 148 | prefix: str="", 149 | suffix: str="", 150 | decimals: int=1, 151 | length: int=100, 152 | fill: str="█"): 153 | """ 154 | Print out a message with a progress bar. 155 | 156 | Implementation is adapted from 157 | https://github.com/theislab/scarches/blob/master/scarches/trainers/trvae/_utils.py#L41 158 | (01.10.2022). 159 | 160 | Parameters 161 | ---------- 162 | epoch: 163 | Current epoch. 164 | n_epochs: 165 | Total number of epochs. 166 | prefix: 167 | String before the progress bar. 168 | suffix: 169 | String after the progress bar. 170 | decimals: 171 | Digits after comma for the percent display. 172 | length: 173 | Length of the progress bar. 174 | fill: 175 | Symbol for filling the bar. 176 | """ 177 | percent = ("{0:." + str(decimals) + "f}").format(100 * ( 178 | epoch / float(n_epochs))) 179 | filled_len = int(length * epoch // n_epochs) 180 | bar = fill * filled_len + '-' * (length - filled_len) 181 | sys.stdout.write("\r%s |%s| %s%s %s" % (prefix, bar, percent, "%", suffix)), 182 | if epoch == n_epochs: 183 | sys.stdout.write("\n") 184 | sys.stdout.flush() 185 | 186 | 187 | def plot_loss_curves(loss_dict: dict) -> plt.figure: 188 | """ 189 | Plot loss curves. 190 | 191 | Parameters 192 | ---------- 193 | loss_dict: 194 | Dictionary containing the training and validation losses. 195 | 196 | Returns 197 | ---------- 198 | fig: 199 | Matplotlib figure of loss curves. 200 | """ 201 | # Plot epochs as integers 202 | ax = plt.figure().gca() 203 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 204 | 205 | # Plot loss 206 | for loss_key, loss in loss_dict.items(): 207 | plt.plot(loss, label = loss_key) 208 | plt.title(f"Loss curves") 209 | plt.ylabel("loss") 210 | plt.xlabel("epoch") 211 | plt.legend(loc = "upper right") 212 | 213 | # Retrieve figure 214 | fig = plt.gcf() 215 | plt.close() 216 | return fig 217 | 218 | 219 | def _cycle_iterable(iterable): 220 | iterator = iter(iterable) 221 | while True: 222 | try: 223 | yield next(iterator) 224 | except StopIteration: 225 | iterator = iter(iterable) -------------------------------------------------------------------------------- /src/nichecompass/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .analysis import (aggregate_obsp_matrix_per_cell_type, 2 | create_cell_type_chord_plot_from_df, 3 | create_new_color_dict, 4 | compute_communication_gp_network, 5 | visualize_communication_gp_network, 6 | generate_enriched_gp_info_plots, 7 | plot_non_zero_gene_count_means_dist) 8 | from .multimodal_mapping import (add_multimodal_mask_to_adata, 9 | get_gene_annotations, 10 | generate_multimodal_mapping_dict) 11 | from .gene_programs import (add_gps_from_gp_dict_to_adata, 12 | extract_gp_dict_from_collectri_tf_network, 13 | extract_gp_dict_from_nichenet_lrt_interactions, 14 | extract_gp_dict_from_mebocost_ms_interactions, 15 | extract_gp_dict_from_omnipath_lr_interactions, 16 | filter_and_combine_gp_dict_gps, 17 | filter_and_combine_gp_dict_gps_v2, 18 | get_unique_genes_from_gp_dict) 19 | 20 | __all__ = ["add_gps_from_gp_dict_to_adata", 21 | "add_multimodal_mask_to_adata", 22 | "aggregate_obsp_matrix_per_cell_type", 23 | "create_cell_type_chord_plot_from_df", 24 | "create_new_color_dict", 25 | "compute_communication_gp_network", 26 | "visualize_communication_gp_network", 27 | "extract_gp_dict_from_collectri_tf_network", 28 | "extract_gp_dict_from_nichenet_lrt_interactions", 29 | "extract_gp_dict_from_mebocost_ms_interactions", 30 | "extract_gp_dict_from_omnipath_lr_interactions", 31 | "filter_and_combine_gp_dict_gps", 32 | "filter_and_combine_gp_dict_gps_v2", 33 | "get_gene_annotations", 34 | "generate_enriched_gp_info_plots", 35 | "plot_non_zero_gene_count_means_dist", 36 | "generate_multimodal_mapping_dict", 37 | "get_unique_genes_from_gp_dict"] -------------------------------------------------------------------------------- /src/nichecompass/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains helper functions for the ´utils´ subpackage. 3 | """ 4 | 5 | import os 6 | from typing import Optional 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | import pyreadr 12 | import seaborn as sns 13 | from anndata import AnnData 14 | 15 | 16 | def load_R_file_as_df(R_file_path: str, 17 | url: Optional[str]=None, 18 | save_df_to_disk: bool=False, 19 | df_save_path: Optional[str]=None) -> pd.DataFrame: 20 | """ 21 | Helper to load an R file either from ´url´ if specified or from 22 | ´R_file_path´ on disk and convert it to a pandas DataFrame. 23 | 24 | Parameters 25 | ---------- 26 | R_file_path: 27 | File path to the R file to be loaded as df. 28 | url: 29 | URL of the R file to be loaded as df. 30 | save_df_to_disk: 31 | If ´True´, save df to disk. 32 | df_save_path: 33 | Path where the df will be saved if ´save_df_to_disk´ is ´True´. 34 | 35 | Returns 36 | ---------- 37 | df: 38 | Content of R file loaded into a pandas DataFrame. 39 | """ 40 | if url is None: 41 | if not os.path.exists(R_file_path): 42 | raise ValueError("Please specify a valid ´R_file_path´ or ´url´.") 43 | result_odict = pyreadr.read_r(R_file_path) 44 | else: 45 | result_odict = pyreadr.read_r(pyreadr.download_file(url, R_file_path)) 46 | os.remove(R_file_path) 47 | 48 | df = result_odict[None] 49 | 50 | if save_df_to_disk: 51 | if df_save_path == None: 52 | raise ValueError("Please specify ´df_save_path´ or set " 53 | "´save_to_disk.´ to False.") 54 | df.to_csv(df_save_path) 55 | return df 56 | 57 | 58 | def create_gp_gene_count_distribution_plots( 59 | gp_dict: Optional[dict]=None, 60 | adata: Optional[AnnData]=None, 61 | gp_targets_mask_key: Optional[str]="nichecompass_gp_targets", 62 | gp_sources_mask_key: Optional[str]="nichecompass_gp_sources", 63 | gp_plot_label: str="", 64 | save_path: Optional[str]=None): 65 | """ 66 | Create distribution plots of the gene counts for sources and targets 67 | of all gene programs in either a gp dict or an adata object. 68 | 69 | Parameters 70 | ---------- 71 | gp_dict: 72 | A gene program dictionary. 73 | adata: 74 | An anndata object 75 | gp_plot_label: 76 | Label of the gene program plot for title. 77 | """ 78 | # Get number of source and target genes for each gene program 79 | if gp_dict is not None: 80 | n_sources_list = [] 81 | n_targets_list = [] 82 | for _, gp_sources_targets_dict in gp_dict.items(): 83 | n_sources_list.append(len(gp_sources_targets_dict["sources"])) 84 | n_targets_list.append(len(gp_sources_targets_dict["targets"])) 85 | elif adata is not None: 86 | n_targets_list = adata.varm[gp_targets_mask_key].sum(axis=0) 87 | n_sources_list = adata.varm[gp_sources_mask_key].sum(axis=0) 88 | 89 | 90 | # Convert the arrays to a pandas DataFrame 91 | targets_df = pd.DataFrame({"values": n_targets_list}) 92 | sources_df = pd.DataFrame({"values": n_sources_list}) 93 | 94 | # Determine plot configurations 95 | max_n_targets = max(n_targets_list) 96 | max_n_sources = max(n_sources_list) 97 | if max_n_targets > 200: 98 | targets_x_ticks_range = 100 99 | xticklabels_rotation = 45 100 | elif max_n_targets > 100: 101 | targets_x_ticks_range = 20 102 | xticklabels_rotation = 0 103 | elif max_n_targets > 10: 104 | targets_x_ticks_range = 10 105 | xticklabels_rotation = 0 106 | else: 107 | targets_x_ticks_range = 1 108 | xticklabels_rotation = 0 109 | if max_n_sources > 200: 110 | sources_x_ticks_range = 100 111 | elif max_n_sources > 100: 112 | sources_x_ticks_range = 20 113 | elif max_n_sources > 10: 114 | sources_x_ticks_range = 10 115 | else: 116 | sources_x_ticks_range = 1 117 | 118 | # Create subplot 119 | fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5)) 120 | plt.suptitle( 121 | f"{gp_plot_label} Gene Programs – Gene Count Distribution Plots") 122 | sns.histplot(x="values", data=targets_df, ax=ax1) 123 | ax1.set_title("Gene Program Targets Distribution", 124 | fontsize=10) 125 | ax1.set(xlabel="Number of Targets", 126 | ylabel="Number of Gene Programs") 127 | ax1.set_xticks( 128 | np.arange(0, 129 | max_n_targets + targets_x_ticks_range, 130 | targets_x_ticks_range)) 131 | ax1.set_xticklabels( 132 | np.arange(0, 133 | max_n_targets + targets_x_ticks_range, 134 | targets_x_ticks_range), 135 | rotation=xticklabels_rotation) 136 | sns.histplot(x="values", data=sources_df, ax=ax2) 137 | ax2.set_title("Gene Program Sources Distribution", 138 | fontsize=10) 139 | ax2.set(xlabel="Number of Sources", 140 | ylabel="Number of Gene Programs") 141 | ax2.set_xticks( 142 | np.arange(0, 143 | max_n_sources + sources_x_ticks_range, 144 | sources_x_ticks_range)) 145 | ax2.set_xticklabels( 146 | np.arange(0, 147 | max_n_sources + sources_x_ticks_range, 148 | sources_x_ticks_range), 149 | rotation=xticklabels_rotation) 150 | plt.subplots_adjust(wspace=0.35) 151 | if save_path: 152 | plt.savefig(save_path) 153 | plt.show() -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import nichecompass 4 | 5 | 6 | def test_package_has_version(): 7 | assert nichecompass.__version__ is not None 8 | 9 | 10 | @pytest.mark.skip(reason="This decorator should be removed when test passes.") 11 | def test_example(): 12 | assert 1 == 0 # This test is designed to fail. 13 | --------------------------------------------------------------------------------