├── .codecov.yaml
├── .editorconfig
├── .github
├── actions
│ └── setup
│ │ └── action.yaml
└── workflows
│ ├── build.yaml
│ ├── release.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── data
├── gene_annotations
│ ├── gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz
│ └── human_mouse_gene_orthologs.csv
└── gene_programs
│ └── metabolite_enzyme_sensor_gps
│ ├── human_metabolite_enzymes.tsv
│ ├── human_metabolite_sensors.tsv
│ ├── mouse_metabolite_enzymes.tsv
│ └── mouse_metabolite_sensors.tsv
├── docs
├── Makefile
├── _static
│ ├── .gitkeep
│ ├── css
│ │ └── custom.css
│ ├── nichecompass_fig1.png
│ ├── nichecompass_logo.svg
│ └── nichecompass_logo_readme.png
├── _templates
│ ├── .gitkeep
│ └── autosummary
│ │ └── class.rst
├── api
│ ├── developer.md
│ ├── index.md
│ └── user.md
├── conf.py
├── contributing.md
├── extensions
│ └── typed_returns.py
├── index.md
├── installation.md
├── references.bib
├── references.md
├── release_notes
│ └── index.md
├── requirements.txt
├── tutorials
│ ├── index.md
│ ├── multimodal_tutorials.md
│ ├── notebooks
│ │ ├── mouse_brain_multimodal.ipynb
│ │ ├── mouse_cns_sample_integration.ipynb
│ │ ├── mouse_cns_single_sample.ipynb
│ │ └── mouse_cns_spatial_reference_mapping.ipynb
│ ├── sample_integration_tutorials.md
│ ├── single_sample_tutorials.md
│ └── spatial_reference_mapping_tutorials.md
└── user_guide
│ └── index.md
├── environment.yaml
├── pyproject.toml
├── src
└── nichecompass
│ ├── __init__.py
│ ├── benchmarking
│ ├── __init__.py
│ ├── cas.py
│ ├── clisis.py
│ ├── gcs.py
│ ├── metrics.py
│ ├── mlami.py
│ ├── nasw.py
│ └── utils.py
│ ├── data
│ ├── __init__.py
│ ├── dataloaders.py
│ ├── dataprocessors.py
│ ├── datareaders.py
│ ├── datasets.py
│ └── utils.py
│ ├── models
│ ├── __init__.py
│ ├── basemodelmixin.py
│ ├── nichecompass.py
│ └── utils.py
│ ├── modules
│ ├── __init__.py
│ ├── basemodulemixin.py
│ ├── losses.py
│ ├── vgaemodulemixin.py
│ └── vgpgae.py
│ ├── nn
│ ├── __init__.py
│ ├── aggregators.py
│ ├── decoders.py
│ ├── encoders.py
│ ├── layercomponents.py
│ ├── layers.py
│ └── utils.py
│ ├── train
│ ├── __init__.py
│ ├── basetrainermixin.py
│ ├── metrics.py
│ ├── trainer.py
│ └── utils.py
│ └── utils
│ ├── __init__.py
│ ├── analysis.py
│ ├── gene_programs.py
│ ├── multimodal_mapping.py
│ └── utils.py
└── tests
└── test_basic.py
/.codecov.yaml:
--------------------------------------------------------------------------------
1 | # Based on pydata/xarray
2 | codecov:
3 | require_ci_to_pass: no
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | # Require 1% coverage, i.e., always succeed
10 | target: 1
11 | patch: false
12 | changes: false
13 |
14 | comment:
15 | layout: diff, flags, files
16 | behavior: once
17 | require_base: no
18 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.github/actions/setup/action.yaml:
--------------------------------------------------------------------------------
1 | name: Setup
2 |
3 | inputs:
4 | python-version:
5 | required: false
6 | default: '3.9'
7 | torch-version:
8 | required: false
9 | default: '2.0.0'
10 | cuda-version:
11 | required: false
12 | default: cpu
13 | full_install:
14 | required: false
15 | default: true
16 |
17 | runs:
18 | using: composite
19 |
20 | steps:
21 | - name: Set up Python ${{ inputs.python-version }}
22 | uses: actions/setup-python@v4.3.0
23 | with:
24 | python-version: ${{ inputs.python-version }}
25 | check-latest: true
26 | cache: pip
27 | cache-dependency-path: |
28 | pyproject.toml
29 |
30 | - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }}
31 | run: |
32 | pip install torch==${{ inputs.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }}
33 | python -c "import torch; print('PyTorch:', torch.__version__)"
34 | python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
35 | python -c "import torch; print('CUDA:', torch.version.cuda)"
36 | shell: bash
37 |
38 | - name: Install extension packages
39 | if: ${{ inputs.full_install == 'true' }}
40 | run: |
41 | pip install scipy
42 | pip install --no-index --upgrade torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html
43 | shell: bash
44 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | package:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Set up Python 3.9
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: "3.9"
22 | cache: "pip"
23 | cache-dependency-path: "**/pyproject.toml"
24 | - name: Install build dependencies
25 | run: python -m pip install --upgrade pip wheel twine build
26 | - name: Build package
27 | run: python -m build
28 | - name: Check package
29 | run: twine check --strict dist/*.whl
30 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/
8 | jobs:
9 | release:
10 | name: Upload release to PyPI
11 | runs-on: ubuntu-latest
12 | environment:
13 | name: pypi
14 | url: https://pypi.org/p/nichecompass/
15 | permissions:
16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
17 | steps:
18 | - uses: actions/checkout@v4
19 | with:
20 | filter: blob:none
21 | fetch-depth: 0
22 | - uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.9"
25 | cache: "pip"
26 | - run: pip install build
27 | - run: python -m build
28 | - name: Publish package distributions to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | schedule:
9 | - cron: "0 5 1,15 * *"
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | test:
17 | runs-on: ${{ matrix.os }}
18 | defaults:
19 | run:
20 | shell: bash -e {0} # -e to fail on error
21 |
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | include:
26 | - os: ubuntu-latest
27 | python: "3.9"
28 |
29 | name: ${{ matrix.name }} Python ${{ matrix.python }}
30 |
31 | env:
32 | OS: ${{ matrix.os }}
33 | PYTHON: ${{ matrix.python }}
34 |
35 | steps:
36 | - uses: actions/checkout@v4
37 | - name: Set up Python ${{ matrix.python }}
38 | uses: actions/setup-python@v5
39 | with:
40 | python-version: ${{ matrix.python }}
41 | cache: "pip"
42 | cache-dependency-path: "**/pyproject.toml"
43 |
44 | - name: Set up external packages
45 | if: steps.changed-files-specific.outputs.only_changed != 'true'
46 | uses: ./.github/actions/setup
47 | - name: Install test dependencies
48 | run: |
49 | python -m pip install --upgrade pip wheel
50 | - name: Install dependencies
51 | run: |
52 | pip install ${{ matrix.pip-flags }} ".[dev,tests]"
53 | - name: Test
54 | env:
55 | MPLBACKEND: agg
56 | PLATFORM: ${{ matrix.os }}
57 | DISPLAY: :42
58 | run: |
59 | coverage run -m pytest -v --color=yes
60 | - name: Report coverage
61 | run: |
62 | coverage report
63 | - name: Upload coverage
64 | uses: codecov/codecov-action@v3
65 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # DS_Store
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | venv/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # vscode
133 | .vscode/settings.json
134 |
135 | # data
136 | data/gene_programs/*
137 | data/spatial_omics/*
138 | !data/gene_programs/metabolite_enzyme_sensor_gps/
139 |
140 | # artifacts
141 | artifacts
142 |
143 | # charliecloud image for HPC
144 | /env/nichecompass.tar.gz
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 22.6.0
4 | hooks:
5 | - id: black
6 | - repo: https://github.com/PyCQA/flake8
7 | rev: 4.0.1
8 | hooks:
9 | - id: flake8
10 | - repo: https://github.com/pycqa/isort
11 | rev: 5.10.1
12 | hooks:
13 | - id: isort
14 | name: isort (python)
15 | additional_dependencies: [toml]
16 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html
2 | version: 2
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.9"
7 | sphinx:
8 | configuration: docs/conf.py
9 | # disable this for more lenient docs builds
10 | fail_on_warning: false
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 | - method: pip
15 | path: .
16 | extra_requirements:
17 | - docsbuild
18 | submodules:
19 | include: [docs/tutorials/notebooks]
20 | recursive: true
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Sebastian Birk, Carlos Talavera-López, Mohammad Lotfollahi
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](https://github.com/Lotfollahi-lab/nichecompass/blob/main/LICENSE)
4 | [](https://github.com/Lotfollahi-lab/nichecompass/stargazers)
5 | [](https://pypi.org/project/nichecompass)
6 | [](https://pepy.tech/project/nichecompass)
7 | [](https://nichecompass.readthedocs.io/en/stable/?badge=stable)
8 | [](https://github.com/psf/black)
9 | [](https://github.com/pre-commit/pre-commit)
10 |
11 | NicheCompass (**N**iche **I**dentification based on **C**ellular grap**H** **E**mbeddings of **COM**munication **P**rograms **A**ligned across **S**patial **S**amples) is a package for end-to-end analysis of spatial multi-omics data, including spatial atlas building, niche identification & characterization, cell-cell communication inference and spatial reference mapping. It is built on top of [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) and [AnnData](https://anndata.readthedocs.io/en/latest/).
12 |
13 | ## Resources
14 | - An installation guide, tutorials and API documentation is available in the [documentation](https://nichecompass.readthedocs.io/).
15 | - Please use [issues](https://github.com/Lotfollahi-lab/nichecompass/issues) to submit bug reports.
16 | - If you would like to contribute, check out the [contributing guide](https://nichecompass.readthedocs.io/en/latest/contributing.html).
17 | - If you find NicheCompass useful for your research, please consider citing the NicheCompass manuscript.
18 |
19 | ## Reference
20 | ```
21 | @article{Birk2025,
22 | author = {Birk, S. and Bonafonte-Pard{\`a}s, I. and Feriz, A. M. and et al.},
23 | title = {Quantitative characterization of cell niches in spatially resolved omics data},
24 | journal = {Nature Genetics},
25 | year = {2025},
26 | doi = {10.1038/s41588-025-02120-6},
27 | url = {https://doi.org/10.1038/s41588-025-02120-6}
28 | }
29 | ```
30 |
--------------------------------------------------------------------------------
/data/gene_annotations/gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/data/gene_annotations/gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */
2 | div.cell_output table.dataframe {
3 | font-size: 0.8em;
4 | }
5 |
6 | .wy-nav-side {
7 | background: #242335;
8 | }
9 |
10 | .wy-side-nav-search {
11 | background-color: #F3F4F7;
12 | }
13 |
14 | .icon.icon-home,
15 | .icon.icon-home .logo {
16 | color: #000000; /* Black color */
17 | }
18 |
--------------------------------------------------------------------------------
/docs/_static/nichecompass_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_static/nichecompass_fig1.png
--------------------------------------------------------------------------------
/docs/_static/nichecompass_logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/_static/nichecompass_logo_readme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_static/nichecompass_logo_readme.png
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lotfollahi-lab/nichecompass/2231b021122a9c06aaed55b9a81188877a230cd4/docs/_templates/.gitkeep
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/api/developer.md:
--------------------------------------------------------------------------------
1 | # Developer
2 |
3 | ## Benchmarking
4 |
5 | ```{eval-rst}
6 | .. module:: nichecompass.benchmarking
7 | .. currentmodule:: nichecompass
8 |
9 | .. autosummary::
10 | :toctree: generated
11 |
12 | benchmarking.utils.compute_knn_graph_connectivities_and_distances
13 | ```
14 |
15 | ## Data
16 |
17 | ```{eval-rst}
18 | .. module:: nichecompass.data
19 | .. currentmodule:: nichecompass
20 |
21 | .. autosummary::
22 | :toctree: generated
23 |
24 | data.initialize_dataloaders
25 | data.edge_level_split
26 | data.node_level_split_mask
27 | data.prepare_data
28 | data.SpatialAnnTorchDataset
29 | ```
30 |
31 | ## Models
32 |
33 | ```{eval-rst}
34 | .. module:: nichecompass.models
35 | .. currentmodule:: nichecompass
36 |
37 | .. autosummary::
38 | :toctree: generated
39 |
40 | models.utils.load_saved_files
41 | models.utils.validate_var_names
42 | models.utils.initialize_model
43 | ```
44 |
45 | ## Modules
46 |
47 | ```{eval-rst}
48 | .. module:: nichecompass.modules
49 | .. currentmodule:: nichecompass
50 |
51 | .. autosummary::
52 | :toctree: generated
53 |
54 | modules.VGPGAE
55 | modules.VGAEModuleMixin
56 | modules.BaseModuleMixin
57 | modules.compute_cat_covariates_contrastive_loss
58 | modules.compute_edge_recon_loss
59 | modules.compute_gp_group_lasso_reg_loss
60 | modules.compute_gp_l1_reg_loss
61 | modules.compute_kl_reg_loss
62 | modules.compute_omics_recon_nb_loss
63 | ```
64 |
65 | ## NN
66 |
67 | ```{eval-rst}
68 | .. module:: nichecompass.nn
69 | .. currentmodule:: nichecompass
70 |
71 | .. autosummary::
72 | :toctree: generated
73 |
74 | nn.OneHopAttentionNodeLabelAggregator
75 | nn.OneHopGCNNormNodeLabelAggregator
76 | nn.OneHopSumNodeLabelAggregator
77 | nn.CosineSimGraphDecoder
78 | nn.FCOmicsFeatureDecoder
79 | nn.MaskedOmicsFeatureDecoder
80 | nn.Encoder
81 | nn.MaskedLinear
82 | nn.AddOnMaskedLayer
83 | ```
84 |
85 | ## Train
86 |
87 | ```{eval-rst}
88 | .. module:: nichecompass.train
89 | .. currentmodule:: nichecompass
90 |
91 | .. autosummary::
92 | :toctree: generated
93 |
94 | train.Trainer
95 | train.eval_metrics
96 | train.plot_eval_metrics
97 | ```
98 |
--------------------------------------------------------------------------------
/docs/api/index.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | Import NicheCompass as:
4 |
5 | ```
6 | import nichecompass
7 | ```
8 |
9 | ```{toctree}
10 | :maxdepth: 2
11 |
12 | user
13 | developer
14 | ```
--------------------------------------------------------------------------------
/docs/api/user.md:
--------------------------------------------------------------------------------
1 | # User
2 |
3 | ## Benchmarking
4 |
5 | ```{eval-rst}
6 | .. module:: nichecompass.benchmarking
7 | .. currentmodule:: nichecompass
8 |
9 | .. autosummary::
10 | :toctree: generated
11 |
12 | benchmarking.compute_benchmarking_metrics
13 | benchmarking.compute_cas
14 | benchmarking.compute_clisis
15 | benchmarking.compute_gcs
16 | benchmarking.compute_mlami
17 | benchmarking.compute_nasw
18 | ```
19 |
20 | ## Data
21 | ```{eval-rst}
22 | .. module:: nichecompass.data
23 | .. currentmodule:: nichecompass
24 |
25 | .. autosummary::
26 | :toctree: generated
27 |
28 | data.load_spatial_adata_from_csv
29 | ```
30 |
31 | ## Models
32 |
33 | ```{eval-rst}
34 | .. module:: nichecompass.models
35 | .. currentmodule:: nichecompass
36 |
37 | .. autosummary::
38 | :toctree: generated
39 |
40 | models.NicheCompass
41 | ```
42 |
43 | ## Gene Program Utilities
44 |
45 | ```{eval-rst}
46 | .. module:: nichecompass.utils
47 | .. currentmodule:: nichecompass
48 |
49 | .. autosummary::
50 | :toctree: generated
51 |
52 | utils.add_gps_from_gp_dict_to_adata
53 | utils.add_multimodal_mask_to_adata
54 | utils.extract_gp_dict_from_collectri_tf_network
55 | utils.extract_gp_dict_from_nichenet_lrt_interactions
56 | utils.extract_gp_dict_from_mebocost_es_interactions
57 | utils.extract_gp_dict_from_omnipath_lr_interactions
58 | utils.filter_and_combine_gp_dict_gps
59 | utils.get_gene_annotations
60 | utils.generate_multimodal_mapping_dict
61 | utils.get_unique_genes_from_gp_dict
62 | utils.generate_enriched_gp_info_plots
63 | ```
64 |
65 | ## Cell-Cell Communication Utilities
66 |
67 | ```{eval-rst}
68 | .. module:: nichecompass.utils
69 | .. currentmodule:: nichecompass
70 |
71 | .. autosummary::
72 | :toctree: generated
73 |
74 | utils.compute_communication_gp_network
75 | utils.visualize_communication_gp_network
76 | ```
77 |
78 | ## Other Utilities
79 | ```{eval-rst}
80 | .. module:: nichecompass.utils
81 | .. currentmodule:: nichecompass
82 |
83 | .. autosummary::
84 | :toctree: generated
85 |
86 | utils.create_new_color_dict
87 | ```
88 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 | import sys
9 | from datetime import datetime
10 | from importlib.metadata import metadata
11 | from pathlib import Path
12 |
13 | HERE = Path(__file__).parent
14 | sys.path.insert(0, str(HERE / "extensions"))
15 |
16 | # -- Project information -----------------------------------------------------
17 |
18 | # NOTE: If you installed your project in editable mode, this might be stale.
19 | # If this is the case, reinstall it to refresh the metadata
20 | info = metadata("nichecompass")
21 | project_name = info.get("Name", "NicheCompass")
22 | project = project_name
23 | author = info.get("Author", "Sebastian Birk")
24 | copyright = f"{datetime.now():%Y}, {author}."
25 | version = info["Version"]
26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL"))
27 | repository_url = urls["Source"]
28 |
29 | # The full version, including alpha/beta/rc tags
30 | release = info["Version"]
31 |
32 | bibtex_bibfiles = ["references.bib"]
33 | templates_path = ["_templates"]
34 | nitpicky = True # Warn about broken links
35 | needs_sphinx = "4.0"
36 |
37 | html_context = {
38 | "display_github": True, # Integrate GitHub
39 | "github_user": "Lotfollahi-lab", # Username
40 | "github_repo": project_name, # Repo name
41 | "github_version": "main", # Version
42 | "conf_py_path": "/docs/", # Path in the checkout to the docs root
43 | }
44 |
45 | # -- General configuration ---------------------------------------------------
46 |
47 | # Add any Sphinx extension module names here, as strings.
48 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
49 | extensions = [
50 | "myst_nb",
51 | "sphinx_copybutton",
52 | "sphinx.ext.autodoc",
53 | "sphinx.ext.intersphinx",
54 | "sphinx.ext.autosummary",
55 | "sphinx.ext.extlinks",
56 | "sphinx.ext.napoleon",
57 | "sphinxcontrib.bibtex",
58 | "sphinx_autodoc_typehints",
59 | "sphinx.ext.mathjax",
60 | "sphinx_design",
61 | "IPython.sphinxext.ipython_console_highlighting",
62 | "sphinxext.opengraph",
63 | "hoverxref.extension",
64 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
65 | ]
66 |
67 | autosummary_generate = True
68 | autodoc_member_order = "groupwise"
69 | default_role = "literal"
70 | napoleon_google_docstring = False
71 | napoleon_numpy_docstring = True
72 | napoleon_include_init_with_doc = False
73 | napoleon_use_rtype = True # having a separate entry generally helps readability
74 | napoleon_use_param = True
75 | myst_heading_anchors = 6 # create anchors for h1-h6
76 | myst_enable_extensions = [
77 | "amsmath",
78 | "colon_fence",
79 | "deflist",
80 | "dollarmath",
81 | "html_image",
82 | "html_admonition",
83 | ]
84 | myst_url_schemes = ("http", "https", "mailto")
85 | nb_output_stderr = "remove"
86 | nb_execution_mode = "off"
87 | nb_merge_streams = True
88 | typehints_defaults = "braces"
89 |
90 | source_suffix = {
91 | ".rst": "restructuredtext",
92 | ".ipynb": "myst-nb",
93 | ".myst": "myst-nb",
94 | }
95 |
96 | intersphinx_mapping = {
97 | "python": ("https://docs.python.org/3", None),
98 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
99 | "numpy": ("https://numpy.org/doc/stable/", None),
100 | }
101 |
102 | # List of patterns, relative to source directory, that match files and
103 | # directories to ignore when looking for source files.
104 | # This pattern also affects html_static_path and html_extra_path.
105 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
106 |
107 | # -- Options for HTML output -------------------------------------------------
108 |
109 | # The theme to use for HTML and HTML Help pages. See the documentation for
110 | # a list of builtin themes.
111 | #
112 | html_theme = "sphinx_rtd_theme"
113 | html_static_path = ["_static"]
114 | html_css_files = ["css/custom.css"]
115 |
116 | html_title = "NicheCompass"
117 | html_logo = "_static/nichecompass_logo.svg"
118 |
119 | html_theme_options = {
120 | "repository_url": repository_url,
121 | "use_repository_button": True,
122 | "path_to_docs": "docs/",
123 | "navigation_with_keys": False,
124 | }
125 |
126 | pygments_style = "default"
127 |
128 | nitpick_ignore = [
129 | # If building the documentation fails because of a missing link that is outside your control,
130 | # you can add an exception to this list.
131 | # ("py:class", "igraph.Graph"),
132 | ]
133 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | This will be added shortly.
4 |
5 |
6 |
7 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui
8 | [codecov]: https://about.codecov.io/sign-up/
9 | [codecov docs]: https://docs.codecov.com/docs
10 | [codecov bot]: https://docs.codecov.com/docs/team-bot
11 | [codecov app]: https://github.com/apps/codecov
12 | [pre-commit.ci]: https://pre-commit.ci/
13 | [readthedocs.org]: https://readthedocs.org/
14 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/
15 | [jupytext]: https://jupytext.readthedocs.io/en/latest/
16 | [pre-commit]: https://pre-commit.com/
17 | [pytest]: https://docs.pytest.org/
18 | [semver]: https://semver.org/
19 | [sphinx]: https://www.sphinx-doc.org/en/master/
20 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html
21 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
22 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html
23 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints
24 | [pypi]: https://pypi.org/
25 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | from __future__ import annotations
4 |
5 | import re
6 | from collections.abc import Generator, Iterable
7 |
8 | from sphinx.application import Sphinx
9 | from sphinx.ext.napoleon import NumpyDocstring
10 |
11 |
12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
13 | for line in lines:
14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line):
15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
16 | else:
17 | yield line
18 |
19 |
20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]:
21 | lines_raw = self._dedent(self._consume_to_next_section())
22 | if lines_raw[0] == ":":
23 | del lines_raw[0]
24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw)))
25 | if lines and lines[-1]:
26 | lines.append("")
27 | return lines
28 |
29 |
30 | def setup(app: Sphinx):
31 | """Set app."""
32 | NumpyDocstring._parse_returns_section = _parse_returns_section
33 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | NicheCompass (**N**iche **I**dentification based on **C**ellular grap**H** **E**mbeddings of **COM**munication **P**rograms **A**ligned across **S**patial **S**amples) is a package for end-to-end analysis of spatial multi-omics data, including spatial atlas building, niche identification & characterization, cell-cell communication inference and spatial reference mapping. It is built on top of [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) and [AnnData](https://anndata.readthedocs.io/en/latest/).
4 | The package is developed and maintained by the [Lotfollahi Lab](https://github.com/Lotfollahi-lab) at the Wellcome Sanger Institute.
5 |
6 | ::::{grid} 1 2 3 3
7 | :gutter: 2
8 |
9 | :::{grid-item-card} Installation {octicon}`plug;1em;`
10 | :link: installation
11 | :link-type: doc
12 |
13 | Check out the installation guide.
14 | :::
15 |
16 | :::{grid-item-card} Tutorials {octicon}`play;1em;`
17 | :link: tutorials/index
18 | :link-type: doc
19 |
20 | Learn by following example applications of NicheCompass.
21 | :::
22 |
23 | :::{grid-item-card} User Guide {octicon}`book;1em;`
24 | :link: user_guide/index
25 | :link-type: doc
26 |
27 | Review good practices for training NicheCompass models on your own data.
28 | :::
29 |
30 | :::{grid-item-card} API {octicon}`info;1em;`
31 | :link: api/index
32 | :link-type: doc
33 |
34 | Find a detailed description of NicheCompass APIs.
35 | :::
36 |
37 | :::{grid-item-card} Release Notes {octicon}`tag;1em;`
38 | :link: release_notes/index
39 | :link-type: doc
40 |
41 | Follow the latest changes to NicheCompass.
42 | :::
43 |
44 | :::{grid-item-card} Contributing {octicon}`code;1em;`
45 | :link: contributing
46 | :link-type: doc
47 |
48 | Help improve NicheCompass.
49 | :::
50 | ::::
51 |
52 | If you find NicheCompass useful for your research, please consider citing the NicheCompass manuscript.
53 |
54 | ```{toctree}
55 | :hidden: true
56 | :maxdepth: 3
57 | :titlesonly: true
58 |
59 | installation
60 | tutorials/index
61 | api/index
62 | release_notes/index
63 | contributing.md
64 | references.md
65 | ```
66 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | NicheCompass is available for Python 3.9. It does yet not support Apple silicon.
4 |
5 |
6 | We do not recommend installation on your system Python. Please set up a virtual
7 | environment, e.g. via conda through the [Mambaforge] distribution, or create a
8 | [Docker] image.
9 |
10 | ## Additional Libraries
11 |
12 | To use NicheCompass, you need to install some external libraries. These include:
13 | - [PyTorch]
14 | - [PyTorch Scatter]
15 | - [PyTorch Sparse]
16 | - [bedtools]
17 |
18 | We recommend to install the PyTorch libraries with GPU support. If you have
19 | CUDA, this can be done as:
20 |
21 | ```
22 | pip install torch==${TORCH}+${CUDA} --extra-index-url https://download.pytorch.org/whl/${CUDA}
23 | pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
24 | ```
25 | where `${TORCH}` and `${CUDA}` should be replaced by the specific PyTorch and
26 | CUDA versions, respectively.
27 |
28 | For example, for PyTorch 2.0.0 and CUDA 11.7, type:
29 | ```
30 | pip install torch==2.0.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
31 | pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
32 | ```
33 |
34 | To install bedtools, you can use conda:
35 | ```
36 | conda install bedtools=2.30.0 -c bioconda
37 | ```
38 |
39 | Alternatively, we have provided a conda environment file with all required
40 | external libraries, which you can use as:
41 | ```
42 | conda env create -f environment.yaml
43 | ```
44 |
45 | To enable GPU support for JAX, after the installation run:
46 | ```
47 | pip install jaxlib==0.3.25+cuda${CUDA}.cudnn${CUDNN} -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
48 | ```
49 |
50 | For example, for CUDA 11.7, type:
51 | ```
52 | pip install jaxlib==0.4.7+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
53 | ```
54 |
55 |
56 | ## Installation via PyPi
57 |
58 | Subsequently, install NicheCompass via pip:
59 | ```
60 | pip install nichecompass
61 | ```
62 |
63 | Install optional dependencies required for benchmarking, multimodal analysis, running tutorials etc. with:
64 | ```
65 | pip install nichecompass[all]
66 | ```
67 |
68 | [Mambaforge]: https://github.com/conda-forge/miniforge
69 | [Docker]: https://www.docker.com
70 | [PyTorch]: http://pytorch.org
71 | [PyTorch Scatter]: https://github.com/rusty1s/pytorch_scatter
72 | [PyTorch Sparse]: https://github.com/rusty1s/pytorch_sparse
73 | [bedtools]: https://bedtools.readthedocs.io
74 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{Virshup_2023,
2 | doi = {10.1038/s41587-023-01733-8},
3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8},
4 | year = 2023,
5 | month = {apr},
6 | publisher = {Springer Science and Business Media {LLC}},
7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis},
8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis},
9 | journal = {Nature Biotechnology}
10 | }
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/docs/release_notes/index.md:
--------------------------------------------------------------------------------
1 | # Release notes
2 |
3 | All notable changes to this project will be documented in this file. The format
4 | is based on [keep a changelog], and this project adheres to
5 | [Semantic Versioning]. Full commit history is available in the [commit logs].
6 |
7 | ### 0.2.3 (17.02.2025)
8 |
9 | - Added numpy<2 dependency as version upgrade of NumPy to major version 2 breaks required scanpy version.
10 | [@sebastianbirk]
11 |
12 | ### 0.2.2 (09.01.2025)
13 |
14 | - Synced repository with Zenodo to mint DOI for publication.
15 | [@sebastianbirk]
16 |
17 | ### 0.2.1 (15.10.2024)
18 |
19 | - Added a user guide to the package documentation.
20 | [@sebastianbirk]
21 |
22 | ### 0.2.0 (22.08.2024)
23 |
24 | - Fixed a bug in the configuration of random seeds.
25 | - Fixed a bug in the definition of MEBOCOST prior gene programs.
26 | - Raised the default quality filter threshold for the retrieval of OmniPath gene programs.
27 | - Modified the GP filtering logic to combine GPs with the same source genes and drop GPs that do not have source genes if they have a specified overlap in target genes with other GPs.
28 | - Changed the default hyperparameters for model training based on ablation experiments.
29 | [@sebastianbirk]
30 |
31 | ### 0.1.2 (13.02.2024)
32 |
33 | - The version was incremented due to package upload requirements.
34 | [@sebastianbirk]
35 |
36 | ### 0.1.1 (13.02.2024)
37 |
38 | - The version was incremented due to package upload requirements.
39 | [@sebastianbirk]
40 |
41 | ### 0.1.0 (13.02.2024)
42 |
43 | - First NicheCompass version.
44 | [@sebastianbirk]
45 |
46 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/
47 | [Semantic Versioning]: https://semver.org/spec/v2.0.0.html
48 | [commit logs]: https://github.com/Lotfollahi-lab/nichecompass/commits
49 | [@sebastianbirk]: https://github.com/sebastianbirk
50 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | https://download.pytorch.org/whl/cpu/torch-1.13.0%2Bcpu-cp39-cp39-linux_x86_64.whl
2 | https://data.pyg.org/whl/torch-1.13.0%2Bcpu/torch_sparse-0.6.17%2Bpt113cpu-cp39-cp39-linux_x86_64.whl
3 | https://data.pyg.org/whl/torch-1.13.0%2Bcpu/torch_scatter-2.1.1%2Bpt113cpu-cp39-cp39-linux_x86_64.whl
4 |
--------------------------------------------------------------------------------
/docs/tutorials/index.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 | Get started with NicheCompass by following our tutorials.
4 |
5 | ```{toctree}
6 | :maxdepth: 2
7 |
8 | single_sample_tutorials
9 | sample_integration_tutorials
10 | spatial_reference_mapping_tutorials
11 | multimodal_tutorials
12 | ```
--------------------------------------------------------------------------------
/docs/tutorials/multimodal_tutorials.md:
--------------------------------------------------------------------------------
1 | # Multimodal Tutorials
2 |
3 | ```{toctree}
4 | :maxdepth: 1
5 |
6 | notebooks/mouse_brain_multimodal
7 | ```
--------------------------------------------------------------------------------
/docs/tutorials/sample_integration_tutorials.md:
--------------------------------------------------------------------------------
1 | # Sample Integration Tutorials
2 | ```{toctree}
3 | :maxdepth: 1
4 |
5 | notebooks/mouse_cns_sample_integration
6 | ```
--------------------------------------------------------------------------------
/docs/tutorials/single_sample_tutorials.md:
--------------------------------------------------------------------------------
1 | # Single Sample Tutorials
2 |
3 | ```{toctree}
4 | :maxdepth: 1
5 |
6 | notebooks/mouse_cns_single_sample
7 | ```
--------------------------------------------------------------------------------
/docs/tutorials/spatial_reference_mapping_tutorials.md:
--------------------------------------------------------------------------------
1 | # Spatial Reference Mapping Tutorials
2 |
3 | ```{toctree}
4 | :maxdepth: 1
5 |
6 | notebooks/mouse_cns_spatial_reference_mapping
7 | ```
--------------------------------------------------------------------------------
/docs/user_guide/index.md:
--------------------------------------------------------------------------------
1 | # User guide
2 |
3 | ## Hyperparameter selection
4 |
5 | We conducted various ablation experiments on both simulated and real spatial transcriptomics data to evaluate important model design choices and hyperparameters. The detailed results and interpretations can be found in the NicheCompass manuscript.
6 |
7 | In summary, we recommend the following:
8 | - Regarding the loss, we observed that finding a balance between gene expression and edge reconstruction is a key element for good niche identification (NID) and GP recovery (GPR) performance, while regularization of de novo GP weights is essential for GPR. The loss weights have been specified in the NicheCompass package accordingly and in most cases do not need to be changed by the user.
9 | - With respect to size of the KNN neighborhood graph, a smaller number of neighbors is more efficient in NID and de novo GP detection while a larger number of neighbors can facilitate GPR of prior GPs. Here, we recommend users to specify a neighborhood size based on the expected range of interactions in the tissue. Empirically, we observed a neighborhood size between 4 and 12 to work well.
10 | - The inclusion of de novo GPs is crucial for GPR. However, the number of de novo GPs should not be too high as important genes might be split across multiple GPs. The default in the NicheCompass package is 100 and we do not recommend to increase this number.
11 | - GP pruning can slighlty improve GPR and NID while reducing the embedding size of the model; we therefore recommend users to use the default setting of weak GP pruning.
12 | - We recommend that users define prior GPs solely based on the biology that they are interested in (as opposed to including as many prior GPs as possible).
13 | - We recommend users to use a GATv2 encoder layer (as opposed to a GCNConv encoder layer) unless performance is a bottleneck or niche characterization is not a priority and the data has single-cell resolution.
14 | - Since the use of prior GPs can significantly improve NID compared to a scenario without prior GPs, we recommend users to use the default set of prior GPs even if interpretability is not a main objective.
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: nichecompass
2 | channels:
3 | - pytorch
4 | - pyg
5 | - nvidia
6 | - bioconda
7 | dependencies:
8 | - python=3.9
9 | - pip
10 | - pytorch=2.0.0
11 | - pytorch-scatter=2.1.1
12 | - pytorch-sparse=0.6.17
13 | - pytorch-cuda=11.7
14 | - bedtools=2.30.0
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "hatchling.build"
3 | requires = ["hatchling"]
4 |
5 | [project]
6 | name = "nichecompass"
7 | version = "0.2.3"
8 | description = "End-to-end analysis of spatial multi-omics data"
9 | readme = "README.md"
10 | requires-python = ">=3.9"
11 | license = {file = "LICENSE"}
12 | authors = [
13 | {name = "Sebastian Birk"},
14 | ]
15 | maintainers = [
16 | {name = "Sebastian Birk", email = "sebastian.birk@outlook.com"},
17 | ]
18 | urls.Documentation = "https://nichecompass.readthedocs.io/"
19 | urls.Source = "https://github.com/Lotfollahi-lab/nichecompass"
20 | urls.Home-page = "https://github.com/Lotfollahi-lab/nichecompass"
21 | classifiers = [
22 | "Development Status :: 3 - Alpha",
23 | "Intended Audience :: Science/Research",
24 | "Natural Language :: English",
25 | "Programming Language :: Python :: 3.9",
26 | "Operating System :: POSIX :: Linux",
27 | "Topic :: Scientific/Engineering :: Bio-Informatics",
28 | ]
29 | dependencies = [
30 | "mlflow>=1.28.0",
31 | "pyreadr>=0.4.6",
32 | "scanpy>=1.9.3,<1.10.0",
33 | "torch-geometric>=2.2.0",
34 | "omnipath>=1.0.7",
35 | "decoupler>=1.4.0",
36 | "numpy<2",
37 | ]
38 | optional-dependencies.dev = [
39 | "pre-commit",
40 | "twine>=4.0.2",
41 | ]
42 | optional-dependencies.docs = [
43 | "docutils>=0.8,!=0.18.*,!=0.19.*",
44 | "sphinx>=4.1",
45 | "sphinx-book-theme>=1.0.0",
46 | "myst-nb",
47 | "myst-parser",
48 | "sphinxcontrib-bibtex>=1.0.0",
49 | "sphinx-autodoc-typehints",
50 | "sphinx_rtd_theme",
51 | "sphinxext-opengraph",
52 | "sphinx-copybutton",
53 | "sphinx-design",
54 | "sphinx-hoverxref",
55 | # For notebooks
56 | "ipykernel",
57 | "ipython",
58 | "pandas",
59 | ]
60 | optional-dependencies.docsbuild = [
61 | "nichecompass[docs,benchmarking,multimodal]"
62 | ] # docs build dependencies
63 | optional-dependencies.tests = [
64 | "pytest",
65 | "coverage",
66 | ]
67 | optional-dependencies.benchmarking = [
68 | "scib-metrics>=0.3.3",
69 | "pynndescent>=0.5.8",
70 | "scikit-misc>=0.3.0",
71 | "squidpy>=1.2.2",
72 | "jax==0.4.7",
73 | "jaxlib==0.4.7"
74 | ]
75 | optional-dependencies.multimodal = [
76 | "scglue>=0.3.2",
77 | ]
78 | optional-dependencies.tutorials = [
79 | "jupyter",
80 | ]
81 | all = ["nichecompass[dev,docs,tests,benchmarking,multimodal,tutorials]"]
82 |
83 | [tool.coverage.run]
84 | source = ["nichecompass"]
85 | omit = [
86 | "**/test_*.py",
87 | ]
88 |
89 | [tool.pytest.ini_options]
90 | testpaths = ["tests"]
91 | xfail_strict = true
92 | addopts = [
93 | "--import-mode=importlib", # allow using test files with same name
94 | ]
95 |
96 | [tool.ruff]
97 | line-length = 120
98 | src = ["src"]
99 | extend-include = ["*.ipynb"]
100 |
101 | [tool.ruff.format]
102 | docstring-code-format = true
103 |
104 | [tool.ruff.lint]
105 | select = [
106 | "F", # Errors detected by Pyflakes
107 | "E", # Error detected by Pycodestyle
108 | "W", # Warning detected by Pycodestyle
109 | "I", # isort
110 | "D", # pydocstyle
111 | "B", # flake8-bugbear
112 | "TID", # flake8-tidy-imports
113 | "C4", # flake8-comprehensions
114 | "BLE", # flake8-blind-except
115 | "UP", # pyupgrade
116 | "RUF100", # Report unused noqa directives
117 | ]
118 | ignore = [
119 | # line too long -> we accept long comment lines; formatter gets rid of long code lines
120 | "E501",
121 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
122 | "E731",
123 | # allow I, O, l as variable names -> I is the identity matrix
124 | "E741",
125 | # Missing docstring in public package
126 | "D104",
127 | # Missing docstring in public module
128 | "D100",
129 | # Missing docstring in __init__
130 | "D107",
131 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
132 | "B008",
133 | # __magic__ methods are are often self-explanatory, allow missing docstrings
134 | "D105",
135 | # first line should end with a period [Bug: doesn't work with single-line docstrings]
136 | "D400",
137 | # First line should be in imperative mood; try rephrasing
138 | "D401",
139 | ## Disable one in each pair of mutually incompatible rules
140 | # We don’t want a blank line before a class docstring
141 | "D203",
142 | # We want docstrings to start immediately after the opening triple quote
143 | "D213",
144 | ]
145 |
146 | [tool.ruff.lint.pydocstyle]
147 | convention = "numpy"
148 |
149 | [tool.ruff.lint.per-file-ignores]
150 | "docs/*" = ["I"]
151 | "tests/*" = ["D"]
152 | "*/__init__.py" = ["F401"]
153 |
154 | [tool.cruft]
155 | skip = [
156 | "tests",
157 | "src/**/__init__.py",
158 | "src/**/basic.py",
159 | "docs/api.md",
160 | "docs/changelog.md",
161 | "docs/references.bib",
162 | "docs/references.md",
163 | "docs/notebooks/example.ipynb",
164 | ]
165 |
--------------------------------------------------------------------------------
/src/nichecompass/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 |
3 | from . import data, models, modules, nn, train, utils
4 |
5 | __all__ = ["data", "models", "modules", "nn", "train", "utils"]
6 |
7 | __version__ = version("nichecompass")
--------------------------------------------------------------------------------
/src/nichecompass/benchmarking/__init__.py:
--------------------------------------------------------------------------------
1 | # This is a trick to make jax use the right cudnn version (needs to be executed
2 | # before importing scanpy)
3 | #import jax.numpy as jnp
4 | #temp_array = jnp.array([1, 2, 3])
5 | #temp_idx = jnp.array([1])
6 | #temp_array[temp_idx]
7 |
8 | from .cas import compute_avg_cas, compute_cas
9 | from .clisis import compute_clisis
10 | from .gcs import compute_avg_gcs, compute_gcs
11 | from .metrics import compute_benchmarking_metrics
12 | from .mlami import compute_mlami
13 | from .nasw import compute_nasw
14 |
15 | __all__ = ["compute_avg_cas",
16 | "compute_avg_gcs",
17 | "compute_benchmarking_metrics",
18 | "compute_cas",
19 | "compute_clisis",
20 | "compute_gcs",
21 | "compute_mlami",
22 | "compute_nasw"]
--------------------------------------------------------------------------------
/src/nichecompass/benchmarking/clisis.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the Cell Type Local Inverse Simpson's Index Similarity
3 | (CLISIS) benchmark for testing how accurately the latent nearest neighbor graph
4 | preserves local neighborhood cell type heterogeneity from the spatial (physical)
5 | nearest neighbor graph. It is a measure for local cell type neighborhood
6 | preservation.
7 | """
8 |
9 | from typing import Optional
10 |
11 | import numpy as np
12 | from anndata import AnnData
13 | from scib_metrics import lisi_knn
14 |
15 | from .utils import compute_knn_graph_connectivities_and_distances
16 |
17 |
18 | def compute_clisis(
19 | adata: AnnData,
20 | cell_type_key: str="cell_type",
21 | batch_key: Optional[str]=None,
22 | spatial_knng_key: str="spatial_knng",
23 | latent_knng_key: str="nichecompass_latent_knng",
24 | spatial_key: Optional[str]="spatial",
25 | latent_key: Optional[str]="nichecompass_latent",
26 | n_neighbors: Optional[int]=90,
27 | n_jobs: int=1,
28 | seed: int=0) -> float:
29 | """
30 | Compute the Cell Type Local Inverse Simpson's Index Similarity (CLISIS). The
31 | CLISIS measures how accurately the latent nearest neighbor graph preserves
32 | local neighborhood cell type heterogeneity from the spatial nearest neighbor
33 | graph. The CLISIS ranges between '0' and '1' with higher values indicating
34 | better local neighborhood cell type heterogeneity preservation. It is
35 | computed by first calculating the Cell Type Local Inverse Simpson's Index
36 | (CLISI) as proposed by Luecken, M. D. et al. Benchmarking atlas-level data
37 | integration in single-cell genomics. Nat. Methods 19, 41–50 (2022) on the
38 | latent and spatial nearest neighbor graph(s) respectively.* Afterwards, the
39 | ratio of the two CLISI scores is taken and logarithmized as proposed by
40 | Heidari, E. et al. Supervised spatial inference of dissociated single-cell
41 | data with SageNet. bioRxiv 2022.04.14.488419 (2022)
42 | doi:10.1101/2022.04.14.488419, leveraging the properties of the log that
43 | np.log2(x/y) = -np.log2(y/x) and np.log2(x/x) = 0. At this stage, values
44 | closer to 0 indicate better local neighborhood cell type heterogeneity
45 | preservation. We then normalize the resulting value by the maximum possible
46 | value that would occur in the case of minimal local neighborhood cell type
47 | preservation to scale our metric between '0' and '1'. Finally, we compute
48 | the median of the absolute normalized scores and subtract it from 1 so that
49 | values closer to '1' indicate better local neighborhood cell type
50 | heterogeneity preservation.
51 |
52 | If a ´batch_key´ is provided, separate spatial nearest neighbor graphs per
53 | batch will be computed and the spatial clisi scores are computed for each
54 | batch separately.
55 |
56 | If existent, uses precomputed nearest neighbor graphs stored in
57 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
58 | ´adata.obsp[latent_knng_key + '_connectivities']´.
59 | Alternatively, computes them on the fly using ´spatial_key´, ´latent_key´
60 | and ´n_neighbors´, and stores them in
61 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
62 | ´adata.obsp[latent_knng_key + '_connectivities']´ respectively.
63 |
64 | * The Inverse Simpson's Index measures the expected number of
65 | samples needed to be sampled before two are drawn from the same category.
66 | The Local Inverse Simpson's Index combines perplexity-based neighborhood
67 | construction with the Inverse Simpson's Index to account for distances
68 | between neighbors. The CLISI score is the LISI applied to cell nearest
69 | neighbor graphs with cell types as categories, and indicates the effective
70 | number of different cell types represented in the local neighborhood of each
71 | cell. If the cells are well mixed, we might expect the CLISI score to be
72 | close to the number of unique cell types (e.g. neigborhoods with an equal
73 | number of cells from 2 cell types get a CLISI of 2). Note, however, that
74 | even under perfect mixing, the value would be smaller than the number of
75 | unique cell types if the absolute number of cells is different for different
76 | cell types.
77 |
78 | Parameters
79 | ----------
80 | adata:
81 | AnnData object with cell type annotations stored in
82 | ´adata.obs[cell_type_key]´, precomputed nearest neighbor graphs stored
83 | in ´adata.obsp[spatial_knng_key + '_connectivities']´ and
84 | ´adata.obsp[latent_knng_key + '_connectivities']´ or spatial coordinates
85 | stored in ´adata.obsm[spatial_key]´ and the latent representation from a
86 | model stored in ´adata.obsm[latent_key]´.
87 | cell_type_key:
88 | Key under which the cell type annotations are stored in ´adata.obs´.
89 | batch_key:
90 | Key under which the batches are stored in ´adata.obs´. If ´None´, the
91 | adata is assumed to only have one unique batch.
92 | spatial_knng_key:
93 | Key under which the spatial nearest neighbor graph is / will be stored
94 | in ´adata.obsp´ with the suffix '_connectivities'.
95 | latent_knng_key:
96 | Key under which the latent nearest neighbor graph is / will be stored in
97 | ´adata.obsp´ with the suffix '_connectivities'.
98 | spatial_key:
99 | Key under which the spatial coordinates are stored in ´adata.obsm´.
100 | latent_key:
101 | Key under which the latent representation from a model is stored in
102 | ´adata.obsm´.
103 | n_neighbors:
104 | Number of neighbors used for the construction of the nearest neighbor
105 | graphs from the spatial coordinates and the latent representation from
106 | a model in case they are constructed.
107 | n_jobs:
108 | Number of jobs to use for parallelization of neighbor search.
109 | seed:
110 | Random seed for reproducibility.
111 |
112 | Returns
113 | ----------
114 | clisis:
115 | The Cell Type Local Inverse Simpson's Index Similarity.
116 | """
117 | # Adding '_connectivities' as expected / added by
118 | # 'compute_knn_graph_connectivities_and_distances'
119 | spatial_knng_connectivities_key = spatial_knng_key + "_connectivities"
120 | latent_knng_connectivities_key = latent_knng_key + "_connectivities"
121 |
122 | if spatial_knng_connectivities_key in adata.obsp:
123 | print("Using precomputed spatial nearest neighbor graph...")
124 |
125 | print("Computing spatial cell CLISI scores for entire dataset...")
126 | spatial_cell_clisi_scores = lisi_knn(
127 | X=adata.obsp[f"{spatial_knng_key}_distances"],
128 | labels=adata.obs[cell_type_key])
129 |
130 | elif batch_key is None:
131 | print("Computing spatial nearest neighbor graph for entire dataset...")
132 | compute_knn_graph_connectivities_and_distances(
133 | adata=adata,
134 | feature_key=spatial_key,
135 | knng_key=spatial_knng_key,
136 | n_neighbors=n_neighbors,
137 | random_state=seed,
138 | n_jobs=n_jobs)
139 |
140 | print("Computing spatial cell CLISI scores for entire dataset...")
141 | spatial_cell_clisi_scores = lisi_knn(
142 | X=adata.obsp[f"{spatial_knng_key}_distances"],
143 | labels=adata.obs[cell_type_key])
144 |
145 | elif batch_key is not None:
146 | # Compute cell CLISI scores for spatial nearest neighbor graph
147 | # of each batch separately and add to array
148 | unique_batches = adata.obs[batch_key].unique().tolist()
149 | spatial_cell_clisi_scores = np.zeros(len(adata))
150 | adata.obs["index"] = np.arange(len(adata))
151 | for batch in unique_batches:
152 | adata_batch = adata[adata.obs[batch_key] == batch]
153 |
154 | print("Computing spatial nearest neighbor graph for "
155 | f"{batch_key} {batch}...")
156 | # Compute batch-specific spatial (ground truth) nearest
157 | # neighbor graph
158 | compute_knn_graph_connectivities_and_distances(
159 | adata=adata_batch,
160 | feature_key=spatial_key,
161 | knng_key=spatial_knng_key,
162 | n_neighbors=n_neighbors,
163 | random_state=seed,
164 | n_jobs=n_jobs)
165 |
166 | print("Computing spatial cell CLISI scores for "
167 | f"{batch_key} {batch}...")
168 | batch_spatial_cell_clisi_scores = lisi_knn(
169 | X=adata_batch.obsp[f"{spatial_knng_key}_distances"],
170 | labels=adata_batch.obs[cell_type_key])
171 |
172 | # Save results
173 | spatial_cell_clisi_scores[adata_batch.obs["index"].values] = (
174 | batch_spatial_cell_clisi_scores)
175 |
176 | if latent_knng_connectivities_key in adata.obsp:
177 | print("Using precomputed latent nearest neighbor graph...")
178 | else:
179 | print("Computing latent nearest neighbor graph...")
180 | # Compute latent connectivities
181 | compute_knn_graph_connectivities_and_distances(
182 | adata=adata,
183 | feature_key=latent_key,
184 | knng_key=latent_knng_key,
185 | n_neighbors=n_neighbors,
186 | random_state=seed,
187 | n_jobs=n_jobs)
188 |
189 | print("Computing latent cell CLISI scores...")
190 | latent_cell_clisi_scores = lisi_knn(
191 | X=adata.obsp[f"{latent_knng_key}_distances"],
192 | labels=adata.obs[cell_type_key])
193 |
194 | print("Computing CLISIS...")
195 | cell_rclisi_scores = latent_cell_clisi_scores / spatial_cell_clisi_scores
196 | cell_log_rclisi_scores = np.log2(cell_rclisi_scores)
197 |
198 | n_cell_types = adata.obs[cell_type_key].nunique()
199 | max_cell_log_rclisi = np.log2(n_cell_types / 1)
200 | norm_cell_log_rclisi_scores = cell_log_rclisi_scores / max_cell_log_rclisi
201 |
202 | clisis = (1 - np.nanmedian(abs(norm_cell_log_rclisi_scores)))
203 | return clisis
--------------------------------------------------------------------------------
/src/nichecompass/benchmarking/gcs.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the Graph Connectivity Similarity (GCS) benchmark for
3 | testing how accurately the latent nearest neighbor graph preserves edges and
4 | non-edges from the spatial (physical) nearest neighbor graph.
5 | """
6 |
7 | from typing import Optional
8 |
9 | import numpy as np
10 | import scipy.sparse as sp
11 | from anndata import AnnData
12 |
13 | from .utils import compute_knn_graph_connectivities_and_distances
14 |
15 |
16 | def compute_gcs(
17 | adata: AnnData,
18 | batch_key: Optional[str]=None,
19 | spatial_knng_key: str="spatial_knng",
20 | latent_knng_key: str="nichecompass_latent_knng",
21 | spatial_key: Optional[str]="spatial",
22 | latent_key: Optional[str]="nichecompass_latent",
23 | n_neighbors: Optional[int]=15,
24 | n_jobs: int=1,
25 | seed: int=0):
26 | """
27 | Compute the graph connectivity similarity (GCS). The GCS measures how
28 | accurately the latent nearest neighbor graph preserves edges and non-edges
29 | from the spatial (ground truth) nearest neighbor graph. A value of '1'
30 | indicates perfect graph similarity and a value of '0' indicates no graph
31 | connectivity similarity at all.
32 |
33 | If a ´batch_key´ is provided, the GCS will be computed on each batch
34 | separately, and the average across all batches is returned.
35 |
36 | If existent, uses precomputed nearest neighbor graphs stored in
37 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
38 | ´adata.obsp[latent_knng_key + '_connectivities']´.
39 | Alternatively, computes them on the fly using ´spatial_key´, ´latent_key´
40 | and ´n_neighbors´, and stores them in
41 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
42 | ´adata.obsp[latent_knng_key + '_connectivities']´ respectively.
43 |
44 | Parameters
45 | ----------
46 | adata:
47 | AnnData object with precomputed nearest neighbor graphs stored in
48 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
49 | ´adata.obsp[latent_knng_key + '_connectivities']´ or spatial coordinates
50 | stored in ´adata.obsm[spatial_key]´ and the latent representation from a
51 | model stored in ´adata.obsm[latent_key]´.
52 | batch_key:
53 | Key under which the batches are stored in ´adata.obs´.
54 | spatial_knng_key:
55 | Key under which the spatial nearest neighbor graph is / will be stored
56 | in ´adata.obsp´ with the suffix '_connectivities'.
57 | latent_knng_key:
58 | Key under which the latent nearest neighbor graph is / will be stored in
59 | ´adata.obsp´ with the suffix '_connectivities'.
60 | spatial_key:
61 | Key under which the spatial coordinates are stored in ´adata.obsm´.
62 | latent_key:
63 | Key under which the latent representation from a model is stored in
64 | ´adata.obsm´.
65 | n_neighbors:
66 | Number of neighbors used for the construction of the nearest neighbor
67 | graphs from the spatial coordinates and the latent representation from
68 | a model in case they are constructed.
69 | n_jobs:
70 | Number of jobs to use for parallelization of neighbor search.
71 | seed:
72 | Random seed for reproducibility.
73 |
74 | Returns
75 | ----------
76 | gcs:
77 | Normalized matrix similarity between the spatial nearest neighbor graph
78 | and the latent nearest neighbor graph as measured by one minus the
79 | size-normalized Frobenius norm of the element-wise matrix differences.
80 | """
81 | # Adding '_connectivities' as expected / added by
82 | # 'compute_knn_graph_connectivities_and_distances'
83 | spatial_knng_connectivities_key = spatial_knng_key + "_connectivities"
84 | latent_knng_connectivities_key = latent_knng_key + "_connectivities"
85 |
86 | if batch_key is not None:
87 | adata_batch_list = []
88 | unique_batches = adata.obs[batch_key].unique().tolist()
89 | for batch in unique_batches:
90 | adata_batch = adata[adata.obs[batch_key] == batch]
91 | adata_batch_list.append(adata_batch)
92 |
93 | if spatial_knng_connectivities_key in adata.obsp:
94 | print("Using precomputed spatial nearest neighbor graph...")
95 |
96 | elif batch_key is None:
97 | print("Computing spatial nearest neighbor graph for entire dataset...")
98 | compute_knn_graph_connectivities_and_distances(
99 | adata=adata,
100 | feature_key=spatial_key,
101 | knng_key=spatial_knng_key,
102 | n_neighbors=n_neighbors,
103 | random_state=seed,
104 | n_jobs=n_jobs)
105 |
106 | elif batch_key is not None:
107 | # Compute spatial nearest neighbor graph for each batch separately
108 | for i, batch in enumerate(unique_batches):
109 | print("Computing spatial nearest neighbor graph for "
110 | f"{batch_key} {batch}...")
111 | compute_knn_graph_connectivities_and_distances(
112 | adata=adata_batch_list[i],
113 | feature_key=spatial_key,
114 | knng_key=spatial_knng_key,
115 | n_neighbors=n_neighbors,
116 | random_state=seed,
117 | n_jobs=n_jobs)
118 |
119 | if latent_knng_connectivities_key in adata.obsp:
120 | print("Using precomputed latent nearest neighbor graph...")
121 | elif batch_key is None:
122 | print("Computing latent nearest neighbor graph for entire dataset...")
123 | compute_knn_graph_connectivities_and_distances(
124 | adata=adata,
125 | feature_key=latent_key,
126 | knng_key=latent_knng_key,
127 | n_neighbors=n_neighbors,
128 | random_state=seed,
129 | n_jobs=n_jobs)
130 | elif batch_key is not None:
131 | # Compute latent nearest neighbor graph for each batch separately
132 | for i, batch in enumerate(unique_batches):
133 | print("Computing latent nearest neighbor graph for "
134 | f"{batch_key} {batch}...")
135 | compute_knn_graph_connectivities_and_distances(
136 | adata=adata_batch_list[i],
137 | feature_key=latent_key,
138 | knng_key=latent_knng_key,
139 | n_neighbors=n_neighbors,
140 | random_state=seed,
141 | n_jobs=n_jobs)
142 |
143 | if batch_key is None:
144 | print("Computing GCS for entire dataset...")
145 | n_neighbors = adata.uns[f"{latent_knng_key}_n_neighbors"]
146 | # Compute Frobenius norm of the matrix of differences to quantify
147 | # distance (square root of the sum of absolute squares)
148 | connectivities_diff = (
149 | adata.obsp[latent_knng_connectivities_key] -
150 | adata.obsp[spatial_knng_connectivities_key])
151 | gcd = sp.linalg.norm(connectivities_diff,
152 | ord="fro")
153 |
154 | # Normalize gcd to be between 0 and 1 and convert to gcs by subtracting
155 | # from 1. Maximum number of differences per node is 2 * n_neighbors (
156 | # sc.pp.neighbors returns a weighted symmetric knn graph with the node-
157 | # wise sums of weights not exceeding the number of neighbors; the
158 | # maximum difference for a node is reached if none of the neighbors
159 | # coincide for the node)
160 | gcs = 1 - (gcd ** 2 / (n_neighbors * 2 * connectivities_diff.shape[0]))
161 | elif batch_key is not None:
162 | # Compute GCS per batch and average
163 | gcs_list = []
164 | for i, batch in enumerate(unique_batches):
165 | print(f"Computing GCS for {batch_key} {batch}...")
166 | n_neighbors = adata_batch_list[i].uns[
167 | f"{latent_knng_key}_n_neighbors"]
168 | batch_connectivities_diff = (
169 | adata_batch_list[i].obsp[latent_knng_connectivities_key] -
170 | adata_batch_list[i].obsp[spatial_knng_connectivities_key])
171 | batch_gcd = sp.linalg.norm(batch_connectivities_diff,
172 | ord="fro")
173 | batch_gcs = 1 - (
174 | batch_gcd ** 2 / (
175 | n_neighbors * 2 * batch_connectivities_diff.shape[0]))
176 | gcs_list.append(batch_gcs)
177 | gcs = np.mean(gcs_list)
178 | return gcs
179 |
180 |
181 | def compute_avg_gcs(
182 | adata: AnnData,
183 | batch_key: Optional[str]=None,
184 | spatial_key: str="spatial",
185 | latent_key: str="nichecompass_latent",
186 | min_n_neighbors: int=1,
187 | max_n_neighbors: int=15,
188 | seed: int=0) -> float:
189 | """
190 | Compute multiple Graph Connectivity Similarities (GCS) by varying the number
191 | of neighbors used for nearest neighbor graph construction (between
192 | ´min_n_neighbors´ and ´max_n_neighbors´) and return the average GCS. Can use
193 | precomputed spatial and latent nearest neighbor graphs stored in
194 | ´adata.obsp[f'{spatial_key}_{n_neighbors}nng_connectivities']´ and
195 | ´adata.obsp[f'{latent_key}_{n_neighbors}nng_connectivities']´ respectively.
196 |
197 | Parameters
198 | ----------
199 | adata:
200 | AnnData object with spatial coordinates stored in
201 | ´adata.obsm[spatial_key]´ and the latent representation from a model
202 | stored in ´adata.obsm[latent_key]´. Precomputed nearest neighbor graphs
203 | can optionally be stored in
204 | ´adata.obsp[f'{spatial_key}_{n_neighbors}nng_connectivities']´
205 | and ´adata.obsp[f'{latent_key}_{n_neighbors}nng_connectivities']´.
206 | batch_key:
207 | Key under which the batches are stored in ´adata.obs´. If ´None´, the
208 | adata is assumed to only have one unique batch.
209 | spatial_key:
210 | Key under which the spatial coordinates are stored in ´adata.obsm´.
211 | latent_key:
212 | Key under which the latent representation from a model is stored in
213 | ´adata.obsm´.
214 | min_n_neighbors:
215 | Minimum number of neighbors used for computing the average GCS.
216 | max_n_neighbors:
217 | Maximum number of neighbors used for computing the average GCS.
218 | seed:
219 | Random seed for reproducibility.
220 |
221 | Returns
222 | ----------
223 | avg_gcs:
224 | Average GCS computed over different nearest neighbor graphs with varying
225 | number of neighbors.
226 | """
227 | gcs_list = []
228 | for n_neighbors in range(min_n_neighbors, max_n_neighbors):
229 | gcs_list.append(compute_gcs(
230 | adata=adata,
231 | batch_key=batch_key,
232 | spatial_knng_key=f"{spatial_key}_{n_neighbors}nng",
233 | latent_knng_key=f"{latent_key}_{n_neighbors}nng",
234 | spatial_key=spatial_key,
235 | latent_key=latent_key,
236 | n_neighbors=n_neighbors,
237 | seed=seed))
238 | avg_gcs = np.mean(gcs_list)
239 | return avg_gcs
240 |
--------------------------------------------------------------------------------
/src/nichecompass/benchmarking/mlami.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the Maximum Leiden Adjusted Mutual Info (MLAMI) benchmark
3 | for testing how accurately the latent feature space preserves global spatial
4 | organization from the spatial (physical) feature space by comparing clustering
5 | overlaps.
6 | """
7 |
8 | from typing import Optional
9 |
10 | import numpy as np
11 | import scanpy as sc
12 | from anndata import AnnData
13 | from sklearn.metrics import adjusted_mutual_info_score
14 |
15 | from .utils import compute_knn_graph_connectivities_and_distances
16 |
17 |
18 | def compute_mlami(
19 | adata: AnnData,
20 | batch_key: Optional[str]=None,
21 | spatial_knng_key: str="spatial_knng",
22 | latent_knng_key: str="nichecompass_latent_knng",
23 | spatial_key: Optional[str]="spatial",
24 | latent_key: Optional[str]="nichecompass_latent",
25 | n_neighbors: Optional[int]=15,
26 | min_res: float=0.1,
27 | max_res: float=1.0,
28 | res_num: int=3,
29 | n_jobs: int=1,
30 | seed: int=0) -> float:
31 | """
32 | Compute the Maximum Leiden Adjusted Mutual Info (MLAMI). The MLAMI ranges
33 | between '0' and '1' with higher values indicating that the latent feature
34 | space more accurately preserves global spatial organization from the spatial
35 | (ground truth) feature space. To compute the MLAMI, Leiden clusterings with
36 | different resolutions are computed for both nearest neighbor graphs. The
37 | Adjusted Mutual Info (AMI) between all clustering resolution pairs is
38 | computed to quantify cluster overlap and the maximum value is returned as
39 | metric for spatial organization preservation.
40 |
41 | If a ´batch_key´ is provided, the MLAMI will be computed on each batch
42 | separately (with latent Leiden clusters computed on the integrated latent
43 | space), and the average across all batches is returned.
44 |
45 | If existent, uses precomputed nearest neighbor graphs stored in
46 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
47 | ´adata.obsp[latent_knng_key + '_connectivities']´.
48 | Alternatively, computes them on the fly using ´spatial_key´, ´latent_key´
49 | and ´n_neighbors´, and stores them in
50 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
51 | ´adata.obsp[latent_knng_key + '_connectivities']´ respectively.
52 |
53 | Parameters
54 | ----------
55 | adata:
56 | AnnData object with precomputed nearest neighbor graphs stored in
57 | ´adata.obsp[spatial_knng_key + '_connectivities']´ and
58 | ´adata.obsp[latent_knng_key + '_connectivities']´ or spatial coordinates
59 | stored in ´adata.obsm[spatial_key]´ and the latent representation from a
60 | model stored in ´adata.obsm[latent_key]´.
61 | batch_key:
62 | Key under which the batches are stored in ´adata.obs´. If ´None´, the
63 | adata is assumed to only have one unique batch.
64 | spatial_knng_key:
65 | Key under which the spatial nearest neighbor graph is / will be stored
66 | in ´adata.obsp´ with the suffix '_connectivities'.
67 | latent_knng_key:
68 | Key under which the latent nearest neighbor graph is / will be stored in
69 | ´adata.obsp´ with the suffix '_connectivities'.
70 | spatial_key:
71 | Key under which the spatial coordinates are stored in ´adata.obsm´.
72 | latent_key:
73 | Key under which the latent representation from a model is stored in
74 | ´adata.obsm´.
75 | n_neighbors:
76 | Number of neighbors used for the construction of the nearest neighbor
77 | graphs from the spatial coordinates and the latent representation from
78 | a model in case they are constructed.
79 | min_res:
80 | Minimum resolution for Leiden clustering.
81 | max_res:
82 | Maximum resolution for Leiden clustering.
83 | res_num:
84 | Number of linearly spaced Leiden resolutions between ´min_res´ and
85 | ´max_res´ for which Leiden clusterings will be computed.
86 | n_jobs:
87 | Number of jobs to use for parallelization of neighbor search.
88 | seed:
89 | Random seed for reproducibility.
90 |
91 | Returns
92 | ----------
93 | mlami:
94 | MLAMI between all clustering resolution pairs.
95 | """
96 | # Adding '_connectivities' as expected / added by
97 | # 'compute_knn_graph_connectivities_and_distances'
98 | spatial_knng_connectivities_key = spatial_knng_key + "_connectivities"
99 | latent_knng_connectivities_key = latent_knng_key + "_connectivities"
100 |
101 | if batch_key is not None:
102 | adata_batch_list = []
103 | unique_batches = adata.obs[batch_key].unique().tolist()
104 | for batch in unique_batches:
105 | adata_batch = adata[adata.obs[batch_key] == batch]
106 | adata_batch_list.append(adata_batch)
107 |
108 | if spatial_knng_connectivities_key in adata.obsp:
109 | print("Using precomputed spatial nearest neighbor graph...")
110 | elif batch_key is None:
111 | print("Computing spatial nearest neighbor graph for entire dataset...")
112 | compute_knn_graph_connectivities_and_distances(
113 | adata=adata,
114 | feature_key=spatial_key,
115 | knng_key=spatial_knng_key,
116 | n_neighbors=n_neighbors,
117 | random_state=seed,
118 | n_jobs=n_jobs)
119 | elif batch_key is not None:
120 | # Compute spatial nearest neighbor graph for each batch separately
121 | for i, batch in enumerate(unique_batches):
122 | print("Computing spatial nearest neighbor graph for "
123 | f"{batch_key} {batch}...")
124 | compute_knn_graph_connectivities_and_distances(
125 | adata=adata_batch_list[i],
126 | feature_key=spatial_key,
127 | knng_key=spatial_knng_key,
128 | n_neighbors=n_neighbors,
129 | random_state=seed,
130 | n_jobs=n_jobs)
131 |
132 | if latent_knng_connectivities_key in adata.obsp:
133 | print("Using precomputed latent nearest neighbor graph...")
134 | else:
135 | print("Computing latent nearest neighbor graph...")
136 | compute_knn_graph_connectivities_and_distances(
137 | adata=adata,
138 | feature_key=latent_key,
139 | knng_key=latent_knng_key,
140 | n_neighbors=n_neighbors,
141 | random_state=seed,
142 | n_jobs=n_jobs)
143 |
144 | # Define search space of clustering resolutions
145 | clustering_resolutions = np.linspace(start=min_res,
146 | stop=max_res,
147 | num=res_num,
148 | dtype=np.float32)
149 |
150 | if batch_key is None:
151 | print("Computing spatial Leiden clusterings for entire dataset...")
152 | # Calculate spatial Leiden clustering for different resolutions
153 | for resolution in clustering_resolutions:
154 | sc.tl.leiden(adata=adata,
155 | resolution=resolution,
156 | random_state=seed,
157 | key_added=f"leiden_spatial_{str(resolution)}",
158 | adjacency=adata.obsp[spatial_knng_connectivities_key])
159 | elif batch_key is not None:
160 | # Compute spatial Leiden clustering for each batch separately
161 | for i, batch in enumerate(unique_batches):
162 | print("Computing spatial Leiden clusterings for "
163 | f"{batch_key} {batch}...")
164 | # Calculate spatial Leiden clustering for different resolutions
165 | for resolution in clustering_resolutions:
166 | sc.tl.leiden(
167 | adata=adata_batch_list[i],
168 | resolution=resolution,
169 | random_state=seed,
170 | key_added=f"leiden_spatial_{str(resolution)}",
171 | adjacency=adata_batch_list[i].obsp[spatial_knng_connectivities_key])
172 |
173 | print("Computing latent Leiden clusterings...")
174 | # Calculate latent Leiden clustering for different resolutions
175 | for resolution in clustering_resolutions:
176 | sc.tl.leiden(adata,
177 | resolution=resolution,
178 | random_state=seed,
179 | key_added=f"leiden_latent_{str(resolution)}",
180 | adjacency=adata.obsp[latent_knng_connectivities_key])
181 | if batch_key is not None:
182 | for i, batch in enumerate(unique_batches):
183 | adata_batch_list[i].obs[f"leiden_latent_{str(resolution)}"] = (
184 | adata.obs[f"leiden_latent_{str(resolution)}"])
185 |
186 | if batch_key is None:
187 | print("Computing MLAMI for entire dataset...")
188 | # Calculate max LAMI over all clustering resolutions
189 | lami_list = []
190 | for spatial_resolution in clustering_resolutions:
191 | for latent_resolution in clustering_resolutions:
192 | lami_list.append(_compute_ami(
193 | adata=adata,
194 | cluster_group1_key=f"leiden_spatial_{str(spatial_resolution)}",
195 | cluster_group2_key=f"leiden_latent_{str(latent_resolution)}"))
196 | mlami = np.max(lami_list)
197 | elif batch_key is not None:
198 | for i, batch in enumerate(unique_batches):
199 | print(f"Computing MLAMI for {batch_key} {batch}...")
200 | batch_lami_list = []
201 | for spatial_resolution in clustering_resolutions:
202 | for latent_resolution in clustering_resolutions:
203 | batch_lami_list.append(_compute_ami(
204 | adata=adata_batch_list[i],
205 | cluster_group1_key=f"leiden_spatial_{str(spatial_resolution)}",
206 | cluster_group2_key=f"leiden_latent_{str(latent_resolution)}"))
207 | batch_mlami = np.max(batch_lami_list)
208 | mlami = np.mean(batch_mlami)
209 | return mlami
210 |
211 |
212 | def _compute_ami(adata: AnnData,
213 | cluster_group1_key: str,
214 | cluster_group2_key: str) -> float:
215 | """
216 | Compute the Adjusted Mutual Information (AMI) between two different
217 | cluster assignments. AMI compares the overlap of two clusterings. For
218 | details, see documentation at
219 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html#sklearn.metrics.adjusted_mutual_info_score.
220 |
221 | Parameters
222 | ----------
223 | adata:
224 | AnnData object with clustering labels stored in
225 | ´adata.obs[cluster_group1_key]´ and ´adata.obs[cluster_group2_key]´.
226 | cluster_group1_key:
227 | Key under which the clustering labels from the first clustering
228 | assignment are stored in ´adata.obs´.
229 | cluster_group2_key:
230 | Key under which the clustering labels from the second clustering
231 | assignment are stored in ´adata.obs´.
232 |
233 | Returns
234 | ----------
235 | ami:
236 | AMI score as calculated by the sklearn implementation.
237 | """
238 | cluster_group1 = adata.obs[cluster_group1_key].tolist()
239 | cluster_group2 = adata.obs[cluster_group2_key].tolist()
240 |
241 | ami = adjusted_mutual_info_score(cluster_group1,
242 | cluster_group2,
243 | average_method="arithmetic")
244 | return ami
245 |
--------------------------------------------------------------------------------
/src/nichecompass/benchmarking/nasw.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the Niche Average Silhouette Width (NASW) benchmark
3 | for testing how well the latent feature space can be clustered into distinct
4 | and compact clusters.
5 | """
6 |
7 | from typing import Optional
8 |
9 | import numpy as np
10 | import scanpy as sc
11 | import scib_metrics
12 | from anndata import AnnData
13 |
14 | from .utils import compute_knn_graph_connectivities_and_distances
15 |
16 |
17 | def compute_nasw(
18 | adata: AnnData,
19 | latent_knng_key: str="nichecompass_latent_knng",
20 | latent_key: Optional[str]="nichecompass_latent",
21 | n_neighbors: Optional[int]=15,
22 | min_res: float=0.1,
23 | max_res: float=1.0,
24 | res_num: int=3,
25 | n_jobs: int=1,
26 | seed: int=0) -> float:
27 | """
28 | Compute the Niche Average Silhouette Width (NASW). The NASW ranges between
29 | '0' and '1' with higher values indicating more distinct and compact
30 | clusters in the latent feature space. To compute the NASW, Leiden
31 | clusterings with different resolutions are computed for the latent nearest
32 | neighbor graph. The NASW for all clustering resolutions is computed and the
33 | average value is returned as metric for clusterability.
34 |
35 | If existent, uses a precomputed latent nearest neighbor graph stored in
36 | ´adata.obsp[latent_knng_key + '_connectivities']´.
37 | Alternatively, computes it on the fly using ´latent_key´ and ´n_neighbors´,
38 | and stores it in ´adata.obsp[latent_knng_key + '_connectivities']´.
39 |
40 | Parameters
41 | ----------
42 | adata:
43 | AnnData object with a precomputed latent nearest neighbor graph stored
44 | in ´adata.obsp[latent_knng_key + '_connectivities']´ or the latent
45 | representation from a model stored in ´adata.obsm[latent_key]´.
46 | latent_knng_key:
47 | Key under which the latent nearest neighbor graph is / will be stored in
48 | ´adata.obsp´ with the suffix '_connectivities'.
49 | latent_key:
50 | Key under which the latent representation from a model is stored in
51 | ´adata.obsm´.
52 | n_neighbors:
53 | Number of neighbors used for the construction of the latent nearest
54 | neighbor graph from the latent representation from a model in case they
55 | are constructed.
56 | min_res:
57 | Minimum resolution for Leiden clustering.
58 | max_res:
59 | Maximum resolution for Leiden clustering.
60 | res_num:
61 | Number of linearly spaced Leiden resolutions between ´min_res´ and
62 | ´max_res´ for which Leiden clusterings will be computed.
63 | n_jobs:
64 | Number of jobs to use for parallelization of neighbor search.
65 | seed:
66 | Random seed for reproducibility.
67 |
68 | Returns
69 | ----------
70 | nasw:
71 | Average NASW across all clustering resolutions.
72 | """
73 | # Adding '_connectivities' as expected / added by
74 | # 'compute_knn_graph_connectivities_and_distances'
75 | latent_knng_connectivities_key = latent_knng_key + "_connectivities"
76 |
77 | if latent_knng_connectivities_key in adata.obsp:
78 | print("Using precomputed latent nearest neighbor graph...")
79 | else:
80 | print("Computing latent nearest neighbor graph...")
81 | compute_knn_graph_connectivities_and_distances(
82 | adata=adata,
83 | feature_key=latent_key,
84 | knng_key=latent_knng_key,
85 | n_neighbors=n_neighbors,
86 | random_state=seed,
87 | n_jobs=n_jobs)
88 |
89 | # Define search space of clustering resolutions
90 | clustering_resolutions = np.linspace(start=min_res,
91 | stop=max_res,
92 | num=res_num,
93 | dtype=np.float32)
94 |
95 | print("Computing latent Leiden clusterings...")
96 | # Calculate latent Leiden clustering for different resolutions
97 | for resolution in clustering_resolutions:
98 | if not f"leiden_latent_{str(resolution)}" in adata.obs:
99 | sc.tl.leiden(adata,
100 | resolution=resolution,
101 | random_state=seed,
102 | key_added=f"leiden_latent_{str(resolution)}",
103 | adjacency=adata.obsp[latent_knng_connectivities_key])
104 | else:
105 | print("Using precomputed latent Leiden clusters for resolution "
106 | f"{str(resolution)}.")
107 |
108 | print("Computing NASW...")
109 | # Calculate max MNASW over all clustering resolutions
110 | nasw_list = []
111 | for resolution in clustering_resolutions:
112 | nasw_list.append(scib_metrics.silhouette_label(
113 | X=adata.obsm[latent_key],
114 | labels=adata.obs[f"leiden_latent_{str(resolution)}"]))
115 | nasw = np.mean(nasw_list)
116 | return nasw
--------------------------------------------------------------------------------
/src/nichecompass/benchmarking/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains helper functions for the ´benchmarking´ subpackage.
3 | """
4 |
5 | from typing import Optional
6 |
7 | import numpy as np
8 | import scanpy as sc
9 | from anndata import AnnData
10 | from scib_metrics.nearest_neighbors import pynndescent
11 |
12 |
13 | def compute_knn_graph_connectivities_and_distances(
14 | adata: AnnData,
15 | feature_key: str="nichecompass_latent",
16 | knng_key: str="nichecompass_latent_15knng",
17 | n_neighbors: int=15,
18 | random_state: int=0,
19 | n_jobs: int=1) -> None:
20 | """
21 | Compute approximate k-nearest-neighbors graph.
22 |
23 | Parameters
24 | ----------
25 | adata:
26 | AnnData object with the features for knn graph computation stored in
27 | ´adata.obsm[feature_key]´.
28 | feature_key:
29 | Key in ´adata.obsm´ that will be used to compute the knn graph.
30 | knng_key:
31 | Key under which the knn graph connectivities will be stored
32 | in ´adata.obsp´ with the suffix '_connectivities', the knn graph
33 | distances will be stored in ´adata.obsp´ with the suffix '_distances',
34 | and the number of neighbors will be stored in ´adata.uns with the suffix
35 | '_n_neighbors' .
36 | n_neighbors:
37 | Number of neighbors of the knn graph.
38 | random_state:
39 | Random state for reproducibility.
40 | n_jobs:
41 | Number of jobs to use for parallelization of neighbor search.
42 | """
43 | neigh_output = pynndescent(
44 | adata.obsm[feature_key],
45 | n_neighbors=n_neighbors,
46 | random_state=random_state,
47 | n_jobs=n_jobs)
48 | indices, distances = neigh_output.indices, neigh_output.distances
49 |
50 | # This is a trick to get lisi metrics to work by adding the tiniest possible value
51 | # to 0 distance neighbors so that each cell has the same amount of neighbors
52 | # (otherwise some cells lose neighbors with distance 0 due to sparse representation)
53 | row_idx = np.where(distances == 0)[0]
54 | col_idx = np.where(distances == 0)[1]
55 | new_row_idx = row_idx[np.where(row_idx != indices[row_idx, col_idx])[0]]
56 | new_col_idx = col_idx[np.where(row_idx != indices[row_idx, col_idx])[0]]
57 | distances[new_row_idx, new_col_idx] = (distances[new_row_idx, new_col_idx] +
58 | np.nextafter(0, 1, dtype=np.float32))
59 |
60 | sp_distances, sp_conns = sc.neighbors._compute_connectivities_umap(
61 | indices[:, :n_neighbors],
62 | distances[:, :n_neighbors],
63 | adata.n_obs,
64 | n_neighbors=n_neighbors)
65 | adata.obsp[f"{knng_key}_connectivities"] = sp_conns
66 | adata.obsp[f"{knng_key}_distances"] = sp_distances
67 | adata.uns[f"{knng_key}_n_neighbors"] = n_neighbors
68 |
69 |
70 | def convert_to_one_hot(vector: np.ndarray,
71 | n_classes: Optional[int]) -> np.array:
72 | """
73 | Converts an input 1-D vector of integer labels into a 2-D array of one-hot
74 | vectors, where for an i'th input value of j, a '1' will be inserted in the
75 | i'th row and j'th column of the output one-hot vector.
76 |
77 | Implementation is adapted from
78 | https://github.com/theislab/scib/blob/29f79d0135f33426481f9ff05dd1ae55c8787142/scib/metrics/lisi.py#L498
79 | (05.12.22).
80 |
81 | Parameters
82 | ----------
83 | vector:
84 | Vector to be one-hot-encoded.
85 | n_classes:
86 | Number of classes to be considered for one-hot-encoding. If ´None´, the
87 | number of classes will be inferred from ´vector´.
88 |
89 | Returns
90 | ----------
91 | one_hot:
92 | 2-D NumPy array of one-hot-encoded vectors.
93 |
94 | Example:
95 | ´´´
96 | vector = np.array((1, 0, 4))
97 | one_hot = _convert_to_one_hot(vector)
98 | print(one_hot)
99 | [[0 1 0 0 0]
100 | [1 0 0 0 0]
101 | [0 0 0 0 1]]
102 | ´´´
103 | """
104 | if n_classes is None:
105 | n_classes = np.max(vector) + 1
106 |
107 | one_hot = np.zeros(shape=(len(vector), n_classes))
108 | one_hot[np.arange(len(vector)), vector] = 1
109 | return one_hot.astype(int)
110 |
--------------------------------------------------------------------------------
/src/nichecompass/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloaders import initialize_dataloaders
2 | from .dataprocessors import (edge_level_split,
3 | node_level_split_mask,
4 | prepare_data)
5 | from .datareaders import load_spatial_adata_from_csv
6 | from .datasets import SpatialAnnTorchDataset
7 |
8 | __all__ = ["initialize_dataloaders",
9 | "edge_level_split",
10 | "node_level_split_mask",
11 | "prepare_data",
12 | "load_spatial_adata_from_csv",
13 | "SpatialAnnTorchDataset"]
14 |
--------------------------------------------------------------------------------
/src/nichecompass/data/dataloaders.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains dataloaders for the training of an NicheCompass model.
3 | """
4 |
5 | from typing import Optional
6 |
7 | from torch_geometric.data import Data
8 | from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
9 |
10 |
11 | def initialize_dataloaders(node_masked_data: Data,
12 | edge_train_data: Optional[Data]=None,
13 | edge_val_data: Optional[Data]=None,
14 | edge_batch_size: Optional[int]=64,
15 | node_batch_size: int=64,
16 | n_direct_neighbors: int=-1,
17 | n_hops: int=1,
18 | shuffle: bool=True,
19 | edges_directed: bool=False,
20 | neg_edge_sampling_ratio: float=1.) -> dict:
21 | """
22 | Initialize edge-level and node-level training and validation dataloaders.
23 |
24 | Parameters
25 | ----------
26 | node_masked_data:
27 | PyG Data object with node-level split masks.
28 | edge_train_data:
29 | PyG Data object containing the edge-level training set.
30 | edge_val_data:
31 | PyG Data object containing the edge-level validation set.
32 | edge_batch_size:
33 | Batch size for the edge-level dataloaders.
34 | node_batch_size:
35 | Batch size for the node-level dataloaders.
36 | n_direct_neighbors:
37 | Number of sampled direct neighbors of the current batch nodes to be
38 | included in the batch. Defaults to ´-1´, which means to include all
39 | direct neighbors.
40 | n_hops:
41 | Number of neighbor hops / levels for neighbor sampling of nodes to be
42 | included in the current batch. E.g. ´2´ means to not only include
43 | sampled direct neighbors of current batch nodes but also sampled
44 | neighbors of the direct neighbors.
45 | shuffle:
46 | If `True`, shuffle the dataloaders.
47 | edges_directed:
48 | If `False`, both symmetric edge index pairs are included in the same
49 | edge-level batch (1 edge has 2 symmetric edge index pairs).
50 | neg_edge_sampling_ratio:
51 | Negative sampling ratio of edges. This is currently implemented in an
52 | approximate way, i.e. negative edges may contain false negatives.
53 |
54 | Returns
55 | ----------
56 | loader_dict:
57 | Dictionary containing training and validation PyG LinkNeighborLoader
58 | (for edge reconstruction) and NeighborLoader (for gene expression
59 | reconstruction) objects.
60 | """
61 | loader_dict = {}
62 |
63 | # Node-level dataloaders
64 | loader_dict["node_train_loader"] = NeighborLoader(
65 | node_masked_data,
66 | num_neighbors=[n_direct_neighbors] * n_hops,
67 | batch_size=node_batch_size,
68 | directed=False,
69 | shuffle=shuffle,
70 | input_nodes=node_masked_data.train_mask)
71 | if node_masked_data.val_mask.sum() != 0:
72 | loader_dict["node_val_loader"] = NeighborLoader(
73 | node_masked_data,
74 | num_neighbors=[n_direct_neighbors] * n_hops,
75 | batch_size=node_batch_size,
76 | directed=False,
77 | shuffle=shuffle,
78 | input_nodes=node_masked_data.val_mask)
79 |
80 | # Edge-level dataloaders
81 | if edge_train_data is not None:
82 | loader_dict["edge_train_loader"] = LinkNeighborLoader(
83 | edge_train_data,
84 | num_neighbors=[n_direct_neighbors] * n_hops,
85 | batch_size=edge_batch_size,
86 | edge_label=None, # will automatically be added as 1 for all edges
87 | edge_label_index=edge_train_data.edge_label_index[:, edge_train_data.edge_label.bool()], # limit the edges to the ones from the edge_label_adj
88 | directed=edges_directed,
89 | shuffle=shuffle,
90 | neg_sampling_ratio=neg_edge_sampling_ratio)
91 | if edge_val_data is not None and edge_val_data.edge_label.sum() != 0:
92 | loader_dict["edge_val_loader"] = LinkNeighborLoader(
93 | edge_val_data,
94 | num_neighbors=[n_direct_neighbors] * n_hops,
95 | batch_size=edge_batch_size,
96 | edge_label=None, # will automatically be added as 1 for all edges
97 | edge_label_index=edge_val_data.edge_label_index[:, edge_val_data.edge_label.bool()], # limit the edges to the ones from the edge_label_adj
98 | directed=edges_directed,
99 | shuffle=shuffle,
100 | neg_sampling_ratio=neg_edge_sampling_ratio)
101 |
102 | return loader_dict
--------------------------------------------------------------------------------
/src/nichecompass/data/dataprocessors.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains data processors for the training of an NicheCompass model.
3 | """
4 |
5 | from typing import List, Optional, Tuple
6 |
7 | import scipy.sparse as sp
8 | import torch
9 | from anndata import AnnData
10 | from torch_geometric.data import Data
11 | from torch_geometric.transforms import RandomNodeSplit, RandomLinkSplit
12 | from torch_geometric.utils import add_self_loops, remove_self_loops
13 |
14 | from .datasets import SpatialAnnTorchDataset
15 |
16 |
17 | def edge_level_split(data: Data,
18 | edge_label_adj: Optional[sp.csr_matrix],
19 | val_ratio: float=0.1,
20 | test_ratio: float=0.,
21 | is_undirected: bool=True,
22 | neg_sampling_ratio: float=0.) -> Tuple[Data, Data, Data]:
23 | """
24 | Split a PyG Data object into training, validation and test PyG Data objects
25 | using an edge-level split. The training split does not include edges in the
26 | validation and test splits and the validation split does not include edges
27 | in the test split. However, nodes will not be split and all node features
28 | will be accessible from all splits.
29 |
30 | Check https://github.com/pyg-team/pytorch_geometric/issues/3668 for more
31 | context how RandomLinkSplit works.
32 |
33 | Parameters
34 | ----------
35 | data:
36 | PyG Data object to be split.
37 | edge_label_adj:
38 | Adjacency matrix which contains edges for edge reconstruction. If
39 | ´None´, uses the 'normal' adjacency matrix used for message passing.
40 | val_ratio:
41 | Ratio of edges to be included in the validation split.
42 | test_ratio:
43 | Ratio of edges to be included in the test split.
44 | is_undirected:
45 | If ´True´, the graph is assumed to be undirected, and positive and
46 | negative samples will not leak (reverse) edge connectivity across
47 | different splits. This is set to ´False´, as there is an issue with
48 | replication of self loops.
49 | neg_sampling_ratio:
50 | Ratio of negative sampling. This should be set to 0 if negative sampling
51 | is done by the dataloader.
52 |
53 | Returns
54 | ----------
55 | train_data:
56 | Training PyG Data object.
57 | val_data:
58 | Validation PyG Data object.
59 | test_data:
60 | Test PyG Data object.
61 | """
62 | # Clone data to not modify in place to not affect node split
63 | data_no_self_loops = data.clone()
64 |
65 | # Remove self loops temporarily as we don't want them as edge labels. There
66 | # is also an issue with RandomLinkSplit (self loops will be replicated for
67 | # message passing). We will add the self loops again after the split
68 | data_no_self_loops.edge_index, data_no_self_loops.edge_attr = (
69 | remove_self_loops(edge_index=data.edge_index,
70 | edge_attr=data.edge_attr))
71 |
72 | if edge_label_adj is not None:
73 | # Add edge label which is 1 for edges from edge_label_adj and 0 otherwise.
74 | # This will be used by dataloader to only sample edges from edge_label_adj
75 | # as opposed to from adj.
76 | data_no_self_loops.edge_label = torch.tensor(
77 | [(edge_label_adj[edge_index[0].item(), edge_index[1].item()] == 1.0) for
78 | edge_index in data_no_self_loops.edge_attr]).int()
79 |
80 | random_link_split = RandomLinkSplit(
81 | num_val=val_ratio,
82 | num_test=test_ratio,
83 | is_undirected=is_undirected,
84 | key="edge_label", # if ´edge_label´ is not existent, it will be added with 1s
85 | neg_sampling_ratio=neg_sampling_ratio)
86 | train_data, val_data, test_data = random_link_split(data_no_self_loops)
87 |
88 | # Readd self loops for message passing
89 | for split_data in [train_data, val_data, test_data]:
90 | split_data.edge_index = add_self_loops(
91 | edge_index=split_data.edge_index,
92 | num_nodes=split_data.x.shape[0])[0]
93 | split_data.edge_attr = add_self_loops(
94 | edge_index=split_data.edge_attr.t(),
95 | num_nodes=split_data.x.shape[0])[0].t()
96 | return train_data, val_data, test_data
97 |
98 |
99 | def node_level_split_mask(data: Data,
100 | val_ratio: float=0.1,
101 | test_ratio: float=0.,
102 | split_key: str="x") -> Data:
103 | """
104 | Split data on node-level into training, validation and test sets by adding
105 | node-level masks (train_mask, val_mask, test_mask) to the PyG Data object.
106 |
107 | Parameters
108 | ----------
109 | data:
110 | PyG Data object to be split.
111 | val_ratio:
112 | Ratio of nodes to be included in the validation split.
113 | test_ratio:
114 | Ratio of nodes to be included in the test split.
115 | split_key:
116 | The attribute key of the PyG Data object that holds the ground
117 | truth labels. Only nodes in which the key is present will be split.
118 |
119 | Returns
120 | ----------
121 | data:
122 | PyG Data object with ´train_mask´, ´val_mask´ and ´test_mask´ attributes
123 | added.
124 | """
125 | random_node_split = RandomNodeSplit(
126 | num_val=val_ratio,
127 | num_test=test_ratio,
128 | key=split_key)
129 | data = random_node_split(data)
130 | return data
131 |
132 |
133 | def prepare_data(adata: AnnData,
134 | cat_covariates_label_encoders: List[dict],
135 | adata_atac: Optional[AnnData]=None,
136 | counts_key: Optional[str]="counts",
137 | adj_key: str="spatial_connectivities",
138 | cat_covariates_keys: Optional[List[str]]=None,
139 | edge_val_ratio: float=0.1,
140 | edge_test_ratio: float=0.,
141 | node_val_ratio: float=0.1,
142 | node_test_ratio: float=0.) -> dict:
143 | """
144 | Prepare data for model training including edge-level and node-level train,
145 | validation, and test splits.
146 |
147 | Parameters
148 | ----------
149 | adata:
150 | AnnData object with counts stored in ´adata.layers[counts_key]´ or
151 | ´adata.X´ depending on ´counts_key´, and sparse adjacency matrix stored
152 | in ´adata.obsp[adj_key]´.
153 | adata_atac:
154 | Additional optional AnnData object with paired spatial ATAC data.
155 | cat_covariates_label_encoders:
156 | List of categorical covariates label encoders from the model (label
157 | encoding indeces need to be aligned with the ones from the model to get
158 | the correct categorical covariates embeddings).
159 | counts_key:
160 | Key under which the counts are stored in ´adata.layer´. If ´None´, uses
161 | ´adata.X´ as counts.
162 | adj_key:
163 | Key under which the sparse adjacency matrix is stored in ´adata.obsp´.
164 | cat_covariates_keys:
165 | Keys under which the categorical covariates are stored in ´adata.obs´.
166 | edge_val_ratio:
167 | Fraction of the data that is used as validation set on edge-level.
168 | edge_test_ratio:
169 | Fraction of the data that is used as test set on edge-level.
170 | node_val_ratio:
171 | Fraction of the data that is used as validation set on node-level.
172 | node_test_ratio:
173 | Fraction of the data that is used as test set on node-level.
174 |
175 | Returns
176 | ----------
177 | data_dict:
178 | Dictionary containing edge-level training, validation and test PyG
179 | Data objects and node-level PyG Data object with split masks under keys
180 | ´edge_train_data´, ´edge_val_data´, ´edge_test_data´, and
181 | ´node_masked_data´ respectively. The edge-level PyG Data objects contain
182 | edges in the ´edge_label_index´ attribute and edge labels in the
183 | ´edge_label´ attribute.
184 | """
185 | data_dict = {}
186 | dataset = SpatialAnnTorchDataset(
187 | adata=adata,
188 | adata_atac=adata_atac,
189 | counts_key=counts_key,
190 | adj_key=adj_key,
191 | cat_covariates_keys=cat_covariates_keys,
192 | cat_covariates_label_encoders=cat_covariates_label_encoders)
193 |
194 | # PyG Data object (has 2 edge index pairs for one edge because of symmetry;
195 | # one edge index pair will be removed in the edge-level split).
196 | data = Data(x=dataset.x,
197 | edge_index=dataset.edge_index,
198 | edge_attr=dataset.edge_index.t()) # store index of edge nodes as
199 | # edge attribute for
200 | # aggregation weight retrieval
201 | # in mini batches
202 |
203 | if cat_covariates_keys is not None:
204 | data.cat_covariates_cats = dataset.cat_covariates_cats
205 |
206 | # Edge-level split for edge reconstruction
207 | edge_train_data, edge_val_data, edge_test_data = edge_level_split(
208 | data=data,
209 | edge_label_adj=dataset.edge_label_adj,
210 | val_ratio=edge_val_ratio,
211 | test_ratio=edge_test_ratio)
212 | data_dict["edge_train_data"] = edge_train_data
213 | data_dict["edge_val_data"] = edge_val_data
214 | data_dict["edge_test_data"] = edge_test_data
215 |
216 | # Node-level split for gene expression reconstruction
217 | data_dict["node_masked_data"] = node_level_split_mask(
218 | data=data,
219 | val_ratio=node_val_ratio,
220 | test_ratio=node_test_ratio)
221 | return data_dict
--------------------------------------------------------------------------------
/src/nichecompass/data/datareaders.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains data readers for the training of an NicheCompass model.
3 | """
4 |
5 | from typing import Optional
6 |
7 | import anndata as ad
8 | import pandas as pd
9 | import scipy.sparse as sp
10 |
11 |
12 | def load_spatial_adata_from_csv(counts_file_path: str,
13 | adj_file_path: str,
14 | cell_type_file_path: Optional[str]=None,
15 | adj_key: str="spatial_connectivities",
16 | cell_type_col: str="cell_type",
17 | cell_type_key: str="cell_type") -> ad.AnnData:
18 | """
19 | Create AnnData object from two csv files containing gene expression feature
20 | matrix and adjacency matrix respectively. Optionally, a third csv file with
21 | cell types can be provided.
22 |
23 | Parameters
24 | ----------
25 | counts_file_path:
26 | File path of the csv file which contains gene expression feature matrix
27 | data.
28 | adj_file_path:
29 | File path of the csv file which contains adjacency matrix data.
30 | cell_type_file_path:
31 | File path of the csv file which contains cell type data.
32 | adj_key:
33 | Key under which the sparse adjacency matrix will be stored in
34 | ´adata.obsp´.
35 | cell_type_col:
36 | Column under wich the cell type is stored in the ´cell_type_file´.
37 | cell_type_key:
38 | Key under which the cell types will be stored in ´adata.obs´.
39 |
40 | Returns
41 | ----------
42 | adata:
43 | AnnData object with gene expression data stored in ´adata.X´ and sparse
44 | adjacency matrix (coo format) stored in ´adata.obps[adj_key]´.
45 | """
46 | adata = ad.read_csv(counts_file_path)
47 | adj_df = pd.read_csv(adj_file_path, sep=",", header=0)
48 | adj = adj_df.values
49 | adata.obsp[adj_key] = sp.csr_matrix(adj).tocoo()
50 |
51 | if cell_type_file_path:
52 | cell_type_df = pd.read_csv(cell_type_file_path, sep=",", header=0)
53 | adata.obs[cell_type_key] = cell_type_df[cell_type_col].values
54 | return adata
--------------------------------------------------------------------------------
/src/nichecompass/data/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the SpatialAnnTorchDataset class to provide a standardized
3 | dataset format for the training of an NicheCompass model.
4 | """
5 |
6 | from typing import List, Optional
7 |
8 | import scipy.sparse as sp
9 | import torch
10 | from anndata import AnnData
11 | from torch_geometric.utils import add_self_loops, remove_self_loops
12 |
13 | from .utils import encode_labels, sparse_mx_to_sparse_tensor
14 |
15 |
16 | class SpatialAnnTorchDataset():
17 | """
18 | Spatially annotated torch dataset class to extract node features, node
19 | labels, adjacency matrix and edge indices in a standardized format from an
20 | AnnData object.
21 |
22 | Parameters
23 | ----------
24 | adata:
25 | AnnData object with counts stored in ´adata.layers[counts_key]´ or
26 | ´adata.X´ depending on ´counts_key´, and sparse adjacency matrix stored
27 | in ´adata.obsp[adj_key]´.
28 | adata_atac:
29 | Additional optional AnnData object with paired spatial ATAC data.
30 | cat_covariates_label_encoders:
31 | List of categorical covariates label encoders from the model (label
32 | encoding indeces need to be aligned with the ones from the model to get
33 | the correct categorical covariates embeddings).
34 | counts_key:
35 | Key under which the counts are stored in ´adata.layer´. If ´None´, uses
36 | ´adata.X´ as counts.
37 | adj_key:
38 | Key under which the sparse adjacency matrix is stored in ´adata.obsp´.
39 | self_loops:
40 | If ´True´, add self loops to the adjacency matrix to model autocrine
41 | communication.
42 | cat_covariates_keys:
43 | Keys under which the categorical covariates are stored in ´adata.obs´.
44 | """
45 | def __init__(self,
46 | adata: AnnData,
47 | cat_covariates_label_encoders: List[dict],
48 | adata_atac: Optional[AnnData]=None,
49 | counts_key: Optional[str]="counts",
50 | adj_key: str="spatial_connectivities",
51 | edge_label_adj_key: str="edge_label_spatial_connectivities",
52 | self_loops: bool=True,
53 | cat_covariates_keys: Optional[str]=None):
54 | if counts_key is None:
55 | x = adata.X
56 | else:
57 | x = adata.layers[counts_key]
58 |
59 | # Store features in dense format
60 | if sp.issparse(x):
61 | self.x = torch.tensor(x.toarray())
62 | else:
63 | self.x = torch.tensor(x)
64 |
65 | # Concatenate ATAC feature vector in dense format if provided
66 | if adata_atac is not None:
67 | if sp.issparse(adata_atac.X):
68 | self.x = torch.cat(
69 | (self.x, torch.tensor(adata_atac.X.toarray())), axis=1)
70 | else:
71 | self.x = torch.cat((self.x, torch.tensor(adata_atac.X)), axis=1)
72 |
73 | # Store adjacency matrix in torch_sparse SparseTensor format
74 | if sp.issparse(adata.obsp[adj_key]):
75 | self.adj = sparse_mx_to_sparse_tensor(adata.obsp[adj_key])
76 | else:
77 | self.adj = sparse_mx_to_sparse_tensor(
78 | sp.csr_matrix(adata.obsp[adj_key]))
79 |
80 | # Store edge label adjacency matrix
81 | if edge_label_adj_key in adata.obsp:
82 | self.edge_label_adj = sp.csr_matrix(adata.obsp[edge_label_adj_key])
83 | else:
84 | self.edge_label_adj = None
85 |
86 | # Validate adjacency matrix symmetry
87 | if (self.adj.nnz() != self.adj.t().nnz()):
88 | raise ImportError("The input adjacency matrix has to be symmetric.")
89 |
90 | self.edge_index = self.adj.to_torch_sparse_coo_tensor()._indices()
91 |
92 | if self_loops:
93 | # Add self loops to account for autocrine communication
94 | # Remove self loops in case there are already before adding new ones
95 | self.edge_index, _ = remove_self_loops(self.edge_index)
96 | self.edge_index, _ = add_self_loops(self.edge_index,
97 | num_nodes=self.x.size(0))
98 |
99 | if cat_covariates_keys is not None:
100 | self.cat_covariates_cats = []
101 | for cat_covariate_key, cat_covariate_label_encoder in zip(
102 | cat_covariates_keys,
103 | cat_covariates_label_encoders):
104 | cat_covariate_cats = torch.tensor(
105 | encode_labels(adata,
106 | cat_covariate_label_encoder,
107 | cat_covariate_key), dtype=torch.long)
108 | self.cat_covariates_cats.append(cat_covariate_cats)
109 | self.cat_covariates_cats = torch.stack(self.cat_covariates_cats,
110 | dim=1)
111 |
112 | self.n_node_features = self.x.size(1)
113 | self.size_factors = self.x.sum(1) # fix for ATAC case
114 |
115 | def __len__(self):
116 | """Return the number of observations stored in SpatialAnnTorchDataset"""
117 | return self.x.size(0)
--------------------------------------------------------------------------------
/src/nichecompass/data/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains helper functions for the ´data´ subpackage.
3 | """
4 |
5 | import anndata as ad
6 | import numpy as np
7 | import torch
8 | from scipy.sparse import csr_matrix
9 | from torch_sparse import SparseTensor
10 |
11 |
12 | def encode_labels(adata: ad.AnnData,
13 | label_encoder: dict,
14 | label_key="str") -> np.ndarray:
15 | """
16 | Encode labels from an `adata` object stored in `adata.obs` to integers.
17 |
18 | Implementation is adapted from
19 | https://github.com/theislab/scarches/blob/c21492d409150cec73d26409f9277b3ac971f4a7/scarches/dataset/trvae/_utils.py#L4
20 | (20.01.2023).
21 |
22 | Parameters
23 | ----------
24 | adata:
25 | AnnData object with labels stored in `adata.obs[label_key]`.
26 | label_encoder:
27 | Dictionary where keys are labels and values are label encodings.
28 | label_key:
29 | Key where in `adata.obs` the labels to be encoded are stored.
30 |
31 | Returns
32 | -------
33 | encoded_labels:
34 | Integer-encoded labels.
35 | """
36 | unique_labels = list(np.unique(adata.obs[label_key]))
37 | encoded_labels = np.zeros(adata.shape[0])
38 |
39 | if not set(unique_labels).issubset(set(label_encoder.keys())):
40 | print(f"Warning: Labels in adata.obs[{label_key}] are not a subset of "
41 | "the label encoder!")
42 | print("Therefore integer value of those labels is set to '-1'.")
43 | for unique_label in unique_labels:
44 | if unique_label not in label_encoder.keys():
45 | encoded_labels[adata.obs[label_key] == unique_label] = -1
46 |
47 | for label, label_encoding in label_encoder.items():
48 | encoded_labels[adata.obs[label_key] == label] = label_encoding
49 | return encoded_labels
50 |
51 |
52 | def sparse_mx_to_sparse_tensor(sparse_mx: csr_matrix) -> SparseTensor:
53 | """
54 | Convert a scipy sparse matrix into a torch_sparse SparseTensor.
55 |
56 | Parameters
57 | ----------
58 | sparse_mx:
59 | Sparse scipy csr_matrix.
60 |
61 | Returns
62 | ----------
63 | sparse_tensor:
64 | torch_sparse SparseTensor object that can be utilized by PyG.
65 | """
66 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
67 | indices = torch.from_numpy(
68 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
69 | values = torch.from_numpy(sparse_mx.data)
70 | shape = torch.Size(sparse_mx.shape)
71 | torch_sparse_coo_tensor = torch.sparse.FloatTensor(indices, values, shape)
72 | sparse_tensor = SparseTensor.from_torch_sparse_coo_tensor(
73 | torch_sparse_coo_tensor)
74 | return sparse_tensor
--------------------------------------------------------------------------------
/src/nichecompass/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .nichecompass import NicheCompass
2 | from .basemodelmixin import BaseModelMixin
3 |
4 | __all__ = ["NicheCompass",
5 | "BaseModelMixin"]
6 |
--------------------------------------------------------------------------------
/src/nichecompass/models/basemodelmixin.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains generic base model functionalities, added as a Mixin to the
3 | NicheCompass model.
4 | """
5 |
6 | import inspect
7 | import os
8 | import warnings
9 | from typing import Optional
10 |
11 | import numpy as np
12 | import pickle
13 | import scipy.sparse as sp
14 | import torch
15 | from anndata import AnnData
16 |
17 | from .utils import initialize_model, load_saved_files, validate_var_names
18 |
19 |
20 | class BaseModelMixin():
21 | """
22 | Base model mix in class for universal model functionalities.
23 |
24 | Parts of the implementation are adapted from
25 | https://github.com/theislab/scarches/blob/master/scarches/models/base/_base.py#L15
26 | (01.10.2022) and
27 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_base_model.py#L63
28 | (01.10.2022).
29 | """
30 | def _get_user_attributes(self) -> list:
31 | """
32 | Get all the attributes defined in a model instance, for example
33 | self.is_trained_.
34 |
35 | Returns
36 | ----------
37 | attributes:
38 | Attributes defined in a model instance.
39 | """
40 | attributes = inspect.getmembers(
41 | self, lambda a: not (inspect.isroutine(a)))
42 | attributes = [a for a in attributes if not (
43 | a[0].startswith("__") and a[0].endswith("__"))]
44 | return attributes
45 |
46 | def _get_public_attributes(self) -> dict:
47 | """
48 | Get only public attributes defined in a model instance. By convention
49 | public attributes have a trailing underscore.
50 |
51 | Returns
52 | ----------
53 | public_attributes:
54 | Public attributes defined in a model instance.
55 | """
56 | public_attributes = self._get_user_attributes()
57 | public_attributes = {a[0]: a[1] for a in public_attributes if
58 | a[0][-1] == "_"}
59 | return public_attributes
60 |
61 | def _get_init_params(self, locals: dict) -> dict:
62 | """
63 | Get the model init signature with associated passed in values from
64 | locals (except the AnnData object passed in).
65 |
66 | Parameters
67 | ----------
68 | locals:
69 | Dictionary returned by calling the ´locals()´ method.
70 |
71 | Returns
72 | ----------
73 | user_params:
74 | Model initialization attributes defined in a model instance.
75 | """
76 | init = self.__init__
77 | sig = inspect.signature(init)
78 | init_params = [p for p in sig.parameters]
79 | user_params = {p: locals[p] for p in locals if p in init_params}
80 | user_params = {k: v for (k, v) in user_params.items() if not
81 | isinstance(v, AnnData)}
82 | return user_params
83 |
84 | def save(self,
85 | dir_path: str,
86 | overwrite: bool=False,
87 | save_adata: bool=False,
88 | adata_file_name: str="adata.h5ad",
89 | save_adata_atac: bool=False,
90 | adata_atac_file_name: str="adata_atac.h5ad",
91 | **anndata_write_kwargs):
92 | """
93 | Save model to disk (the Trainer optimizer state is not saved).
94 |
95 | Parameters
96 | ----------
97 | dir_path:
98 | Path of the directory where the model will be saved.
99 | overwrite:
100 | If `True`, overwrite existing data. If `False` and directory
101 | already exists at `dir_path`, error will be raised.
102 | save_adata:
103 | If `True`, also saves the AnnData object.
104 | adata_file_name:
105 | File name under which the AnnData object will be saved.
106 | save_adata_atac:
107 | If `True`, also saves the ATAC AnnData object.
108 | adata_atac_file_name:
109 | File name under which the ATAC AnnData object will be saved.
110 | adata_write_kwargs:
111 | Kwargs for adata write function.
112 | """
113 | if not os.path.exists(dir_path) or overwrite:
114 | os.makedirs(dir_path, exist_ok=overwrite)
115 | else:
116 | raise ValueError(f"Directory '{dir_path}' already exists."
117 | "Please provide another directory for saving.")
118 |
119 | model_save_path = os.path.join(dir_path, "model_params.pt")
120 | attr_save_path = os.path.join(dir_path, "attr.pkl")
121 | var_names_save_path = os.path.join(dir_path, "var_names.csv")
122 |
123 | if save_adata:
124 | # Convert storage format of adjacency matrix to be writable by
125 | # adata.write()
126 | if self.adata.obsp["spatial_connectivities"] is not None:
127 | self.adata.obsp["spatial_connectivities"] = sp.csr_matrix(
128 | self.adata.obsp["spatial_connectivities"])
129 | self.adata.write(
130 | os.path.join(dir_path, adata_file_name), **anndata_write_kwargs)
131 |
132 | if save_adata_atac:
133 | self.adata_atac.write(
134 | os.path.join(dir_path, adata_atac_file_name))
135 |
136 | var_names = self.adata.var_names.astype(str).to_numpy()
137 | public_attributes = self._get_public_attributes()
138 |
139 | torch.save(self.model.state_dict(), model_save_path)
140 | np.savetxt(var_names_save_path, var_names, fmt="%s")
141 | with open(attr_save_path, "wb") as f:
142 | pickle.dump(public_attributes, f)
143 |
144 | @classmethod
145 | def load(cls,
146 | dir_path: str,
147 | adata: Optional[AnnData]=None,
148 | adata_atac: Optional[AnnData]=None,
149 | adata_file_name: str="adata.h5ad",
150 | adata_atac_file_name: Optional[str]=None,
151 | use_cuda: bool=False,
152 | n_addon_gps: int=0,
153 | gp_names_key: Optional[str]=None,
154 | genes_idx_key: Optional[str]=None,
155 | unfreeze_all_weights: bool=False,
156 | unfreeze_addon_gp_weights: bool=False,
157 | unfreeze_cat_covariates_embedder_weights: bool=False
158 | ) -> torch.nn.Module:
159 | """
160 | Instantiate a model from saved output. Can be used for transfer learning
161 | scenarios and to learn de-novo gene programs by adding add-on gene
162 | programs and freezing non add-on weights.
163 |
164 | Parameters
165 | ----------
166 | dir_path:
167 | Path to saved outputs.
168 | adata:
169 | AnnData organized in the same way as data used to train the model.
170 | If ´None´, will check for and load adata saved with the model.
171 | adata_atac:
172 | ATAC AnnData organized in the same way as data used to train the
173 | model. If ´None´ and ´adata_atac_file_name´ is not ´None´, will
174 | check for and load adata_atac saved with the model.
175 | adata_file_name:
176 | File name of the AnnData object to be loaded.
177 | adata_atac_file_name:
178 | File name of the ATAC AnnData object to be loaded.
179 | use_cuda:
180 | If `True`, load model on GPU.
181 | n_addon_gps:
182 | Number of (new) add-on gene programs to be added to the model's
183 | architecture.
184 | gp_names_key:
185 | Key under which the gene program names are stored in ´adata.uns´.
186 | unfreeze_all_weights:
187 | If `True`, unfreeze all weights.
188 | unfreeze_addon_gp_weights:
189 | If `True`, unfreeze addon gp weights.
190 | unfreeze_cat_covariates_embedder_weights:
191 | If `True`, unfreeze categorical covariates embedder weights.
192 |
193 | Returns
194 | -------
195 | model:
196 | Model with loaded state dictionaries and, if specified, frozen non
197 | add-on weights.
198 | """
199 | load_adata = adata is None
200 | load_adata_atac = ((adata_atac is None) &
201 | (adata_atac_file_name is not None))
202 | use_cuda = use_cuda and torch.cuda.is_available()
203 | map_location = torch.device("cpu") if use_cuda is False else None
204 |
205 | model_state_dict, var_names, attr_dict, new_adata, new_adata_atac = (
206 | load_saved_files(dir_path,
207 | load_adata,
208 | adata_file_name,
209 | load_adata_atac,
210 | adata_atac_file_name,
211 | map_location=map_location))
212 | adata = new_adata if new_adata is not None else adata
213 | adata_atac = (new_adata_atac if new_adata_atac is not None else
214 | adata_atac)
215 |
216 | validate_var_names(adata, var_names)
217 |
218 | # Include all genes in gene expression reconstruction if addon nodes
219 | # are present
220 | if n_addon_gps != 0:
221 | if genes_idx_key not in adata.uns:
222 | raise ValueError("Please specifiy a valid 'genes_idx_key' if "
223 | "'n_addon_gps' > 0, so that all genes can be "
224 | "included in the genes idx.")
225 | adata.uns[genes_idx_key] = np.arange(adata.n_vars * 2)
226 |
227 | # Add new categorical covariates categories from query data
228 | cat_covariates_cats = attr_dict["cat_covariates_cats_"]
229 | cat_covariates_keys = attr_dict["init_params_"]["cat_covariates_keys"]
230 | new_cat_covariates_cats = []
231 | if cat_covariates_keys is not None:
232 | for i, cat_covariate_key in enumerate(cat_covariates_keys):
233 | new_cat_covariate_cats = []
234 | adata_cat_covariate_cats = adata.obs[cat_covariate_key].unique().tolist()
235 | for cat_covariate_cat in adata_cat_covariate_cats:
236 | if cat_covariate_cat not in cat_covariates_cats[i]:
237 | new_cat_covariate_cats.append(cat_covariate_cat)
238 | for cat_covariate_cat in new_cat_covariate_cats:
239 | new_cat_covariates_cats.append(cat_covariate_cat)
240 | cat_covariates_cats[i].append(cat_covariate_cat)
241 | attr_dict["init_params_"]["cat_covariates_cats"] = cat_covariates_cats
242 |
243 | if n_addon_gps != 0:
244 | attr_dict["n_addon_gps_"] += n_addon_gps
245 | attr_dict["init_params_"]["n_addon_gps"] += n_addon_gps
246 |
247 | if gp_names_key is None:
248 | raise ValueError("Please specify 'gp_names_key' so that addon "
249 | "gps can be added to the gene program list.")
250 |
251 | gps = list(adata.uns[gp_names_key])
252 |
253 | if any("addon_GP_" in gp for gp in gps):
254 | addon_gp_idx = int(gps[-1][-1]) + 1
255 | adata.uns[gp_names_key] = np.array(
256 | gps + ["addon_GP_" + str(addon_gp_idx + i) for i in
257 | range(n_addon_gps)])
258 | else:
259 | adata.uns[gp_names_key] = np.array(
260 | gps + ["addon_GP_" + str(i) for i in range(n_addon_gps)])
261 |
262 | model = initialize_model(cls, adata, attr_dict, adata_atac)
263 |
264 | # set saved attrs for loaded model
265 | for attr, val in attr_dict.items():
266 | setattr(model, attr, val)
267 |
268 | if n_addon_gps != 0 or len(new_cat_covariates_cats) > 0:
269 | model.model.load_and_expand_state_dict(model_state_dict)
270 | else:
271 | model.model.load_state_dict(model_state_dict)
272 |
273 | if use_cuda:
274 | model.model.cuda()
275 | model.model.eval()
276 |
277 | # First freeze all parameters and then subsequently unfreeze based on
278 | # load settings
279 | for param_name, param in model.model.named_parameters():
280 | param.requires_grad = False
281 | model.freeze_ = True
282 | if unfreeze_all_weights:
283 | for param_name, param in model.model.named_parameters():
284 | param.requires_grad = True
285 | model.freeze_ = False
286 | if unfreeze_addon_gp_weights:
287 | # allow updates of addon gp weights
288 | for param_name, param in model.model.named_parameters():
289 | if "addon" in param_name or \
290 | "theta" in param_name or \
291 | "aggregator" in param_name:
292 | param.requires_grad = True
293 | if unfreeze_cat_covariates_embedder_weights:
294 | # Allow updates of categorical covariates embedder weights
295 | for param_name, param in model.model.named_parameters():
296 | if ("cat_covariate" in param_name) & ("embedder" in param_name):
297 | param.requires_grad = True
298 |
299 | if model.freeze_ and not model.is_trained_:
300 | raise ValueError("The model has not been pre-trained and therefore "
301 | "weights should not be frozen.")
302 |
303 | return model
304 |
305 | def _check_if_trained(self,
306 | warn: bool=True):
307 | """
308 | Check if the model is trained.
309 |
310 | Parameters
311 | -------
312 | warn:
313 | If not trained and `warn` is True, raise a warning, else raise a
314 | RuntimeError.
315 | """
316 | message = ("Trying to query inferred values from an untrained model. "
317 | "Please train the model first.")
318 | if not self.is_trained_:
319 | if warn:
320 | warnings.warn(message)
321 | else:
322 | raise RuntimeError(message)
--------------------------------------------------------------------------------
/src/nichecompass/models/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains helper functions for the ´models´ subpackage.
3 | """
4 |
5 | import logging
6 | import os
7 | import pickle
8 | from collections import OrderedDict
9 | from typing import Optional, Tuple, Literal
10 |
11 | import anndata as ad
12 | import numpy as np
13 | import torch
14 |
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def load_saved_files(dir_path: str,
20 | load_adata: bool,
21 | adata_file_name: str="adata.h5ad",
22 | load_adata_atac: bool=False,
23 | adata_atac_file_name: str="adata_atac.h5ad",
24 | map_location: Optional[Literal["cpu", "cuda"]]=None
25 | ) -> Tuple[OrderedDict, dict, np.ndarray, ad.AnnData]:
26 | """
27 | Helper to load saved model files.
28 |
29 | Parts of the implementation are adapted from
30 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L55
31 | (01.10.2022)
32 |
33 | Parameters
34 | ----------
35 | dir_path:
36 | Path where the saved model files are stored.
37 | load_adata:
38 | If `True`, also load the stored AnnData object.
39 | adata_file_name:
40 | File name under which the AnnData object is saved.
41 | load_adata_atac:
42 | If `True`, also load the stored ATAC AnnData object.
43 | adata_atac_file_name:
44 | File name under which the ATAC AnnData object is saved.
45 | map_location:
46 | Memory location where to map the model files to.
47 |
48 | Returns
49 | ----------
50 | model_state_dict:
51 | The stored model state dict.
52 | var_names:
53 | The stored variable names.
54 | attr_dict:
55 | The stored attributes.
56 | adata:
57 | The stored AnnData object.
58 | adata_atac:
59 | The stored ATAC AnnData object.
60 | """
61 | attr_path = os.path.join(dir_path, "attr.pkl")
62 | adata_path = os.path.join(dir_path, adata_file_name)
63 | var_names_path = os.path.join(dir_path, "var_names.csv")
64 | model_path = os.path.join(dir_path, "model_params.pt")
65 |
66 | if os.path.exists(adata_path) and load_adata:
67 | adata = ad.read(adata_path)
68 | elif not os.path.exists(adata_path) and load_adata:
69 | raise ValueError("Dir path contains no saved anndata and no adata was "
70 | "passed.")
71 | else:
72 | adata = None
73 |
74 | if load_adata_atac:
75 | adata_atac_path = os.path.join(dir_path, adata_atac_file_name)
76 | if os.path.exists(adata_atac_path):
77 | adata_atac = ad.read(adata_atac_path)
78 | else:
79 | raise ValueError("Dir path contains no saved 'adata_atac' and no "
80 | "'adata_atac' was passed.")
81 | else:
82 | adata_atac = None
83 |
84 | model_state_dict = torch.load(model_path, map_location=map_location)
85 | var_names = np.genfromtxt(var_names_path, delimiter=",", dtype=str)
86 | with open(attr_path, "rb") as handle:
87 | attr_dict = pickle.load(handle)
88 | return model_state_dict, var_names, attr_dict, adata, adata_atac
89 |
90 |
91 | def validate_var_names(adata: ad.AnnData, source_var_names: str):
92 | """
93 | Helper to validate variable names.
94 |
95 | Parts of the implementation are adapted from
96 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L141
97 | (01.10.2022)
98 |
99 | Parameters
100 | ----------
101 | source_var_names:
102 | Variables names against which to validate.
103 | """
104 | user_var_names = adata.var_names.astype(str)
105 | if not np.array_equal(source_var_names, user_var_names):
106 | logger.warning(
107 | "The ´var_names´ of the passed in adata do not match the "
108 | "´var_names´ of the adata used to train the model. For valid "
109 | "results, the var_names need to be the same and in the same order "
110 | "as the adata used to train the model.")
111 |
112 |
113 | def initialize_model(cls,
114 | adata: ad.AnnData,
115 | attr_dict: dict,
116 | adata_atac: Optional[ad.AnnData]=None) -> torch.nn.Module:
117 | """
118 | Helper to initialize a model. Adapted from
119 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L103.
120 |
121 | Parameters
122 | ----------
123 | adata:
124 | AnnData object to be used for initialization.
125 | attr_dict:
126 | Dictionary with attributes for model initialization.
127 | adata_atac:
128 | ATAC AnnData object to be used for initialization.
129 | """
130 | if "init_params_" not in attr_dict.keys():
131 | raise ValueError("No init_params_ were saved by the model.")
132 | # Get the parameters for the class init signature
133 | init_params = attr_dict.pop("init_params_")
134 |
135 | # Grab all the parameters except for kwargs (is a dict)
136 | non_kwargs = {k: v for k, v in init_params.items() if not isinstance(v, dict)}
137 | # Expand out kwargs
138 | kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
139 | kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
140 | if adata_atac is None:
141 | # Support for legacy models
142 | model = cls(adata=adata,
143 | **non_kwargs,
144 | **kwargs)
145 | else:
146 | model = cls(adata=adata,
147 | adata_atac=adata_atac,
148 | **non_kwargs,
149 | **kwargs)
150 | return model
--------------------------------------------------------------------------------
/src/nichecompass/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .basemodulemixin import BaseModuleMixin
2 | from .losses import (compute_cat_covariates_contrastive_loss,
3 | compute_edge_recon_loss,
4 | compute_gp_group_lasso_reg_loss,
5 | compute_gp_l1_reg_loss,
6 | compute_kl_reg_loss,
7 | compute_omics_recon_nb_loss)
8 | from .vgaemodulemixin import VGAEModuleMixin
9 | from .vgpgae import VGPGAE
10 |
11 | __all__ = ["BaseModuleMixin",
12 | "compute_cat_covariates_contrastive_loss",
13 | "compute_edge_recon_loss",
14 | "compute_gp_group_lasso_reg_loss",
15 | "compute_gp_l1_reg_loss",
16 | "compute_kl_reg_loss",
17 | "compute_omics_recon_nb_loss",
18 | "VGAEModuleMixin",
19 | "VGPGAE"]
20 |
--------------------------------------------------------------------------------
/src/nichecompass/modules/basemodulemixin.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains generic module functionalities, added as a Mixin to the
3 | Variational Gene Program Graph Autoencoder module.
4 | """
5 |
6 | import inspect
7 | from collections import OrderedDict
8 |
9 | import torch
10 |
11 |
12 | class BaseModuleMixin:
13 | """
14 | Base module mix in class containing universal module functionalities.
15 |
16 | Parts of the implementation are adapted from
17 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_base_model.py#L63
18 | (01.10.2022).
19 | """
20 | def _get_user_attributes(self) -> list:
21 | """
22 | Get all the attributes defined in a module instance.
23 |
24 | Returns
25 | ----------
26 | attributes:
27 | Attributes defined in a module instance.
28 | """
29 | attributes = inspect.getmembers(
30 | self, lambda a: not (inspect.isroutine(a)))
31 | attributes = [a for a in attributes if not (
32 | a[0].startswith("__") and a[0].endswith("__"))]
33 | return attributes
34 |
35 | def _get_public_attributes(self) -> dict:
36 | """
37 | Get only public attributes defined in a module instance. By convention
38 | public attributes have a trailing underscore.
39 |
40 | Returns
41 | ----------
42 | public_attributes:
43 | Public attributes defined in a module instance.
44 | """
45 | public_attributes = self._get_user_attributes()
46 | public_attributes = {a[0]: a[1] for a in public_attributes if
47 | a[0][-1] == "_"}
48 | return public_attributes
49 |
50 | def load_and_expand_state_dict(self,
51 | model_state_dict: OrderedDict):
52 | """
53 | Load model state dictionary into model and expand it to account for
54 | architectural changes through e.g. add-on nodes.
55 |
56 | Parts of the implementation are adapted from
57 | https://github.com/theislab/scarches/blob/master/scarches/models/base/_base.py#L92
58 | (01.10.2022).
59 | """
60 | load_state_dict = model_state_dict.copy() # old model architecture state
61 | # dict
62 | new_state_dict = self.state_dict() # new model architecture state dict
63 | device = next(self.parameters()).device
64 |
65 | # Update parameter tensors from old model architecture with changes from
66 | # new model architecture
67 | for key, load_param_tensor in load_state_dict.items():
68 | new_param_tensor = new_state_dict[key]
69 | if new_param_tensor.size() == load_param_tensor.size():
70 | continue # nothing needs to be updated
71 | else:
72 | # new model architecture parameter tensors are different from
73 | # old model architecture parameter tensors; updates are
74 | # necessary
75 | load_param_tensor = load_param_tensor.to(device)
76 | n_dims = len(new_param_tensor.shape)
77 | idx_slicers = [slice(None)] * n_dims
78 | for i in range(n_dims):
79 | dim_diff = (new_param_tensor.shape[i] -
80 | load_param_tensor.shape[i])
81 | idx_slicers[i] = slice(-dim_diff, None)
82 | if dim_diff > 0:
83 | break
84 | expanded_param_tensor = torch.cat(
85 | [load_param_tensor, new_param_tensor[tuple(idx_slicers)]],
86 | dim=i)
87 | load_state_dict[key] = expanded_param_tensor
88 |
89 | # Add parameter tensors from new model architecture to old model
90 | # architecture state dict
91 | for key, new_param_tensor in new_state_dict.items():
92 | if key not in load_state_dict:
93 | load_state_dict[key] = new_param_tensor
94 |
95 | self.load_state_dict(load_state_dict)
96 |
--------------------------------------------------------------------------------
/src/nichecompass/modules/vgaemodulemixin.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains generic VGAE functionalities, added as a Mixin to the
3 | Variational Gene Program Graph Autoencoder neural network module.
4 | """
5 |
6 | import torch
7 |
8 |
9 | class VGAEModuleMixin:
10 | """
11 | VGAE module mix in class containing universal VGAE module
12 | functionalities.
13 | """
14 | def reparameterize(self,
15 | mu: torch.Tensor,
16 | logstd: torch.Tensor) -> torch.Tensor:
17 | """
18 | Use reparameterization trick for latent space normal distribution.
19 |
20 | Parameters
21 | ----------
22 | mu:
23 | Expected values of the latent space distribution (dim: n_obs,
24 | n_gps).
25 | logstd:
26 | Log standard deviations of the latent space distribution (dim: n_obs,
27 | n_gps).
28 |
29 | Returns
30 | ----------
31 | rep:
32 | Reparameterized latent features (dim: n_obs, n_gps).
33 | """
34 | if self.training:
35 | std = torch.exp(logstd)
36 | eps = torch.randn_like(mu)
37 | rep = eps.mul(std).add(mu)
38 | return rep
39 | else:
40 | rep = mu
41 | return rep
42 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregators import (OneHopAttentionNodeLabelAggregator,
2 | OneHopGCNNormNodeLabelAggregator,
3 | OneHopSumNodeLabelAggregator)
4 | from .decoders import (CosineSimGraphDecoder,
5 | FCOmicsFeatureDecoder,
6 | MaskedOmicsFeatureDecoder)
7 | from .encoders import Encoder
8 | from .layercomponents import MaskedLinear
9 | from .layers import AddOnMaskedLayer
10 |
11 | __all__ = ["OneHopAttentionNodeLabelAggregator",
12 | "OneHopGCNNormNodeLabelAggregator",
13 | "OneHopSumNodeLabelAggregator",
14 | "CosineSimGraphDecoder",
15 | "FCOmicsFeatureDecoder",
16 | "MaskedOmicsFeatureDecoder",
17 | "Encoder",
18 | "MaskedLinear",
19 | "AddOnMaskedLayer"]
20 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/aggregators.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains gene expression aggregators used by the NicheCompass model.
3 | """
4 |
5 | from typing import Literal
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch_geometric.nn.conv import MessagePassing
10 | from torch_geometric.nn.conv.gcn_conv import gcn_norm
11 | from torch_geometric.nn.dense.linear import Linear
12 | from torch_geometric.nn.inits import glorot
13 | from torch_geometric.utils import softmax
14 | from torch_sparse import SparseTensor
15 |
16 |
17 | class OneHopAttentionNodeLabelAggregator(MessagePassing):
18 | """
19 | One-hop Attention Node Label Aggregator class that uses a weighted sum
20 | of the omics features of a node's 1-hop neighbors to build an
21 | aggregated neighbor omics feature vector for a node. The weights are
22 | determined by an additivite attention mechanism with learnable weights.
23 | It returns a concatenation of the node's own omics feature vector and
24 | the attention-aggregated neighbor omics feature vector as node labels
25 | for the omics reconstruction task.
26 |
27 | Parts of the implementation are inspired by
28 | https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/gatv2_conv.py#L16
29 | (01.10.2022).
30 |
31 | Parameters
32 | ----------
33 | modality:
34 | Omics modality that is aggregated. Can be either `rna` or `atac`.
35 | n_input:
36 | Number of omics features used for the Node Label Aggregation.
37 | n_heads:
38 | Number of attention heads for multi-head attention.
39 | leaky_relu_negative_slope:
40 | Slope of the leaky relu activation function.
41 | dropout_rate:
42 | Dropout probability of the normalized attention coefficients which
43 | exposes each node to a stochastically sampled neighborhood during
44 | training.
45 | """
46 | def __init__(self,
47 | modality: Literal["rna", "atac"],
48 | n_input: int,
49 | n_heads: int=4,
50 | leaky_relu_negative_slope: float=0.2,
51 | dropout_rate: float=0.):
52 | super().__init__(node_dim=0)
53 | self.n_input = n_input
54 | self.n_heads = n_heads
55 | self.leaky_relu_negative_slope = leaky_relu_negative_slope
56 | self.linear_l_l = Linear(n_input,
57 | n_input * n_heads,
58 | bias=False,
59 | weight_initializer="glorot")
60 | self.linear_r_l = Linear(n_input,
61 | n_input * n_heads,
62 | bias=False,
63 | weight_initializer="glorot")
64 | self.attn = nn.Parameter(torch.Tensor(1, n_heads, n_input))
65 | self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
66 | self.dropout = nn.Dropout(p=dropout_rate)
67 | self._alpha = None
68 | self.reset_parameters()
69 |
70 | print(f"ONE HOP ATTENTION {modality.upper()} NODE LABEL AGGREGATOR -> "
71 | f"n_input: {n_input}, "
72 | f"n_heads: {n_heads}")
73 |
74 | def reset_parameters(self):
75 | """
76 | Reset weight parameters.
77 | """
78 | self.linear_l_l.reset_parameters()
79 | self.linear_r_l.reset_parameters()
80 | glorot(self.attn)
81 |
82 | def forward(self,
83 | x: torch.Tensor,
84 | edge_index: torch.Tensor,
85 | return_agg_weights: bool=False) -> torch.Tensor:
86 | """
87 | Forward pass of the One-hop Attention Node Label Aggregator.
88 |
89 | Parameters
90 | ----------
91 | x:
92 | Tensor containing the omics features of the nodes in the current
93 | node batch including sampled neighbors.
94 | (Size: n_nodes_batch_and_sampled_neighbors x n_node_features)
95 | edge_index:
96 | Tensor containing the node indices of edges in the current node
97 | batch including sampled neighbors.
98 | (Size: 2 x n_edges_batch_and_sampled_neighbors)
99 | return_agg_weights:
100 | If ´True´, also return the aggregation weights (attention weights).
101 |
102 | Returns
103 | ----------
104 | x_neighbors:
105 | Tensor containing the node labels of the nodes in the current node
106 | batch excluding sampled neighbors. These labels are used for the
107 | omics feature reconstruction task.
108 | (Size: n_nodes_batch x (2 x n_node_features))
109 | alpha:
110 | Aggregation weights for edges in ´edge_index´.
111 | """
112 | x_l = x_r = x
113 | g_l = self.linear_l_l(x_l).view(-1, self.n_heads, self.n_input)
114 | g_r = self.linear_r_l(x_r).view(-1, self.n_heads, self.n_input)
115 | x_l = x_l.repeat(1, self.n_heads).view(-1, self.n_heads, self.n_input)
116 |
117 | output = self.propagate(edge_index, x=(x_l, x_r), g=(g_l, g_r))
118 | x_neighbors = output.mean(dim=1)
119 | alpha = self._alpha
120 | self._alpha = None
121 | if return_agg_weights:
122 | return x_neighbors, alpha
123 | return x_neighbors, None
124 |
125 | def message(self,
126 | x_j: torch.Tensor,
127 | x_i: torch.Tensor,
128 | g_j: torch.Tensor,
129 | g_i: torch.Tensor,
130 | index: torch.Tensor) -> torch.Tensor:
131 | """
132 | Message method of the MessagePassing parent class. Variables with "_i"
133 | suffix refer to the central nodes that aggregate information. Variables
134 | with "_j" suffix refer to the neigboring nodes.
135 |
136 | Parameters
137 | ----------
138 | x_j:
139 | Gene expression of neighboring nodes (dim: n_index x n_heads x
140 | n_node_features).
141 | g_i:
142 | Key vector of central nodes (dim: n_index x n_heads x
143 | n_node_features).
144 | g_j:
145 | Query vector of neighboring nodes (dim: n_index x n_heads x
146 | n_node_features).
147 | """
148 | g = g_i + g_j
149 | g = self.activation(g)
150 | alpha = (g * self.attn).sum(dim=-1)
151 | alpha = softmax(alpha, index) # index is 2nd dim of edge_index (index of
152 | # central node over which softmax should
153 | # be applied)
154 | self._alpha = alpha
155 | alpha = self.dropout(alpha)
156 | return x_j * alpha.unsqueeze(-1)
157 |
158 |
159 | class OneHopGCNNormNodeLabelAggregator(nn.Module):
160 | """
161 | One-hop GCN Norm Node Label Aggregator class that uses a symmetrically
162 | normalized sum of the omics feature vector of a node's 1-hop neighbors to
163 | build an aggregated neighbor omics feature vector for a node. It returns a
164 | concatenation of the node's own omics feature vector and the gcn-norm
165 | aggregated neighbor omics feature vector as node labels for the omics
166 | reconstruction task.
167 |
168 | modality:
169 | Omics modality that is aggregated. Can be either `rna` or `atac`.
170 | """
171 | def __init__(self,
172 | modality: Literal["rna", "atac"]):
173 | super().__init__()
174 | print(f"ONE HOP GCN NORM {modality.upper()} NODE LABEL AGGREGATOR")
175 |
176 | def forward(self,
177 | x: torch.Tensor,
178 | edge_index: torch.Tensor,
179 | return_agg_weights: bool=False) -> torch.Tensor:
180 | """
181 | Forward pass of the One-hop GCN Norm Node Label Aggregator.
182 |
183 | Parameters
184 | ----------
185 | x:
186 | Tensor containing the omics features of the nodes in the current
187 | node batch including sampled neighbors.
188 | (Size: n_nodes_batch_and_sampled_neighbors x n_node_features)
189 | edge_index:
190 | Tensor containing the node indices of edges in the current node
191 | batch including sampled neighbors.
192 | (Size: 2 x n_edges_batch_and_sampled_neighbors)
193 | return_agg_weights:
194 | If ´True´, also return the aggregation weights (norm weights).
195 |
196 | Returns
197 | ----------
198 | x_neighbors:
199 | Tensor containing the node labels of the nodes in the current node
200 | batch. These labels are used for the omics reconstruction task.
201 | (Size: n_nodes_batch x (2 x n_node_features))
202 | alpha:
203 | Neighbor aggregation weights.
204 | """
205 | adj = SparseTensor.from_edge_index(edge_index,
206 | sparse_sizes=(x.shape[0],
207 | x.shape[0]))
208 | adj_norm = gcn_norm(adj)
209 | x_neighbors = adj_norm.t().matmul(x)
210 | if return_agg_weights:
211 | alpha = adj_norm.coo()[2]
212 | return x_neighbors, alpha
213 | return x_neighbors, None
214 |
215 |
216 | class OneHopSumNodeLabelAggregator(nn.Module):
217 | """
218 | One-hop Sum Node Label Aggregator class that sums up the omics features of
219 | a node's 1-hop neighbors to build an aggregated neighbor omics feature
220 | vector for a node. It returns a concatenation of the node's own omics
221 | feature vector and the sum-aggregated neighbor omics feature vector as node
222 | labels for the omics reconstruction task.
223 |
224 | Parameters
225 | ----------
226 | modality:
227 | Omics modality that is aggregated. Can be either `rna` or `atac`.
228 | """
229 | def __init__(self,
230 | modality: Literal["rna", "atac"]):
231 | super().__init__()
232 | print(f"ONE HOP SUM {modality.upper()} NODE LABEL AGGREGATOR")
233 |
234 | def forward(self,
235 | x: torch.Tensor,
236 | edge_index:torch.Tensor,
237 | return_agg_weights: bool=False) -> torch.Tensor:
238 | """
239 | Forward pass of the One-hop Sum Node Label Aggregator.
240 |
241 | Parameters
242 | ----------
243 | x:
244 | Tensor containing the omics features of the nodes in the current
245 | node batch including sampled neighbors.
246 | (Size: n_nodes_batch_and_sampled_neighbors x n_node_features)
247 | edge_index:
248 | Tensor containing the node indices of edges in the current node
249 | batch including sampled neighbors.
250 | (Size: 2 x n_edges_batch_and_sampled_neighbors)
251 |
252 | Returns
253 | ----------
254 | x_neighbors:
255 | Tensor containing the node labels of the nodes in the current node
256 | batch excluding sampled neighbors. These labels are used for the
257 | omics reconstruction task.
258 | (Size: n_nodes_batch x (2 x n_node_features))
259 | """
260 | adj = SparseTensor.from_edge_index(edge_index,
261 | sparse_sizes=(x.shape[0],
262 | x.shape[0]))
263 | x_neighbors = adj.t().matmul(x)
264 | return x_neighbors, None
265 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/decoders.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains decoders used by the NicheCompass model.
3 | """
4 |
5 | from typing import Literal, List, Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .layers import AddOnMaskedLayer
11 | from .utils import compute_cosine_similarity
12 |
13 |
14 | class CosineSimGraphDecoder(nn.Module):
15 | """
16 | Cosine similarity graph decoder class.
17 |
18 | Takes the concatenated latent feature vectors z of the source and
19 | target nodes as input, and calculates the element-wise cosine similarity
20 | between source and target nodes to return the reconstructed edge logits.
21 |
22 | The sigmoid activation function to compute reconstructed edge probabilities
23 | is integrated into the binary cross entropy loss for computational
24 | efficiency.
25 |
26 | Parameters
27 | ----------
28 | dropout_rate:
29 | Probability of nodes to be dropped during training.
30 | """
31 | def __init__(self,
32 | dropout_rate: float=0.):
33 | super().__init__()
34 | print("COSINE SIM GRAPH DECODER -> "
35 | f"dropout_rate: {dropout_rate}")
36 |
37 | self.dropout = nn.Dropout(dropout_rate)
38 |
39 | def forward(self,
40 | z: torch.Tensor) -> torch.Tensor:
41 | """
42 | Forward pass of the cosine similarity graph decoder.
43 |
44 | Parameters
45 | ----------
46 | z:
47 | Concatenated latent feature vector of the source and target nodes
48 | (dim: 4 * edge_batch_size x n_gps due to negative edges).
49 |
50 | Returns
51 | ----------
52 | edge_recon_logits:
53 | Reconstructed edge logits (dim: 2 * edge_batch_size due to negative
54 | edges).
55 | """
56 | z = self.dropout(z)
57 |
58 | # Compute element-wise cosine similarity
59 | edge_recon_logits = compute_cosine_similarity(
60 | z[:int(z.shape[0]/2)], # ´edge_label_index[0]´
61 | z[int(z.shape[0]/2):]) # ´edge_label_index[1]´
62 | return edge_recon_logits
63 |
64 |
65 | class MaskedOmicsFeatureDecoder(nn.Module):
66 | """
67 | Masked omics feature decoder class.
68 |
69 | Takes the latent space features z (gp scores) as input, and has a masked
70 | layer to decode the parameters of the underlying omics feature distributions.
71 |
72 | Parameters
73 | ----------
74 | modality:
75 | Omics modality that is decoded. Can be either `rna` or `atac`.
76 | entity:
77 | Entity that is decoded. Can be either `target` or `source`.
78 | n_prior_gp_input:
79 | Number of maskable prior gp input nodes to the decoder (maskable latent
80 | space dimensionality).
81 | n_addon_gp_input:
82 | Number of non-maskable add-on gp input nodes to the decoder (
83 | non-maskable latent space dimensionality).
84 | n_cat_covariates_embed_input:
85 | Number of categorical covariates embedding input nodes to the decoder
86 | (categorical covariates embedding dimensionality).
87 | n_output:
88 | Number of output nodes from the decoder (number of omics features).
89 | mask:
90 | Mask that determines which masked input nodes / prior gp latent features
91 | z can contribute to the reconstruction of which omics features.
92 | addon_mask:
93 | Mask that determines which add-on input nodes / add-on gp latent
94 | features z can contribute to the reconstruction of which omics features.
95 | masked_features_idx:
96 | Index of omics features that are included in the mask.
97 | recon_loss:
98 | The loss used for omics reconstruction. If `nb`, uses a negative
99 | binomial loss.
100 | """
101 | def __init__(self,
102 | modality: Literal["rna", "atac"],
103 | entity: Literal["target", "source"],
104 | n_prior_gp_input: int,
105 | n_addon_gp_input: int,
106 | n_cat_covariates_embed_input: int,
107 | n_output: int,
108 | mask: torch.Tensor,
109 | addon_mask: torch.Tensor,
110 | masked_features_idx: List,
111 | recon_loss: Literal["nb"]):
112 | super().__init__()
113 | print(f"MASKED {entity.upper()} {modality.upper()} DECODER -> "
114 | f"n_prior_gp_input: {n_prior_gp_input}, "
115 | f"n_addon_gp_input: {n_addon_gp_input}, "
116 | f"n_cat_covariates_embed_input: {n_cat_covariates_embed_input}, "
117 | f"n_output: {n_output}")
118 |
119 | self.masked_features_idx = masked_features_idx
120 | self.recon_loss = recon_loss
121 |
122 | self.nb_means_normalized_decoder = AddOnMaskedLayer(
123 | n_input=n_prior_gp_input,
124 | n_addon_input=n_addon_gp_input,
125 | n_cat_covariates_embed_input=n_cat_covariates_embed_input,
126 | n_output=n_output,
127 | bias=False,
128 | mask=mask,
129 | addon_mask=addon_mask,
130 | masked_features_idx=masked_features_idx,
131 | activation=nn.Softmax(dim=-1))
132 |
133 | def forward(self,
134 | z: torch.Tensor,
135 | log_library_size: torch.Tensor,
136 | cat_covariates_embed: Optional[torch.Tensor]=None,
137 | dynamic_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
138 | """
139 | Forward pass of the masked omics feature decoder.
140 |
141 | Parameters
142 | ----------
143 | z:
144 | Tensor containing the latent space features.
145 | log_library_size:
146 | Tensor containing the omics feature log library size of the nodes.
147 | dynamic_mask:
148 | Dynamic mask that can change in each forward pass. Is used for atac
149 | modality: if a gene is removed by regularization in the rna decoder
150 | (its weight is set to 0), the corresponding peaks will be marked as 0
151 | in the `dynamic_mask`.
152 | cat_covariates_embed:
153 | Tensor containing the categorical covariates embedding (all
154 | categorical covariates embeddings concatenated into one embedding).
155 |
156 | Returns
157 | ----------
158 | nb_means:
159 | The mean parameters of the negative binomial distribution.
160 | """
161 | # Add categorical covariates embedding to latent feature vector
162 | if cat_covariates_embed is not None:
163 | z = torch.cat((z, cat_covariates_embed), dim=-1)
164 |
165 | nb_means_normalized = self.nb_means_normalized_decoder(
166 | input=z,
167 | dynamic_mask=dynamic_mask)
168 | nb_means = torch.exp(log_library_size) * nb_means_normalized
169 | return nb_means
170 |
171 |
172 | class FCOmicsFeatureDecoder(nn.Module):
173 | """
174 | Fully connected omics feature decoder class.
175 |
176 | Takes the latent space features z as input, and has a fully connected layer
177 | to decode the parameters of the underlying omics feature distributions.
178 |
179 | Parameters
180 | ----------
181 | modality:
182 | Omics modality that is decoded. Can be either `rna` or `atac`.
183 | entity:
184 | Entity that is decoded. Can be either `target` or `source`.
185 | n_prior_gp_input:
186 | Number of maskable prior gp input nodes to the decoder (maskable latent
187 | space dimensionality).
188 | n_addon_gp_input:
189 | Number of non-maskable add-on gp input nodes to the decoder (
190 | non-maskable latent space dimensionality).
191 | n_cat_covariates_embed_input:
192 | Number of categorical covariates embedding input nodes to the decoder
193 | (categorical covariates embedding dimensionality).
194 | n_output:
195 | Number of output nodes from the decoder (number of omics features).
196 | n_layers:
197 | Number of fully connected layers used for decoding.
198 | recon_loss:
199 | The loss used for omics reconstruction. If `nb`, uses a negative
200 | binomial loss.
201 | """
202 | def __init__(self,
203 | modality: Literal["rna", "atac"],
204 | entity: Literal["target", "source"],
205 | n_prior_gp_input: int,
206 | n_addon_gp_input: int,
207 | n_cat_covariates_embed_input: int,
208 | n_output: int,
209 | n_layers: int,
210 | recon_loss: Literal["nb"]):
211 | super().__init__()
212 | print(f"FC {entity.upper()} {modality.upper()} DECODER -> "
213 | f"n_prior_gp_input: {n_prior_gp_input}, "
214 | f"n_addon_gp_input: {n_addon_gp_input}, "
215 | f"n_cat_covariates_embed_input: {n_cat_covariates_embed_input}, "
216 | f"n_output: {n_output}")
217 |
218 | self.n_input = (n_prior_gp_input
219 | + n_addon_gp_input
220 | + n_cat_covariates_embed_input)
221 | self.recon_loss = recon_loss
222 |
223 | if n_layers == 1:
224 | self.nb_means_normalized_decoder = nn.Sequential(
225 | nn.Linear(self.n_input, n_output, bias=False),
226 | nn.Softmax(dim=-1))
227 | elif n_layers == 2:
228 | self.nb_means_normalized_decoder = nn.Sequential(
229 | nn.Linear(self.n_input, self.n_input, bias=False),
230 | nn.ReLU(),
231 | nn.Linear(self.n_input, n_output, bias=False),
232 | nn.Softmax(dim=-1))
233 |
234 | def forward(self,
235 | z: torch.Tensor,
236 | log_library_size: torch.Tensor,
237 | cat_covariates_embed: Optional[torch.Tensor]=None,
238 | **kwargs) -> torch.Tensor:
239 | """
240 | Forward pass of the fully connected omics feature decoder.
241 |
242 | Parameters
243 | ----------
244 | z:
245 | Tensor containing the latent space features.
246 | log_library_size:
247 | Tensor containing the log library size of the nodes.
248 | dynamic_mask:
249 | Dynamic mask that can change in each forward pass. Is used for atac
250 | modality. If a gene is removed by regularization in the rna decoder
251 | (its weight is set to 0), the corresponding peaks will be marked as 0
252 | in the `dynamic_mask`.
253 | cat_covariates_embed:
254 | Tensor containing the categorical covariates embedding (all
255 | categorical covariates embeddings concatenated into one embedding).
256 |
257 | Returns
258 | ----------
259 | nb_means:
260 | The mean parameters of the negative binomial distribution.
261 | """
262 | # Add categorical covariates embedding to latent feature vector
263 | if cat_covariates_embed is not None:
264 | z = torch.cat((z, cat_covariates_embed), dim=-1)
265 |
266 | nb_means_normalized = self.nb_means_normalized_decoder(input=z)
267 | nb_means = torch.exp(log_library_size) * nb_means_normalized
268 | return nb_means
269 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/encoders.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the encoder used by the NicheCompass model.
3 | """
4 |
5 | from typing import Literal, Optional, Tuple
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch_geometric.nn import GATv2Conv, GCNConv
10 |
11 |
12 | class Encoder(nn.Module):
13 | """
14 | Encoder class.
15 |
16 | Takes the input space features x and the edge indices as input, first computes
17 | fully connected layers and then uses message passing layers to output mu and
18 | logstd of the latent space normal distribution.
19 |
20 | Parameters
21 | ----------
22 | n_input:
23 | Number of input nodes (omics features) to the encoder.
24 | n_cat_covariates_embed_input:
25 | Number of categorical covariates embedding input nodes to the encoder.
26 | n_hidden:
27 | Number of hidden nodes outputted after the fully connected layers and
28 | intermediate message passing layers.
29 | n_latent:
30 | Number of output nodes (prior gps) from the encoder, making up the
31 | first part of the latent space features z.
32 | n_addon_latent:
33 | Number of add-on nodes in the latent space (new gps), making up the
34 | second part of the latent space features z.
35 | n_fc_layers:
36 | Number of fully connected layers before the message passing layers.
37 | conv_layer:
38 | Message passing layer used.
39 | n_layers:
40 | Number of message passing layers.
41 | cat_covariates_embed_mode:
42 | Indicates where to inject the categorical covariates embedding if
43 | injected.
44 | n_attention_heads:
45 | Only relevant if ´conv_layer == gatv2conv´. Number of attention heads
46 | used.
47 | dropout_rate:
48 | Probability of nodes to be dropped in the hidden layer during training.
49 | activation:
50 | Activation function used after the fully connected layers and
51 | intermediate message passing layers.
52 | use_bn:
53 | If ´True´, use a batch normalization layer at the end to normalize ´mu´.
54 | """
55 | def __init__(self,
56 | n_input: int,
57 | n_cat_covariates_embed_input: int,
58 | n_hidden: int,
59 | n_latent: int,
60 | n_addon_latent: int=100,
61 | n_fc_layers: int=1,
62 | conv_layer: Literal["gcnconv", "gatv2conv"]="gatv2conv",
63 | n_layers: int=1,
64 | cat_covariates_embed_mode: Literal["input", "hidden"]="input",
65 | n_attention_heads: int=4,
66 | dropout_rate: float=0.,
67 | activation: nn.Module=nn.ReLU,
68 | use_bn: bool=True):
69 | super().__init__()
70 | print("ENCODER -> "
71 | f"n_input: {n_input}, "
72 | f"n_cat_covariates_embed_input: {n_cat_covariates_embed_input}, "
73 | f"n_hidden: {n_hidden}, "
74 | f"n_latent: {n_latent}, "
75 | f"n_addon_latent: {n_addon_latent}, "
76 | f"n_fc_layers: {n_fc_layers}, "
77 | f"n_layers: {n_layers}, "
78 | f"conv_layer: {conv_layer}, "
79 | f"n_attention_heads: "
80 | f"{n_attention_heads if conv_layer == 'gatv2conv' else '0'}, "
81 | f"dropout_rate: {dropout_rate}, ")
82 |
83 | self.n_addon_latent = n_addon_latent
84 | self.n_layers = n_layers
85 | self.n_fc_layers = n_fc_layers
86 | self.cat_covariates_embed_mode = cat_covariates_embed_mode
87 |
88 | if ((cat_covariates_embed_mode == "input") &
89 | (n_cat_covariates_embed_input != 0)):
90 | # Add categorical covariates embedding to input
91 | n_input += n_cat_covariates_embed_input
92 |
93 | if n_fc_layers == 2:
94 | self.fc_l1 = nn.Linear(n_input, int(n_input / 2))
95 | self.fc_l2 = nn.Linear(int(n_input / 2), n_hidden)
96 | self.fc_l2_bn = nn.BatchNorm1d(n_hidden)
97 | elif n_fc_layers == 1:
98 | self.fc_l1 = nn.Linear(n_input, n_hidden)
99 |
100 | if ((cat_covariates_embed_mode == "hidden") &
101 | (n_cat_covariates_embed_input != 0)):
102 | # Add categorical covariates embedding to hidden after fc_l
103 | n_hidden += n_cat_covariates_embed_input
104 |
105 | if conv_layer == "gcnconv":
106 | if n_layers == 2:
107 | self.conv_l1 = GCNConv(n_hidden,
108 | n_hidden)
109 | self.conv_mu = GCNConv(n_hidden,
110 | n_latent)
111 | self.conv_logstd = GCNConv(n_hidden,
112 | n_latent)
113 | if n_addon_latent != 0:
114 | self.addon_conv_mu = GCNConv(n_hidden,
115 | n_addon_latent)
116 | self.addon_conv_logstd = GCNConv(n_hidden,
117 | n_addon_latent)
118 | elif conv_layer == "gatv2conv":
119 | if n_layers == 2:
120 | self.conv_l1 = GATv2Conv(n_hidden,
121 | n_hidden,
122 | heads=n_attention_heads,
123 | concat=False)
124 | self.conv_mu = GATv2Conv(n_hidden,
125 | n_latent,
126 | heads=n_attention_heads,
127 | concat=False)
128 | self.conv_logstd = GATv2Conv(n_hidden,
129 | n_latent,
130 | heads=n_attention_heads,
131 | concat=False)
132 | if n_addon_latent != 0:
133 | self.addon_conv_mu = GATv2Conv(n_hidden,
134 | n_addon_latent,
135 | heads=n_attention_heads,
136 | concat=False)
137 | self.addon_conv_logstd = GATv2Conv(n_hidden,
138 | n_addon_latent,
139 | heads=n_attention_heads,
140 | concat=False)
141 | self.activation = activation
142 | self.dropout = nn.Dropout(dropout_rate)
143 |
144 | def forward(self,
145 | x: torch.Tensor,
146 | edge_index: torch.Tensor,
147 | cat_covariates_embed: Optional[torch.Tensor]=None
148 | ) -> Tuple[torch.Tensor, torch.Tensor]:
149 | """
150 | Forward pass of the encoder.
151 |
152 | Parameters
153 | ----------
154 | x:
155 | Tensor containing the omics features.
156 | edge_index:
157 | Tensor containing the edge indices for message passing.
158 | cat_covariates_embed:
159 | Tensor containing the categorical covariates embedding (all
160 | categorical covariates embeddings concatenated into one embedding).
161 |
162 | Returns
163 | ----------
164 | mu:
165 | Tensor containing the expected values of the latent space normal
166 | distribution.
167 | logstd:
168 | Tensor containing the log standard deviations of the latent space
169 | normal distribution.
170 | """
171 | if ((self.cat_covariates_embed_mode == "input") &
172 | (cat_covariates_embed is not None)):
173 | # Add categorical covariates embedding to input vector
174 | x = torch.cat((x,
175 | cat_covariates_embed),
176 | axis=1)
177 |
178 | # FC forward pass shared across all nodes
179 | hidden = self.dropout(self.activation(self.fc_l1(x)))
180 | if self.n_fc_layers == 2:
181 | hidden = self.dropout(self.activation(self.fc_l2(hidden)))
182 | hidden = self.fc_l2_bn(hidden)
183 |
184 | if ((self.cat_covariates_embed_mode == "hidden") &
185 | (cat_covariates_embed is not None)):
186 | # Add categorical covariates embedding to hidden vector
187 | hidden = torch.cat((hidden,
188 | cat_covariates_embed),
189 | axis=1)
190 |
191 | if self.n_layers == 2:
192 | # Part of forward pass shared across all nodes
193 | hidden = self.dropout(self.activation(
194 | self.conv_l1(hidden, edge_index)))
195 |
196 | # Part of forward pass only for maskable latent nodes
197 | mu = self.conv_mu(hidden, edge_index)
198 | logstd = self.conv_logstd(hidden, edge_index)
199 |
200 | # Part of forward pass only for unmaskable add-on latent nodes
201 | if self.n_addon_latent != 0:
202 | mu = torch.cat(
203 | (mu, self.addon_conv_mu(hidden, edge_index)),
204 | dim=1)
205 | logstd = torch.cat(
206 | (logstd, self.addon_conv_logstd(hidden, edge_index)),
207 | dim=1)
208 | return mu, logstd
209 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/layercomponents.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains neural network layer components used by the NicheCompass
3 | model.
4 | """
5 |
6 | from typing import Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class MaskedLinear(nn.Linear):
13 | """
14 | Masked linear class.
15 |
16 | Parts of the implementation are adapted from
17 | https://github.com/theislab/scarches/blob/master/scarches/models/expimap/modules.py#L9;
18 | 01.10.2022.
19 |
20 | Uses static and dynamic binary masks to mask connections from the input
21 | layer to the output layer so that only unmasked connections can be used.
22 |
23 | Parameters
24 | ----------
25 | n_input:
26 | Number of input nodes to the masked layer.
27 | n_output:
28 | Number of output nodes from the masked layer.
29 | mask:
30 | Static mask that is used to mask the node connections from the input
31 | layer to the output layer.
32 | bias:
33 | If ´True´, use a bias.
34 | """
35 | def __init__(self,
36 | n_input: int,
37 | n_output: int,
38 | mask: torch.Tensor,
39 | bias=False):
40 | # Mask should have dim n_input x n_output
41 | if n_input != mask.shape[0] or n_output != mask.shape[1]:
42 | raise ValueError("Incorrect shape of the mask. Mask should have dim"
43 | " (n_input x n_output). Please provide a mask with"
44 | f" dimensions ({n_input} x {n_output}).")
45 | super().__init__(n_input, n_output, bias)
46 |
47 | self.register_buffer("mask", mask.t())
48 |
49 | # Zero out weights with the mask so that the optimizer does not
50 | # consider them
51 | self.weight.data *= self.mask
52 |
53 | def forward(self,
54 | input: torch.Tensor,
55 | dynamic_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
56 | """
57 | Forward pass of the masked linear class.
58 |
59 | Parameters
60 | ----------
61 | input:
62 | Tensor containing the input features to the masked linear class.
63 | dynamic_mask:
64 | Additional optional Tensor containing a mask that changes
65 | during training.
66 |
67 | Returns
68 | ----------
69 | output:
70 | Tensor containing the output of the masked linear class (linear
71 | transformation of the input by only considering unmasked
72 | connections).
73 | """
74 | if dynamic_mask is not None:
75 | dynamic_mask = dynamic_mask.t().to(self.mask.device)
76 | self.weight.data *= dynamic_mask
77 | masked_weights = self.weight * self.mask * dynamic_mask
78 | else:
79 | masked_weights = self.weight * self.mask
80 | output = nn.functional.linear(input, masked_weights, self.bias)
81 | return output
82 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains neural network layers used by the NicheCompass model.
3 | """
4 |
5 | from typing import List, Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .layercomponents import MaskedLinear
11 |
12 |
13 | class AddOnMaskedLayer(nn.Module):
14 | """
15 | Add-on masked layer class.
16 |
17 | Parts of the implementation are adapted from
18 | https://github.com/theislab/scarches/blob/7980a187294204b5fb5d61364bb76c0b809eb945/scarches/models/expimap/modules.py#L28;
19 | 01.10.2022.
20 |
21 | Parameters
22 | ----------
23 | n_input:
24 | Number of mask input nodes to the add-on masked layer.
25 | n_output:
26 | Number of output nodes from the add-on masked layer.
27 | mask:
28 | Mask that is used to mask the node connections for mask inputs from the
29 | input layer to the output layer.
30 | addon_mask:
31 | Mask that is used to mask the node connections for add-on inputs from
32 | the input layer to the output layer.
33 | masked_features_idx:
34 | Index of input features that are included in the mask.
35 | bias:
36 | If ´True´, use a bias for the mask input nodes.
37 | n_addon_input:
38 | Number of add-on input nodes to the add-on masked layer.
39 | n_cat_covariates_embed_input:
40 | Number of categorical covariates embedding input nodes to the addon
41 | masked layer.
42 | activation:
43 | Activation function used at the end of the ad-on masked layer.
44 | """
45 | def __init__(self,
46 | n_input: int,
47 | n_output: int,
48 | mask: torch.Tensor,
49 | addon_mask: torch.Tensor,
50 | masked_features_idx: List,
51 | bias: bool=False,
52 | n_addon_input: int=0,
53 | n_cat_covariates_embed_input: int=0,
54 | activation: nn.Module=nn.Softmax(dim=-1)):
55 | super().__init__()
56 | self.n_input = n_input
57 | self.n_addon_input = n_addon_input
58 | self.n_cat_covariates_embed_input = n_cat_covariates_embed_input
59 | self.masked_features_idx = masked_features_idx
60 |
61 | # Masked layer
62 | self.masked_l = MaskedLinear(n_input=n_input,
63 | n_output=n_output,
64 | mask=mask,
65 | bias=bias)
66 |
67 | # Add-on layer
68 | if n_addon_input != 0:
69 | self.addon_l = MaskedLinear(n_input=n_addon_input,
70 | n_output=n_output,
71 | mask=addon_mask,
72 | bias=False)
73 |
74 | # Categorical covariates embedding layer
75 | if n_cat_covariates_embed_input != 0:
76 | self.cat_covariates_embed_l = nn.Linear(
77 | n_cat_covariates_embed_input,
78 | n_output,
79 | bias=False)
80 |
81 | self.activation = activation
82 |
83 | def forward(self,
84 | input: torch.Tensor,
85 | dynamic_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
86 | """
87 | Forward pass of the add-on masked layer.
88 |
89 | Parameters
90 | ----------
91 | input:
92 | Input features to the add-on masked layer. Includes add-on input
93 | nodes and categorical covariates embedding input nodes if specified.
94 | dynamic_mask:
95 | Additional optional dynamic mask for the masked layer.
96 |
97 | Returns
98 | ----------
99 | output:
100 | Output of the add-on masked layer.
101 | """
102 | if (self.n_addon_input == 0) & (self.n_cat_covariates_embed_input == 0):
103 | mask_input = input
104 | elif ((self.n_addon_input != 0) &
105 | (self.n_cat_covariates_embed_input == 0)):
106 | mask_input, addon_input = torch.split(
107 | input,
108 | [self.n_input, self.n_addon_input],
109 | dim=1)
110 | elif ((self.n_addon_input == 0) &
111 | (self.n_cat_covariates_embed_input != 0)):
112 | mask_input, cat_covariates_embed_input = torch.split(
113 | input,
114 | [self.n_input, self.n_cat_covariates_embed_input],
115 | dim=1)
116 | elif ((self.n_addon_input != 0) &
117 | (self.n_cat_covariates_embed_input != 0)):
118 | mask_input, addon_input, cat_covariates_embed_input = torch.split(
119 | input,
120 | [self.n_input, self.n_addon_input, self.n_cat_covariates_embed_input],
121 | dim=1)
122 |
123 | output = self.masked_l(
124 | input=mask_input,
125 | dynamic_mask=(dynamic_mask[:self.n_input, :] if
126 | dynamic_mask is not None else None))
127 | # Dynamic mask also has entries for add-on gps
128 | if self.n_addon_input != 0:
129 | # Only unmasked features will have weights != 0.
130 | output += self.addon_l(
131 | input=addon_input,
132 | dynamic_mask=(dynamic_mask[self.n_input:, :] if
133 | dynamic_mask is not None else None))
134 | if self.n_cat_covariates_embed_input != 0:
135 | if self.n_addon_input != 0:
136 | output += self.cat_covariates_embed_l(
137 | cat_covariates_embed_input)
138 | else:
139 | # Only add categorical covariates embedding layer output to
140 | # masked features
141 | output[:, self.masked_features_idx] += self.cat_covariates_embed_l(
142 | cat_covariates_embed_input)[:, self.masked_features_idx]
143 | output = self.activation(output)
144 | return output
145 |
--------------------------------------------------------------------------------
/src/nichecompass/nn/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains helper functions for the ´nn´ subpackage.
3 | """
4 |
5 | import torch
6 |
7 |
8 | def compute_cosine_similarity(tensor1: torch.Tensor,
9 | tensor2: torch.Tensor,
10 | eps: float=1e-8) -> torch.Tensor:
11 | """
12 | Compute the element-wise cosine similarity between two 2D tensors.
13 |
14 | Parameters
15 | ----------
16 | tensor1:
17 | First tensor for element-wise cosine similarity computation (dim: n_obs
18 | x n_features).
19 | tensor2:
20 | Second tensor for element-wise cosine similarity computation (dim: n_obs
21 | x n_features).
22 |
23 | Returns
24 | ----------
25 | cosine_sim:
26 | Result tensor that contains the computed element-wise cosine
27 | similarities (dim: n_obs).
28 | """
29 | tensor1_norm = tensor1.norm(dim=1)[:, None]
30 | tensor2_norm = tensor2.norm(dim=1)[:, None]
31 | tensor1_normalized = tensor1 / torch.max(
32 | tensor1_norm, eps * torch.ones_like(tensor1_norm))
33 | tensor2_normalized = tensor2 / torch.max(
34 | tensor2_norm, eps * torch.ones_like(tensor2_norm))
35 | cosine_sim = torch.mul(tensor1_normalized, tensor2_normalized).sum(1)
36 | return cosine_sim
37 |
--------------------------------------------------------------------------------
/src/nichecompass/train/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import eval_metrics, plot_eval_metrics
2 | from .trainer import Trainer
3 |
4 | __all__ = ["eval_metrics",
5 | "plot_eval_metrics",
6 | "Trainer"]
--------------------------------------------------------------------------------
/src/nichecompass/train/basetrainermixin.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains generic trainer functionalities, added as a Mixin to
3 | the Trainer module.
4 | """
5 |
6 | import inspect
7 |
8 |
9 | class BaseTrainerMixin:
10 | """
11 | Base trainer mix in class containing universal trainer functionalities.
12 |
13 | Parts of the implementation are adapted from
14 | https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_base_model.py#L63
15 | (01.10.2022).
16 | """
17 | def _get_user_attributes(self) -> list:
18 | """
19 | Get all the attributes defined in a trainer instance.
20 |
21 | Returns
22 | ----------
23 | attributes:
24 | Attributes defined in a trainer instance.
25 | """
26 | attributes = inspect.getmembers(
27 | self, lambda a: not (inspect.isroutine(a)))
28 | attributes = [a for a in attributes if not (
29 | a[0].startswith("__") and a[0].endswith("__"))]
30 | return attributes
31 |
32 | def _get_public_attributes(self) -> dict:
33 | """
34 | Get only public attributes defined in a trainer instance. By convention
35 | public attributes have a trailing underscore.
36 |
37 | Returns
38 | ----------
39 | public_attributes:
40 | Public attributes defined in a trainer instance.
41 | """
42 | public_attributes = self._get_user_attributes()
43 | public_attributes = {a[0]: a[1] for a in public_attributes if
44 | a[0][-1] == "_"}
45 | return public_attributes
--------------------------------------------------------------------------------
/src/nichecompass/train/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains metrics to evaluate the NicheCompass model training.
3 | """
4 |
5 | from typing import Optional, Union
6 |
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import sklearn.metrics as skm
10 | import torch
11 | from matplotlib.ticker import MaxNLocator
12 |
13 |
14 | def eval_metrics(
15 | edge_recon_probs: Union[torch.Tensor, np.ndarray],
16 | edge_labels: Union[torch.Tensor, np.ndarray],
17 | edge_same_cat_covariates_cat: Optional[Union[torch.Tensor, np.ndarray]]=None,
18 | edge_incl: Optional[Union[torch.Tensor, np.ndarray]]=None,
19 | omics_pred_dict: Optional[dict]=None) -> dict:
20 | """
21 | Get the evaluation metrics for a (balanced) sample of positive and negative
22 | edges and a sample of nodes.
23 |
24 | Parameters
25 | ----------
26 | edge_recon_probs:
27 | Tensor or array containing reconstructed edge probabilities.
28 | edge_labels:
29 | Tensor or array containing ground truth labels of edges.
30 | edge_incl:
31 | Boolean tensor or array indicating whether the edge should be included
32 | in the evaluation.
33 | target_rna_preds:
34 | Tensor or array containing the predicted gene expression.
35 | target_rna:
36 | Tensor or array containing the ground truth gene expression.
37 | source_rna_preds:
38 | Tensor or array containing the predicted gene expression.
39 | source_rna:
40 | Tensor or array containing the ground truth gene expression.
41 | chrom_access_preds:
42 | Tensor or array containing the predicted chromatin accessibility.
43 | chrom_access:
44 | Tensor or array containing the ground truth chromatin accessibility.
45 |
46 | Returns
47 | ----------
48 | eval_dict:
49 | Dictionary containing the evaluation metrics ´auroc_score´ (area under
50 | the receiver operating characteristic curve), ´auprc score´ (area under
51 | the precision-recall curve), ´best_acc_score´ (accuracy under optimal
52 | classification threshold) and ´best_f1_score´ (F1 score under optimal
53 | classification threshold).
54 | """
55 | eval_dict = {}
56 |
57 | if isinstance(edge_recon_probs, torch.Tensor):
58 | edge_recon_probs = edge_recon_probs.detach().cpu().numpy()
59 | if isinstance(edge_labels, torch.Tensor):
60 | edge_labels = edge_labels.detach().cpu().numpy()
61 | if isinstance(edge_incl, torch.Tensor):
62 | edge_incl = edge_incl.detach().cpu().numpy()
63 |
64 | if omics_pred_dict is not None:
65 | for key, value in omics_pred_dict.items():
66 | if isinstance(value, torch.Tensor):
67 | omics_pred_dict[key] = value.detach().cpu().numpy()
68 |
69 | if "target_rna_preds" in omics_pred_dict.keys():
70 | # Calculate the gene expression mean squared error
71 | eval_dict["target_rna_mse_score"] = skm.mean_squared_error(
72 | omics_pred_dict["target_rna"],
73 | omics_pred_dict["target_rna_preds"])
74 | eval_dict["source_rna_mse_score"] = skm.mean_squared_error(
75 | omics_pred_dict["source_rna"],
76 | omics_pred_dict["source_rna_preds"])
77 |
78 | if "target_atac_preds" in omics_pred_dict.keys():
79 | # Calculate the gene expression mean squared error
80 | eval_dict["target_atac_mse_score"] = skm.mean_squared_error(
81 | omics_pred_dict["target_atac"],
82 | omics_pred_dict["target_atac_preds"])
83 | eval_dict["source_atac_mse_score"] = skm.mean_squared_error(
84 | omics_pred_dict["source_atac"],
85 | omics_pred_dict["source_atac_preds"])
86 |
87 | if edge_same_cat_covariates_cat is not None:
88 | for i, edge_same_cat_covariate_cat in enumerate(edge_same_cat_covariates_cat):
89 | # Only include negative sampled edges (edge label is 0)
90 | edge_same_cat_covariate_cat_incl = edge_labels == 0
91 | edge_same_cat_covariate_cat_recon_probs = edge_recon_probs[edge_same_cat_covariate_cat_incl]
92 | edge_same_cat_covariate_cat_labels = edge_same_cat_covariate_cat[edge_same_cat_covariate_cat_incl]
93 | same_cat_mask = edge_same_cat_covariate_cat_labels == 1
94 | diff_cat_mask = edge_same_cat_covariate_cat_labels == 0
95 | same_cat_mean = np.mean(edge_same_cat_covariate_cat_recon_probs[same_cat_mask])
96 | diff_cat_mean = np.mean(edge_same_cat_covariate_cat_recon_probs[diff_cat_mask])
97 | eval_dict[f"cat_covariate{i}_mean_sim_diff"] = diff_cat_mean - same_cat_mean
98 |
99 | if edge_incl is not None:
100 | edge_incl = edge_incl.astype(bool)
101 | # Remove edges whose node pair has different categories in categorical
102 | # covariates for which no cross-category edges are present
103 | edge_recon_probs = edge_recon_probs[edge_incl]
104 | edge_labels = edge_labels[edge_incl]
105 |
106 | # Calculate threshold independent metrics
107 | eval_dict["auroc_score"] = skm.roc_auc_score(edge_labels, edge_recon_probs)
108 | eval_dict["auprc_score"] = skm.average_precision_score(edge_labels,
109 | edge_recon_probs)
110 |
111 | # Get the optimal classification probability threshold above which an edge
112 | # is classified as positive so that the threshold optimizes the accuracy
113 | # over the sampled (balanced) set of positive and negative edges.
114 | best_acc_score = 0
115 | best_threshold = 0
116 | for threshold in np.arange(0.01, 1, 0.005):
117 | pred_labels = (edge_recon_probs > threshold).astype("int")
118 | acc_score = skm.accuracy_score(edge_labels, pred_labels)
119 | if acc_score > best_acc_score:
120 | best_threshold = threshold
121 | best_acc_score = acc_score
122 | eval_dict["best_acc_score"] = best_acc_score
123 | eval_dict["best_acc_threshold"] = best_threshold
124 |
125 | # Get the optimal classification probability threshold above which an edge
126 | # is classified as positive so that the threshold optimizes the F1 score
127 | # over the sampled (balanced) set of positive and negative edges.
128 | best_f1_score = 0
129 | for threshold in np.arange(0.01, 1, 0.005):
130 | pred_labels = (edge_recon_probs > threshold).astype("int")
131 | f1_score = skm.f1_score(edge_labels, pred_labels)
132 | if f1_score > best_f1_score:
133 | best_f1_score = f1_score
134 | eval_dict["best_f1_score"] = best_f1_score
135 | return eval_dict
136 |
137 |
138 | def plot_eval_metrics(eval_dict: dict) -> plt.figure:
139 | """
140 | Plot evaluation metrics.
141 |
142 | Parameters
143 | ----------
144 | eval_dict:
145 | Dictionary containing the eval metric scores to be plotted.
146 |
147 | Returns
148 | ----------
149 | fig:
150 | Matplotlib figure containing a plot of the evaluation metrics.
151 | """
152 | # Plot epochs as integers
153 | ax = plt.figure().gca()
154 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
155 |
156 | # Plot eval metrics
157 | for metric_key, metric_scores in eval_dict.items():
158 | plt.plot(metric_scores, label=metric_key)
159 | plt.title("Evaluation metrics over epochs")
160 | plt.ylabel("metric score")
161 | plt.xlabel("epoch")
162 | plt.legend(loc="lower right")
163 |
164 | # Retrieve figure
165 | fig = plt.gcf()
166 | plt.close()
167 | return fig
--------------------------------------------------------------------------------
/src/nichecompass/train/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains helper functions for the ´train´ subpackage.
3 | """
4 |
5 | import sys
6 | from typing import Tuple
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from matplotlib.ticker import MaxNLocator
11 |
12 |
13 | class EarlyStopping:
14 | """
15 | EarlyStopping class for early stopping of NicheCompass training.
16 |
17 | Parts of the implementation are adapted from
18 | https://github.com/theislab/scarches/blob/cb54fa0df3255ad1576a977b17e9d77d4907ceb0/scarches/utils/monitor.py#L4
19 | (01.10.2022).
20 |
21 | Parameters
22 | ----------
23 | early_stopping_metric:
24 | The metric on which the early stopping criterion is calculated.
25 | metric_improvement_threshold:
26 | The minimum value which counts as metric_improvement.
27 | patience:
28 | Number of epochs which are allowed to have no metric improvement until
29 | the training is stopped.
30 | reduce_lr_on_plateau:
31 | If ´True´, the learning rate gets adjusted by ´lr_factor´ after a given
32 | number of epochs with no metric improvement.
33 | lr_patience:
34 | Number of epochs which are allowed to have no metric improvement until
35 | the learning rate is adjusted.
36 | lr_factor:
37 | Scaling factor for adjusting the learning rate.
38 | """
39 | def __init__(self,
40 | early_stopping_metric: str="val_global_loss",
41 | metric_improvement_threshold: float=0.,
42 | patience: int=8,
43 | reduce_lr_on_plateau: bool=True,
44 | lr_patience: int=4,
45 | lr_factor: float=0.1):
46 | self.early_stopping_metric = early_stopping_metric
47 | self.metric_improvement_threshold = metric_improvement_threshold
48 | self.patience = patience
49 | self.reduce_lr_on_plateau = reduce_lr_on_plateau
50 | self.lr_patience = lr_patience
51 | self.lr_factor = lr_factor
52 | self.epochs = 0
53 | self.epochs_not_improved = 0
54 | self.epochs_not_improved_lr = 0
55 | self.current_performance = np.inf
56 | self.best_performance = np.inf
57 | self.best_performance_state = np.inf
58 |
59 | def step(self, current_metric: float) -> Tuple[bool, bool]:
60 | self.epochs += 1
61 |
62 | # Calculate metric improvement
63 | self.current_performance = current_metric
64 | metric_improvement = (self.best_performance -
65 | self.current_performance)
66 | # Update best performance
67 | if metric_improvement > 0:
68 | self.best_performance = self.current_performance
69 | # Update epochs not improved
70 | if metric_improvement < self.metric_improvement_threshold:
71 | self.epochs_not_improved += 1
72 | self.epochs_not_improved_lr += 1
73 | else:
74 | self.epochs_not_improved = 0
75 | self.epochs_not_improved_lr = 0
76 |
77 | # Determine whether to continue training and whether to reduce the
78 | # learning rate
79 | if self.epochs < self.patience:
80 | continue_training = True
81 | reduce_lr = False
82 | elif self.epochs_not_improved >= self.patience:
83 | continue_training = False
84 | reduce_lr = False
85 | else:
86 | if self.reduce_lr_on_plateau == False:
87 | reduce_lr = False
88 | elif self.epochs_not_improved_lr >= self.lr_patience:
89 | reduce_lr = True
90 | self.epochs_not_improved_lr = 0
91 | print("\nReducing learning rate: metric has not improved more "
92 | f"than {self.metric_improvement_threshold} in the last "
93 | f"{self.lr_patience} epochs.")
94 | else:
95 | reduce_lr = False
96 | continue_training = True
97 | if not continue_training:
98 | print("\nStopping early: metric has not improved more than "
99 | + str(self.metric_improvement_threshold) +
100 | " in the last " + str(self.patience) + " epochs.")
101 | print("If the early stopping criterion is too strong, "
102 | "please instantiate it with different parameters "
103 | "in the train method.")
104 | return continue_training, reduce_lr
105 |
106 | def update_state(self, current_metric: float) -> bool:
107 | improved = (self.best_performance_state - current_metric) > 0
108 | if improved:
109 | self.best_performance_state = current_metric
110 | return improved
111 |
112 |
113 | def print_progress(epoch: int, logs: dict, n_epochs: int):
114 | """
115 | Create message for '_print_progress_bar()' and print it out with a progress
116 | bar.
117 |
118 | Implementation is adapted from
119 | https://github.com/theislab/scarches/blob/master/scarches/trainers/trvae/_utils.py#L11
120 | (01.10.2022).
121 |
122 | Parameters
123 | ----------
124 | epoch:
125 | Current epoch.
126 | logs:
127 | Dictionary with all logs (losses & metrics).
128 | n_epochs:
129 | Total number of epochs.
130 | """
131 | # Define progress message
132 | message = ""
133 | for key in logs:
134 | message += f"{key:s}: {logs[key][-1]:.4f}; "
135 | message = message[:-2] + "\n"
136 |
137 | # Display progress bar
138 | _print_progress_bar(epoch + 1,
139 | n_epochs,
140 | prefix=f"Epoch {epoch + 1}/{n_epochs}",
141 | suffix=message,
142 | decimals=1,
143 | length=20)
144 |
145 |
146 | def _print_progress_bar(epoch: int,
147 | n_epochs: int,
148 | prefix: str="",
149 | suffix: str="",
150 | decimals: int=1,
151 | length: int=100,
152 | fill: str="█"):
153 | """
154 | Print out a message with a progress bar.
155 |
156 | Implementation is adapted from
157 | https://github.com/theislab/scarches/blob/master/scarches/trainers/trvae/_utils.py#L41
158 | (01.10.2022).
159 |
160 | Parameters
161 | ----------
162 | epoch:
163 | Current epoch.
164 | n_epochs:
165 | Total number of epochs.
166 | prefix:
167 | String before the progress bar.
168 | suffix:
169 | String after the progress bar.
170 | decimals:
171 | Digits after comma for the percent display.
172 | length:
173 | Length of the progress bar.
174 | fill:
175 | Symbol for filling the bar.
176 | """
177 | percent = ("{0:." + str(decimals) + "f}").format(100 * (
178 | epoch / float(n_epochs)))
179 | filled_len = int(length * epoch // n_epochs)
180 | bar = fill * filled_len + '-' * (length - filled_len)
181 | sys.stdout.write("\r%s |%s| %s%s %s" % (prefix, bar, percent, "%", suffix)),
182 | if epoch == n_epochs:
183 | sys.stdout.write("\n")
184 | sys.stdout.flush()
185 |
186 |
187 | def plot_loss_curves(loss_dict: dict) -> plt.figure:
188 | """
189 | Plot loss curves.
190 |
191 | Parameters
192 | ----------
193 | loss_dict:
194 | Dictionary containing the training and validation losses.
195 |
196 | Returns
197 | ----------
198 | fig:
199 | Matplotlib figure of loss curves.
200 | """
201 | # Plot epochs as integers
202 | ax = plt.figure().gca()
203 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
204 |
205 | # Plot loss
206 | for loss_key, loss in loss_dict.items():
207 | plt.plot(loss, label = loss_key)
208 | plt.title(f"Loss curves")
209 | plt.ylabel("loss")
210 | plt.xlabel("epoch")
211 | plt.legend(loc = "upper right")
212 |
213 | # Retrieve figure
214 | fig = plt.gcf()
215 | plt.close()
216 | return fig
217 |
218 |
219 | def _cycle_iterable(iterable):
220 | iterator = iter(iterable)
221 | while True:
222 | try:
223 | yield next(iterator)
224 | except StopIteration:
225 | iterator = iter(iterable)
--------------------------------------------------------------------------------
/src/nichecompass/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .analysis import (aggregate_obsp_matrix_per_cell_type,
2 | create_cell_type_chord_plot_from_df,
3 | create_new_color_dict,
4 | compute_communication_gp_network,
5 | visualize_communication_gp_network,
6 | generate_enriched_gp_info_plots,
7 | plot_non_zero_gene_count_means_dist)
8 | from .multimodal_mapping import (add_multimodal_mask_to_adata,
9 | get_gene_annotations,
10 | generate_multimodal_mapping_dict)
11 | from .gene_programs import (add_gps_from_gp_dict_to_adata,
12 | extract_gp_dict_from_collectri_tf_network,
13 | extract_gp_dict_from_nichenet_lrt_interactions,
14 | extract_gp_dict_from_mebocost_ms_interactions,
15 | extract_gp_dict_from_omnipath_lr_interactions,
16 | filter_and_combine_gp_dict_gps,
17 | filter_and_combine_gp_dict_gps_v2,
18 | get_unique_genes_from_gp_dict)
19 |
20 | __all__ = ["add_gps_from_gp_dict_to_adata",
21 | "add_multimodal_mask_to_adata",
22 | "aggregate_obsp_matrix_per_cell_type",
23 | "create_cell_type_chord_plot_from_df",
24 | "create_new_color_dict",
25 | "compute_communication_gp_network",
26 | "visualize_communication_gp_network",
27 | "extract_gp_dict_from_collectri_tf_network",
28 | "extract_gp_dict_from_nichenet_lrt_interactions",
29 | "extract_gp_dict_from_mebocost_ms_interactions",
30 | "extract_gp_dict_from_omnipath_lr_interactions",
31 | "filter_and_combine_gp_dict_gps",
32 | "filter_and_combine_gp_dict_gps_v2",
33 | "get_gene_annotations",
34 | "generate_enriched_gp_info_plots",
35 | "plot_non_zero_gene_count_means_dist",
36 | "generate_multimodal_mapping_dict",
37 | "get_unique_genes_from_gp_dict"]
--------------------------------------------------------------------------------
/src/nichecompass/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains helper functions for the ´utils´ subpackage.
3 | """
4 |
5 | import os
6 | from typing import Optional
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import pandas as pd
11 | import pyreadr
12 | import seaborn as sns
13 | from anndata import AnnData
14 |
15 |
16 | def load_R_file_as_df(R_file_path: str,
17 | url: Optional[str]=None,
18 | save_df_to_disk: bool=False,
19 | df_save_path: Optional[str]=None) -> pd.DataFrame:
20 | """
21 | Helper to load an R file either from ´url´ if specified or from
22 | ´R_file_path´ on disk and convert it to a pandas DataFrame.
23 |
24 | Parameters
25 | ----------
26 | R_file_path:
27 | File path to the R file to be loaded as df.
28 | url:
29 | URL of the R file to be loaded as df.
30 | save_df_to_disk:
31 | If ´True´, save df to disk.
32 | df_save_path:
33 | Path where the df will be saved if ´save_df_to_disk´ is ´True´.
34 |
35 | Returns
36 | ----------
37 | df:
38 | Content of R file loaded into a pandas DataFrame.
39 | """
40 | if url is None:
41 | if not os.path.exists(R_file_path):
42 | raise ValueError("Please specify a valid ´R_file_path´ or ´url´.")
43 | result_odict = pyreadr.read_r(R_file_path)
44 | else:
45 | result_odict = pyreadr.read_r(pyreadr.download_file(url, R_file_path))
46 | os.remove(R_file_path)
47 |
48 | df = result_odict[None]
49 |
50 | if save_df_to_disk:
51 | if df_save_path == None:
52 | raise ValueError("Please specify ´df_save_path´ or set "
53 | "´save_to_disk.´ to False.")
54 | df.to_csv(df_save_path)
55 | return df
56 |
57 |
58 | def create_gp_gene_count_distribution_plots(
59 | gp_dict: Optional[dict]=None,
60 | adata: Optional[AnnData]=None,
61 | gp_targets_mask_key: Optional[str]="nichecompass_gp_targets",
62 | gp_sources_mask_key: Optional[str]="nichecompass_gp_sources",
63 | gp_plot_label: str="",
64 | save_path: Optional[str]=None):
65 | """
66 | Create distribution plots of the gene counts for sources and targets
67 | of all gene programs in either a gp dict or an adata object.
68 |
69 | Parameters
70 | ----------
71 | gp_dict:
72 | A gene program dictionary.
73 | adata:
74 | An anndata object
75 | gp_plot_label:
76 | Label of the gene program plot for title.
77 | """
78 | # Get number of source and target genes for each gene program
79 | if gp_dict is not None:
80 | n_sources_list = []
81 | n_targets_list = []
82 | for _, gp_sources_targets_dict in gp_dict.items():
83 | n_sources_list.append(len(gp_sources_targets_dict["sources"]))
84 | n_targets_list.append(len(gp_sources_targets_dict["targets"]))
85 | elif adata is not None:
86 | n_targets_list = adata.varm[gp_targets_mask_key].sum(axis=0)
87 | n_sources_list = adata.varm[gp_sources_mask_key].sum(axis=0)
88 |
89 |
90 | # Convert the arrays to a pandas DataFrame
91 | targets_df = pd.DataFrame({"values": n_targets_list})
92 | sources_df = pd.DataFrame({"values": n_sources_list})
93 |
94 | # Determine plot configurations
95 | max_n_targets = max(n_targets_list)
96 | max_n_sources = max(n_sources_list)
97 | if max_n_targets > 200:
98 | targets_x_ticks_range = 100
99 | xticklabels_rotation = 45
100 | elif max_n_targets > 100:
101 | targets_x_ticks_range = 20
102 | xticklabels_rotation = 0
103 | elif max_n_targets > 10:
104 | targets_x_ticks_range = 10
105 | xticklabels_rotation = 0
106 | else:
107 | targets_x_ticks_range = 1
108 | xticklabels_rotation = 0
109 | if max_n_sources > 200:
110 | sources_x_ticks_range = 100
111 | elif max_n_sources > 100:
112 | sources_x_ticks_range = 20
113 | elif max_n_sources > 10:
114 | sources_x_ticks_range = 10
115 | else:
116 | sources_x_ticks_range = 1
117 |
118 | # Create subplot
119 | fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5))
120 | plt.suptitle(
121 | f"{gp_plot_label} Gene Programs – Gene Count Distribution Plots")
122 | sns.histplot(x="values", data=targets_df, ax=ax1)
123 | ax1.set_title("Gene Program Targets Distribution",
124 | fontsize=10)
125 | ax1.set(xlabel="Number of Targets",
126 | ylabel="Number of Gene Programs")
127 | ax1.set_xticks(
128 | np.arange(0,
129 | max_n_targets + targets_x_ticks_range,
130 | targets_x_ticks_range))
131 | ax1.set_xticklabels(
132 | np.arange(0,
133 | max_n_targets + targets_x_ticks_range,
134 | targets_x_ticks_range),
135 | rotation=xticklabels_rotation)
136 | sns.histplot(x="values", data=sources_df, ax=ax2)
137 | ax2.set_title("Gene Program Sources Distribution",
138 | fontsize=10)
139 | ax2.set(xlabel="Number of Sources",
140 | ylabel="Number of Gene Programs")
141 | ax2.set_xticks(
142 | np.arange(0,
143 | max_n_sources + sources_x_ticks_range,
144 | sources_x_ticks_range))
145 | ax2.set_xticklabels(
146 | np.arange(0,
147 | max_n_sources + sources_x_ticks_range,
148 | sources_x_ticks_range),
149 | rotation=xticklabels_rotation)
150 | plt.subplots_adjust(wspace=0.35)
151 | if save_path:
152 | plt.savefig(save_path)
153 | plt.show()
--------------------------------------------------------------------------------
/tests/test_basic.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import nichecompass
4 |
5 |
6 | def test_package_has_version():
7 | assert nichecompass.__version__ is not None
8 |
9 |
10 | @pytest.mark.skip(reason="This decorator should be removed when test passes.")
11 | def test_example():
12 | assert 1 == 0 # This test is designed to fail.
13 |
--------------------------------------------------------------------------------