├── .codecov.yaml ├── .cruft.json ├── .editorconfig ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml └── workflows │ ├── build.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── .gitkeep │ ├── cellcharter.png │ ├── cellcharter_workflow.png │ ├── css │ │ └── custom.css │ └── spatial_clusters.png ├── _templates │ ├── .gitkeep │ └── autosummary │ │ └── class.rst ├── api.md ├── changelog.md ├── conf.py ├── contributing.md ├── extensions │ └── typed_returns.py ├── graph.md ├── index.md ├── notebooks │ ├── codex_mouse_spleen.ipynb │ ├── cosmx_human_nsclc.ipynb │ ├── data │ │ └── .gitkeep │ ├── example.ipynb │ └── tutorial_models │ │ └── codex_mouse_spleen_trvae │ │ ├── attr.pkl │ │ ├── model_params.pt │ │ └── var_names.csv ├── plotting.md ├── references.bib ├── references.md └── tools.md ├── pyproject.toml ├── src └── cellcharter │ ├── __init__.py │ ├── _constants │ └── _pkg_constants.py │ ├── _utils.py │ ├── datasets │ ├── __init__.py │ └── _dataset.py │ ├── gr │ ├── __init__.py │ ├── _aggr.py │ ├── _build.py │ ├── _group.py │ ├── _nhood.py │ └── _utils.py │ ├── pl │ ├── __init__.py │ ├── _autok.py │ ├── _group.py │ ├── _nhood.py │ ├── _shape.py │ └── _utils.py │ └── tl │ ├── __init__.py │ ├── _autok.py │ ├── _gmm.py │ ├── _shape.py │ ├── _trvae.py │ └── _utils.py └── tests ├── _data └── test_data.h5ad ├── _models ├── cellcharter_autok_imc │ ├── attributes.pickle │ ├── best_models │ │ ├── GaussianMixture_k1 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k10 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k11 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k12 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k13 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k14 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k15 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k16 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k2 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k3 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k4 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k5 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k6 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k7 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ ├── GaussianMixture_k8 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ │ ├── config.json │ │ │ │ └── parameters.pt │ │ │ └── params.pickle │ │ └── GaussianMixture_k9 │ │ │ ├── attributes.pickle │ │ │ ├── model │ │ │ ├── config.json │ │ │ └── parameters.pt │ │ │ └── params.pickle │ └── params.pickle └── cellcharter_autok_mibitof │ ├── attributes.pickle │ ├── best_models │ ├── GaussianMixture_k1 │ │ ├── attributes.pickle │ │ ├── model │ │ │ ├── config.json │ │ │ └── parameters.pt │ │ └── params.json │ ├── GaussianMixture_k2 │ │ ├── attributes.pickle │ │ ├── model │ │ │ ├── config.json │ │ │ └── parameters.pt │ │ └── params.json │ ├── GaussianMixture_k3 │ │ ├── attributes.pickle │ │ ├── model │ │ │ ├── config.json │ │ │ └── parameters.pt │ │ └── params.json │ ├── GaussianMixture_k4 │ │ ├── attributes.pickle │ │ ├── model │ │ │ ├── config.json │ │ │ └── parameters.pt │ │ └── params.json │ ├── GaussianMixture_k5 │ │ ├── attributes.pickle │ │ ├── model │ │ │ ├── config.json │ │ │ └── parameters.pt │ │ └── params.json │ └── GaussianMixture_k6 │ │ ├── attributes.pickle │ │ ├── model │ │ ├── config.json │ │ └── parameters.pt │ │ └── params.json │ └── params.pickle ├── conftest.py ├── graph ├── test_aggregate_neighbors.py ├── test_build.py ├── test_diff_nhood.py ├── test_group.py └── test_nhood.py ├── plotting ├── test_group.py ├── test_plot_nhood.py ├── test_plot_stability.py └── test_shape.py └── tools ├── test_autok.py ├── test_gmm.py └── test_shape.py /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/scverse/cookiecutter-scverse", 3 | "commit": "87a407a65408d75a949c0b54b19fd287475a56f8", 4 | "checkout": "v0.4.0", 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "cellcharter", 8 | "package_name": "cellcharter", 9 | "project_description": "A Python package for the identification, characterization and comparison of spatial clusters from spatial -omics data.", 10 | "author_full_name": "Marco Varrone", 11 | "author_email": "marco.varrone@unil.ch", 12 | "github_user": "marcovarrone", 13 | "project_repo": "https://github.com/marcovarrone/cellcharter", 14 | "license": "BSD 3-Clause License", 15 | "_copy_without_render": [ 16 | ".github/workflows/build.yaml", 17 | ".github/workflows/test.yaml", 18 | "docs/_templates/autosummary/**.rst" 19 | ], 20 | "_render_devdocs": false, 21 | "_jinja2_env_vars": { 22 | "lstrip_blocks": true, 23 | "trim_blocks": true 24 | }, 25 | "_template": "https://github.com/scverse/cookiecutter-scverse" 26 | } 27 | }, 28 | "directory": null 29 | } 30 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.{yml,yaml}] 12 | indent_size = 2 13 | 14 | [.cruft.json] 15 | indent_size = 2 16 | 17 | [Makefile] 18 | indent_style = tab 19 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # Can't yet be moved to the pyproject.toml due to https://github.com/PyCQA/flake8/issues/234 2 | [flake8] 3 | max-line-length = 120 4 | ignore = 5 | # line break before a binary operator -> black does not adhere to PEP8 6 | W503 7 | # line break occured after a binary operator -> black does not adhere to PEP8 8 | W504 9 | # line too long -> we accept long comment lines; black gets rid of long code lines 10 | E501 11 | # whitespace before : -> black does not adhere to PEP8 12 | E203 13 | # line break before binary operator -> black does not adhere to PEP8 14 | W503 15 | # missing whitespace after ,', ';', or ':' -> black does not adhere to PEP8 16 | E231 17 | # continuation line over-indented for hanging indent -> black does not adhere to PEP8 18 | E126 19 | # too many leading '#' for block comment -> this is fine for indicating sections 20 | E262 21 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 22 | E731 23 | # allow I, O, l as variable names -> I is the identity matrix 24 | E741 25 | # Missing docstring in public package 26 | D104 27 | # Missing docstring in public module 28 | D100 29 | # Missing docstring in __init__ 30 | D107 31 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 32 | B008 33 | # Missing docstring in magic method 34 | D105 35 | # format string does contain unindexed parameters 36 | P101 37 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 38 | D400 39 | # First line should be in imperative mood; try rephrasing 40 | D401 41 | # Block quote ends without a blank line; unexpected unindent 42 | RST201 43 | # Definition list ends without a blank line; unexpected unindent 44 | RST203 45 | # Block quote ends without a blank line; unexpected unindent 46 | RST301 47 | exclude = .git,__pycache__,build,docs/_build,dist 48 | per-file-ignores = 49 | tests/*: D 50 | */__init__.py: F401 51 | src/cellcharter/_constants/_pkg_constants.py: D101,D102,D106 52 | rst-roles = 53 | attr, 54 | class, 55 | func, 56 | ref, 57 | cite:p, 58 | cite:t, 59 | mod, 60 | rst-directives = 61 | envvar, 62 | exception, 63 | rst-substitutions = 64 | version, 65 | extend-ignore = 66 | RST307, 67 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report something that is broken or incorrect 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 9 | detailing how to provide the necessary information for us to reproduce your bug. In brief: 10 | * Please provide exact steps how to reproduce the bug in a clean Python environment. 11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure. 12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly 13 | available datasets or to share a subset of your data. 14 | 15 | - type: textarea 16 | id: report 17 | attributes: 18 | label: Report 19 | description: A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: versions 25 | attributes: 26 | label: Version information 27 | description: | 28 | Please paste below the output of 29 | 30 | ```python 31 | import session_info 32 | session_info.show(html=False, dependencies=True) 33 | ``` 34 | placeholder: | 35 | ----- 36 | anndata 0.8.0rc2.dev27+ge524389 37 | session_info 1.0.0 38 | ----- 39 | asttokens NA 40 | awkward 1.8.0 41 | backcall 0.2.0 42 | cython_runtime NA 43 | dateutil 2.8.2 44 | debugpy 1.6.0 45 | decorator 5.1.1 46 | entrypoints 0.4 47 | executing 0.8.3 48 | h5py 3.7.0 49 | ipykernel 6.15.0 50 | jedi 0.18.1 51 | mpl_toolkits NA 52 | natsort 8.1.0 53 | numpy 1.22.4 54 | packaging 21.3 55 | pandas 1.4.2 56 | parso 0.8.3 57 | pexpect 4.8.0 58 | pickleshare 0.7.5 59 | pkg_resources NA 60 | prompt_toolkit 3.0.29 61 | psutil 5.9.1 62 | ptyprocess 0.7.0 63 | pure_eval 0.2.2 64 | pydev_ipython NA 65 | pydevconsole NA 66 | pydevd 2.8.0 67 | pydevd_file_utils NA 68 | pydevd_plugins NA 69 | pydevd_tracing NA 70 | pygments 2.12.0 71 | pytz 2022.1 72 | scipy 1.8.1 73 | setuptools 62.5.0 74 | setuptools_scm NA 75 | six 1.16.0 76 | stack_data 0.3.0 77 | tornado 6.1 78 | traitlets 5.3.0 79 | wcwidth 0.2.5 80 | zmq 23.1.0 81 | ----- 82 | IPython 8.4.0 83 | jupyter_client 7.3.4 84 | jupyter_core 4.10.0 85 | ----- 86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0] 87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35 88 | ----- 89 | Session information updated at 2022-07-07 17:55 90 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for cellcharter 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Check Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | package: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | cache: "pip" 23 | cache-dependency-path: "**/pyproject.toml" 24 | - name: Install build dependencies 25 | run: python -m pip install --upgrade pip wheel twine build 26 | - name: Build package 27 | run: python -m build 28 | - name: Check package 29 | run: twine check --strict dist/*.whl 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ 8 | jobs: 9 | release: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/cellcharter 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.x" 25 | cache: "pip" 26 | - run: pip install build 27 | - run: python -m build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: "0 5 1,15 * *" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | runs-on: ${{ matrix.os }} 18 | defaults: 19 | run: 20 | shell: bash -e {0} # -e to fail on error 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | include: 26 | - os: ubuntu-latest 27 | python: "3.10" 28 | - os: ubuntu-latest 29 | python: "3.12" 30 | - os: ubuntu-latest 31 | python: "3.12" 32 | pip-flags: "--pre" 33 | name: PRE-RELEASE DEPENDENCIES 34 | 35 | name: ${{ matrix.name }} Python ${{ matrix.python }} 36 | 37 | env: 38 | OS: ${{ matrix.os }} 39 | PYTHON: ${{ matrix.python }} 40 | 41 | steps: 42 | - uses: actions/checkout@v3 43 | with: 44 | lfs: true # Ensure LFS files are fetched 45 | 46 | - name: Set up Python ${{ matrix.python }} 47 | uses: actions/setup-python@v4 48 | with: 49 | python-version: ${{ matrix.python }} 50 | cache: "pip" 51 | cache-dependency-path: "**/pyproject.toml" 52 | 53 | - name: Install test dependencies 54 | run: | 55 | python -m pip install --upgrade pip wheel 56 | - name: Install dependencies 57 | run: | 58 | pip install ${{ matrix.pip-flags }} ".[dev,test]" 59 | - name: Test 60 | env: 61 | MPLBACKEND: agg 62 | PLATFORM: ${{ matrix.os }} 63 | DISPLAY: :42 64 | run: | 65 | coverage run -m pytest -v --color=yes 66 | - name: Report coverage 67 | run: | 68 | coverage report 69 | - name: Upload coverage 70 | uses: codecov/codecov-action@v3 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | .DS_Store 3 | *~ 4 | buck-out/ 5 | 6 | # Compiled files 7 | .venv/ 8 | __pycache__/ 9 | .mypy_cache/ 10 | .ruff_cache/ 11 | 12 | # Distribution / packaging 13 | /build/ 14 | /dist/ 15 | /*.egg-info/ 16 | 17 | # Tests and coverage 18 | /.pytest_cache/ 19 | /.cache/ 20 | /data/ 21 | /node_modules/ 22 | 23 | # docs 24 | /docs/generated/ 25 | /docs/_build/ 26 | 27 | # IDEs 28 | /.idea/ 29 | /.vscode/ 30 | 31 | docs/notebooks/tutorial_data 32 | docs/jupyter_execute 33 | 34 | .ipynb_checkpoints 35 | *.ipynb 36 | 37 | trash 38 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | node: 16.14.2 5 | default_stages: 6 | - pre-commit 7 | - pre-push 8 | minimum_pre_commit_version: 2.16.0 9 | repos: 10 | - repo: https://github.com/psf/black 11 | rev: 25.1.0 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/pre-commit/mirrors-prettier 15 | rev: v4.0.0-alpha.8 16 | hooks: 17 | - id: prettier 18 | exclude: '.*\.py$ |.*\.toml$' 19 | - repo: https://github.com/asottile/blacken-docs 20 | rev: 1.19.1 21 | hooks: 22 | - id: blacken-docs 23 | - repo: https://github.com/PyCQA/isort 24 | rev: 6.0.1 25 | hooks: 26 | - id: isort 27 | - repo: https://github.com/asottile/yesqa 28 | rev: v1.5.0 29 | hooks: 30 | - id: yesqa 31 | additional_dependencies: 32 | - flake8-tidy-imports 33 | - flake8-docstrings 34 | - flake8-rst-docstrings 35 | - flake8-comprehensions 36 | - flake8-bugbear 37 | - flake8-blind-except 38 | - repo: https://github.com/pre-commit/pre-commit-hooks 39 | rev: v5.0.0 40 | hooks: 41 | - id: detect-private-key 42 | - id: check-ast 43 | - id: end-of-file-fixer 44 | - id: mixed-line-ending 45 | args: [--fix=lf] 46 | - id: trailing-whitespace 47 | - id: check-case-conflict 48 | - repo: https://github.com/PyCQA/autoflake 49 | rev: v2.3.1 50 | hooks: 51 | - id: autoflake 52 | args: 53 | - --in-place 54 | - --remove-all-unused-imports 55 | - --remove-unused-variable 56 | - --ignore-init-module-imports 57 | - repo: https://github.com/PyCQA/flake8 58 | rev: 7.2.0 59 | hooks: 60 | - id: flake8 61 | additional_dependencies: 62 | - flake8-tidy-imports 63 | - flake8-docstrings 64 | - flake8-rst-docstrings 65 | - flake8-comprehensions 66 | - flake8-bugbear 67 | - flake8-blind-except 68 | - repo: https://github.com/asottile/pyupgrade 69 | rev: v3.19.1 70 | hooks: 71 | - id: pyupgrade 72 | args: [--py3-plus, --py38-plus, --keep-runtime-typing] 73 | - repo: local 74 | hooks: 75 | - id: forbid-to-commit 76 | name: Don't commit rej files 77 | entry: | 78 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 79 | Fix the merge conflicts manually and remove the .rej files. 80 | language: fail 81 | files: '.*\.rej$' 82 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog][], 6 | and this project adheres to [Semantic Versioning][]. 7 | 8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html 10 | 11 | ## [Unreleased] 12 | 13 | ### Added 14 | 15 | - Basic tool, preprocessing and plotting functions 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Marco Varrone 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | **A Python package for the identification, characterization and comparison of spatial clusters from spatial omics data.** 5 | 6 | --- 7 | 8 |

9 | Documentation • 10 | Examples • 11 | Paper • 12 | Preprint 13 |

14 | 15 | [![Tests][badge-tests]][link-tests] 16 | [![Documentation][badge-docs]][link-docs] 17 | [![PyPI Downloads](https://static.pepy.tech/badge/cellcharter)](https://pepy.tech/projects/cellcharter) 18 | 19 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/CSOgroup/cellcharter/test.yaml?branch=main 20 | [link-tests]: https://github.com/CSOgroup/cellcharter/actions/workflows/test.yml 21 | [badge-docs]: https://img.shields.io/readthedocs/cellcharter 22 | 23 |
24 | 25 | ## Background 26 | 27 |

28 | Spatial clustering (or spatial domain identification) determines cellular niches characterized by specific admixing of these populations. It assigns cells to clusters based on both their intrinsic features (e.g., protein or mRNA expression), and the features of neighboring cells in the tissue. 29 |

30 |

31 | 32 |

33 | 34 |

35 | CellCharter is able to automatically identify spatial domains and offers a suite of approaches for cluster characterization and comparison. 36 |

37 |

38 | 39 |

40 | 41 | ## Features 42 | 43 | - **Identify niches for multiple samples**: By combining the power of scVI and scArches, CellCharter can identify domains for multiple samples simultaneously, even with in presence of batch effects. 44 | - **Scalability**: CellCharter can handle large datasets with millions of cells and thousands of features. The possibility to run it on GPUs makes it even faster 45 | - **Flexibility**: CellCharter can be used with different types of spatial omics data, such as spatial transcriptomics, proteomics, epigenomics and multiomics data. The only difference is the method used for dimensionality reduction and batch effect removal. 46 | - Spatial transcriptomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Zero-inflated negative binomial distribution. 47 | - Spatial proteomics: CellCharter has been tested on a version of [scArches](https://docs.scarches.org/en/latest/api/models.html#scarches.models.TRVAE), modified to use Mean Squared Error loss instead of the default Negative Binomial loss. 48 | - Spatial epigenomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Poisson distribution. 49 | - Spatial multiomics: it's possible to use multi-omics models such as [MultiVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.MULTIVI.html#scvi.model.MULTIVI), or use the concatenation of the results from the different models. 50 | - **Best candidates for the number of domains**: CellCharter offers a [method to find multiple best candidates](https://cellcharter.readthedocs.io/en/latest/generated/cellcharter.tl.ClusterAutoK.html) for the number of domains, based on the stability of a certain number of domains across multiple runs. 51 | - **Domain characterization**: CellCharter provides a set of tools to characterize and compare the spatial domains, such as domain proportion, cell type enrichment, (differential) neighborhood enrichment, and domain shape characterization. 52 | 53 | Since CellCharter 0.3.0, we moved the implementation of the Gaussian Mixture Model (GMM) from [PyCave](https://github.com/borchero/pycave), not maintained anymore, to [TorchGMM](https://github.com/CSOgroup/torchgmm), a fork of PyCave maintained by the CSOgroup. This change allows us to have a more stable and maintained implementation of GMM that is compatible with the most recent versions of PyTorch. 54 | 55 | ## Getting started 56 | 57 | Please refer to the [documentation][link-docs]. In particular, the 58 | 59 | - [API documentation][link-api]. 60 | - [Tutorials][link-tutorial] 61 | 62 | ## Installation 63 | 64 | 1. Create a conda or pyenv environment 65 | 2. Install Python >= 3.10,<3.13 and [PyTorch](https://pytorch.org) >= 1.12.0. If you are planning to use a GPU, make sure to download and install the correct version of PyTorch first from [here](https://pytorch.org/get-started/locally/). 66 | 3. Install the library used for dimensionality reduction and batch effect removal according to the data type you are planning to analyze: 67 | - [scVI](https://github.com/scverse/scvi-tools) for spatial transcriptomics and/or epigenomics data such as 10x Visium and Xenium, Nanostring CosMx, Vizgen MERSCOPE, Stereo-seq, DBiT-seq, MERFISH and seqFISH data. 68 | - A modified version of [scArches](https://github.com/theislab/scarches)'s TRVAE model for spatial proteomics data such as Akoya CODEX, Lunaphore COMET, CyCIF, IMC and MIBI-TOF data. 69 | 4. Install CellCharter using pip: 70 | 71 | ```bash 72 | pip install cellcharter 73 | ``` 74 | 75 | We report here an example of an installation aimed at analyzing spatial transcriptomics data (and thus installing `scvi-tools`). 76 | This example is based on a Linux CentOS 7 system with an NVIDIA A100 GPU. 77 | It will install Pytorch for GPU by default. 78 | 79 | ```bash 80 | conda create -n cellcharter-env -c conda-forge python=3.12 81 | conda activate cellcharter-env 82 | pip install scvi-tools 83 | pip install cellcharter 84 | ``` 85 | 86 | Note: a different system may require different commands to install PyTorch and JAX. Refer to their respective documentation for more details. 87 | 88 | ## Contribution 89 | 90 | If you found a bug or you want to propose a new feature, please use the [issue tracker][issue-tracker]. 91 | 92 | [issue-tracker]: https://github.com/CSOgroup/cellcharter/issues 93 | [link-docs]: https://cellcharter.readthedocs.io 94 | [link-api]: https://cellcharter.readthedocs.io/en/latest/api.html 95 | [link-tutorial]: https://cellcharter.readthedocs.io/en/latest/notebooks/codex_mouse_spleen.html 96 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_static/cellcharter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/cellcharter.png -------------------------------------------------------------------------------- /docs/_static/cellcharter_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/cellcharter_workflow.png -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ 2 | div.cell_output table.dataframe { 3 | font-size: 0.8em; 4 | } 5 | -------------------------------------------------------------------------------- /docs/_static/spatial_clusters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/spatial_clusters.png -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Graph 4 | 5 | ```{eval-rst} 6 | .. module:: cellcharter.gr 7 | .. currentmodule:: cellcharter 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | 12 | gr.aggregate_neighbors 13 | gr.nhood_enrichment 14 | gr.remove_long_links 15 | gr.remove_intra_cluster_links 16 | 17 | ``` 18 | 19 | ## Tools 20 | 21 | ```{eval-rst} 22 | .. module:: cellcharter.tl 23 | .. currentmodule:: cellcharter 24 | 25 | .. autosummary:: 26 | :toctree: generated 27 | 28 | tl.Cluster 29 | tl.ClusterAutoK 30 | tl.TRVAE 31 | ``` 32 | 33 | ## Plotting 34 | 35 | ```{eval-rst} 36 | .. module:: cellcharter.pl 37 | .. currentmodule:: cellcharter 38 | 39 | .. autosummary:: 40 | :toctree: generated 41 | 42 | pl.autok_stability 43 | pl.nhood_enrichment 44 | pl.proportion 45 | ``` 46 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | import sys 9 | from datetime import datetime 10 | from importlib.metadata import metadata 11 | from pathlib import Path 12 | 13 | HERE = Path(__file__).parent 14 | sys.path.insert(0, str(HERE / "extensions")) 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | info = metadata("cellcharter") 20 | project_name = info["Name"] 21 | author = info["Author"] 22 | copyright = f"{datetime.now():%Y}, {author}." 23 | version = info["Version"] 24 | repository_url = f"https://github.com/CSOgroup/{project_name}" 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = info["Version"] 28 | 29 | bibtex_bibfiles = ["references.bib"] 30 | templates_path = ["_templates"] 31 | nitpicky = True # Warn about broken links 32 | needs_sphinx = "4.0" 33 | 34 | html_context = { 35 | "display_github": True, # Integrate GitHub 36 | "github_user": "CSOgroup", # Username 37 | "github_repo": project_name, # Repo name 38 | "github_version": "main", # Version 39 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 40 | } 41 | 42 | # -- General configuration --------------------------------------------------- 43 | 44 | # Add any Sphinx extension module names here, as strings. 45 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 46 | extensions = [ 47 | "myst_nb", 48 | "sphinx_copybutton", 49 | "sphinx.ext.autodoc", 50 | "sphinx.ext.intersphinx", 51 | "sphinx.ext.autosummary", 52 | "sphinx.ext.napoleon", 53 | "sphinxcontrib.bibtex", 54 | "sphinx_autodoc_typehints", 55 | "sphinx.ext.mathjax", 56 | "IPython.sphinxext.ipython_console_highlighting", 57 | "sphinxext.opengraph", 58 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 59 | ] 60 | 61 | autosummary_generate = True 62 | autodoc_member_order = "groupwise" 63 | default_role = "literal" 64 | napoleon_google_docstring = False 65 | napoleon_numpy_docstring = True 66 | napoleon_include_init_with_doc = False 67 | napoleon_use_rtype = True # having a separate entry generally helps readability 68 | napoleon_use_param = True 69 | myst_heading_anchors = 6 # create anchors for h1-h6 70 | myst_enable_extensions = [ 71 | "amsmath", 72 | "colon_fence", 73 | "deflist", 74 | "dollarmath", 75 | "html_image", 76 | "html_admonition", 77 | ] 78 | myst_url_schemes = ("http", "https", "mailto") 79 | nb_output_stderr = "remove" 80 | nb_execution_mode = "off" 81 | nb_merge_streams = True 82 | typehints_defaults = "braces" 83 | 84 | source_suffix = { 85 | ".rst": "restructuredtext", 86 | ".ipynb": "myst-nb", 87 | ".myst": "myst-nb", 88 | } 89 | 90 | intersphinx_mapping = { 91 | "python": ("https://docs.python.org/3", None), 92 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 93 | "numpy": ("https://numpy.org/doc/stable/", None), 94 | } 95 | 96 | # List of patterns, relative to source directory, that match files and 97 | # directories to ignore when looking for source files. 98 | # This pattern also affects html_static_path and html_extra_path. 99 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 100 | 101 | 102 | # -- Options for HTML output ------------------------------------------------- 103 | 104 | # The theme to use for HTML and HTML Help pages. See the documentation for 105 | # a list of builtin themes. 106 | # 107 | html_theme = "sphinx_book_theme" 108 | html_static_path = ["_static"] 109 | html_css_files = ["css/custom.css"] 110 | 111 | html_title = project_name 112 | 113 | html_theme_options = { 114 | "repository_url": repository_url, 115 | "use_repository_button": True, 116 | "path_to_docs": "docs/", 117 | "navigation_with_keys": False, 118 | } 119 | 120 | pygments_style = "default" 121 | 122 | nitpick_ignore = [ 123 | # If building the documentation fails because of a missing link that is outside your control, 124 | # you can add an exception to this list. 125 | # ("py:class", "igraph.Graph"), 126 | ] 127 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too. 4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important 5 | information to get you started on contributing. 6 | 7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer 8 | to the [scanpy developer guide][]. 9 | 10 | ## Installing dev dependencies 11 | 12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build 13 | the documentation_. It's easy to install them using `pip`: 14 | 15 | ```bash 16 | cd cellcharter 17 | pip install -e ".[dev,test,doc]" 18 | ``` 19 | 20 | ## Code-style 21 | 22 | This package uses [pre-commit][] to enforce consistent code-styles. 23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message. 24 | 25 | To enable pre-commit locally, simply run 26 | 27 | ```bash 28 | pre-commit install 29 | ``` 30 | 31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time. 32 | 33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before 34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message. 35 | 36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use 37 | 38 | ```bash 39 | git pull --rebase 40 | ``` 41 | 42 | to integrate the changes into yours. 43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage. 44 | 45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors] 46 | and [prettier][prettier-editors]. 47 | 48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/ 49 | [prettier-editors]: https://prettier.io/docs/en/editors.html 50 | 51 | ## Writing tests 52 | 53 | ```{note} 54 | Remember to first install the package with `pip install -e '.[dev,test]'` 55 | ``` 56 | 57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added 58 | to the package. 59 | 60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the 61 | command line by executing 62 | 63 | ```bash 64 | pytest 65 | ``` 66 | 67 | in the root of the repository. 68 | 69 | ### Continuous integration 70 | 71 | Continuous integration will automatically run the tests on all pull requests and test 72 | against the minimum and maximum supported Python version. 73 | 74 | Additionally, there's a CI job that tests against pre-releases of all dependencies 75 | (if there are any). The purpose of this check is to detect incompatibilities 76 | of new package versions early on and gives you time to fix the issue or reach 77 | out to the developers of the dependency before the package is released to a wider audience. 78 | 79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests 80 | 81 | ## Publishing a release 82 | 83 | ### Updating the version number 84 | 85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief 86 | 87 | > Given a version number MAJOR.MINOR.PATCH, increment the: 88 | > 89 | > 1. MAJOR version when you make incompatible API changes, 90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and 91 | > 3. PATCH version when you make backwards compatible bug fixes. 92 | > 93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 94 | 95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub. 96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI. 97 | 98 | ## Writing documentation 99 | 100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: 101 | 102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text 103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). 104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) 105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types 106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) 107 | 108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information 109 | on how to write documentation. 110 | 111 | ### Tutorials with myst-nb and jupyter notebooks 112 | 113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. 114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. 115 | It is your responsibility to update and re-run the notebook whenever necessary. 116 | 117 | If you are interested in automatically running notebooks as part of the continuous integration, please check 118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse` 119 | repository. 120 | 121 | #### Hints 122 | 123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only 124 | if you do so can sphinx automatically create a link to the external documentation. 125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to 126 | the `nitpick_ignore` list in `docs/conf.py` 127 | 128 | #### Building the docs locally 129 | 130 | ```bash 131 | cd docs 132 | make html 133 | open _build/html/index.html 134 | ``` 135 | 136 | 137 | 138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html 140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 141 | [codecov]: https://about.codecov.io/sign-up/ 142 | [codecov docs]: https://docs.codecov.com/docs 143 | [codecov bot]: https://docs.codecov.com/docs/team-bot 144 | [codecov app]: https://github.com/apps/codecov 145 | [pre-commit.ci]: https://pre-commit.ci/ 146 | [readthedocs.org]: https://readthedocs.org/ 147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 149 | [pre-commit]: https://pre-commit.com/ 150 | [anndata]: https://github.com/scverse/anndata 151 | [mudata]: https://github.com/scverse/mudata 152 | [pytest]: https://docs.pytest.org/ 153 | [semver]: https://semver.org/ 154 | [sphinx]: https://www.sphinx-doc.org/en/master/ 155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 159 | [pypi]: https://pypi.org/ 160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository 161 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | from __future__ import annotations 4 | 5 | import re 6 | from collections.abc import Generator, Iterable 7 | 8 | from sphinx.application import Sphinx 9 | from sphinx.ext.napoleon import NumpyDocstring 10 | 11 | 12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: 13 | for line in lines: 14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 21 | lines_raw = self._dedent(self._consume_to_next_section()) 22 | if lines_raw[0] == ":": 23 | del lines_raw[0] 24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw))) 25 | if lines and lines[-1]: 26 | lines.append("") 27 | return lines 28 | 29 | 30 | def setup(app: Sphinx): 31 | """Set app.""" 32 | NumpyDocstring._parse_returns_section = _parse_returns_section 33 | -------------------------------------------------------------------------------- /docs/graph.md: -------------------------------------------------------------------------------- 1 | # Graph 2 | 3 | ```{eval-rst} 4 | .. module:: cellcharter.gr 5 | .. currentmodule:: cellcharter 6 | 7 | .. autosummary:: 8 | :toctree: generated 9 | 10 | gr.aggregate_neighbors 11 | gr.connected_components 12 | gr.diff_nhood_enrichment 13 | gr.enrichment 14 | gr.nhood_enrichment 15 | gr.remove_long_links 16 | gr.remove_intra_cluster_links 17 | ``` 18 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | ```{toctree} 6 | :hidden: true 7 | :maxdepth: 2 8 | :caption: API 9 | 10 | graph.md 11 | tools.md 12 | plotting.md 13 | ``` 14 | 15 | ```{toctree} 16 | :hidden: true 17 | :maxdepth: 2 18 | :caption: Tutorials 19 | 20 | notebooks/codex_mouse_spleen 21 | notebooks/cosmx_human_nsclc 22 | 23 | 24 | ``` 25 | -------------------------------------------------------------------------------- /docs/notebooks/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/notebooks/data/.gitkeep -------------------------------------------------------------------------------- /docs/notebooks/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example notebook" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from anndata import AnnData\n", 18 | "import cellcharter" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "adata = AnnData(np.random.normal(size=(20, 10)))" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Implement a preprocessing function here." 40 | ] 41 | }, 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "0" 46 | ] 47 | }, 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "cellcharter.pp.basic_preproc(adata)" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "Python 3", 61 | "language": "python", 62 | "name": "python3" 63 | }, 64 | "language_info": { 65 | "codemirror_mode": { 66 | "name": "ipython", 67 | "version": 3 68 | }, 69 | "file_extension": ".py", 70 | "mimetype": "text/x-python", 71 | "name": "python", 72 | "nbconvert_exporter": "python", 73 | "pygments_lexer": "ipython3", 74 | "version": "3.8.13" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 4 79 | } 80 | -------------------------------------------------------------------------------- /docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/attr.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/attr.pkl -------------------------------------------------------------------------------- /docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/model_params.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/model_params.pt -------------------------------------------------------------------------------- /docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/var_names.csv: -------------------------------------------------------------------------------- 1 | CD45 2 | Ly6C 3 | TCR 4 | Ly6G 5 | CD19 6 | CD169 7 | CD106 8 | CD3 9 | CD1632 10 | CD8a 11 | CD90 12 | F480 13 | CD11c 14 | Ter119 15 | CD11b 16 | IgD 17 | CD27 18 | CD5 19 | B220 20 | CD71 21 | CD31 22 | CD4 23 | IgM 24 | CD79b 25 | ERTR7 26 | CD35 27 | CD2135 28 | CD44 29 | NKp46 30 | -------------------------------------------------------------------------------- /docs/plotting.md: -------------------------------------------------------------------------------- 1 | # Plotting 2 | 3 | ```{eval-rst} 4 | .. module:: cellcharter.pl 5 | .. currentmodule:: cellcharter 6 | 7 | .. autosummary:: 8 | :toctree: generated 9 | 10 | pl.autok_stability 11 | pl.diff_nhood_enrichment 12 | pl.enrichment 13 | pl.nhood_enrichment 14 | pl.proportion 15 | pl.boundaries 16 | pl.shape_metrics 17 | ``` 18 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Wolf2018, 2 | author = {Wolf, F. Alexander 3 | and Angerer, Philipp 4 | and Theis, Fabian J.}, 5 | title = {SCANPY: large-scale single-cell gene expression data analysis}, 6 | journal = {Genome Biology}, 7 | year = {2018}, 8 | month = {Feb}, 9 | day = {06}, 10 | volume = {19}, 11 | number = {1}, 12 | pages = {15}, 13 | abstract = {Scanpy is a scalable toolkit for analyzing single-cell gene expression data. It includes methods for preprocessing, visualization, clustering, pseudotime and trajectory inference, differential expression testing, and simulation of gene regulatory networks. Its Python-based implementation efficiently deals with data sets of more than one million cells (https://github.com/theislab/Scanpy). Along with Scanpy, we present AnnData, a generic class for handling annotated data matrices (https://github.com/theislab/anndata).}, 14 | issn = {1474-760X}, 15 | doi = {10.1186/s13059-017-1382-0}, 16 | url = {https://doi.org/10.1186/s13059-017-1382-0} 17 | } 18 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/tools.md: -------------------------------------------------------------------------------- 1 | # Tools 2 | 3 | ```{eval-rst} 4 | .. module:: cellcharter.tl 5 | .. currentmodule:: cellcharter 6 | 7 | .. autosummary:: 8 | :toctree: generated 9 | 10 | tl.boundaries 11 | tl.Cluster 12 | tl.ClusterAutoK 13 | tl.curl 14 | tl.elongation 15 | tl.linearity 16 | tl.purity 17 | tl.TRVAE 18 | ``` 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [tool.hatch.build] 6 | exclude = [ 7 | "docs*", 8 | "tests*", 9 | ] 10 | 11 | [project] 12 | name = "cellcharter" 13 | version = "0.3.4" 14 | description = "A Python package for the identification, characterization and comparison of spatial clusters from spatial -omics data." 15 | readme = "README.md" 16 | requires-python = ">=3.10,<3.14" 17 | license = {file = "LICENSE"} 18 | authors = [ 19 | {name = "CSO group"}, 20 | ] 21 | maintainers = [ 22 | {name = "Marco Varrone", email = "marco.varrone@unil.ch"}, 23 | ] 24 | urls.Documentation = "https://cellcharter.readthedocs.io/" 25 | urls.Source = "https://github.com/CSOgroup/cellcharter" 26 | urls.Home-page = "https://github.com/CSOgroup/cellcharter" 27 | dependencies = [ 28 | "anndata", 29 | "scikit-learn", 30 | "squidpy >= 1.6.3", 31 | "torchgmm >= 0.1.2", 32 | # for debug logging (referenced from the issue template) 33 | "session-info", 34 | "spatialdata", 35 | "spatialdata-plot", 36 | "rasterio", 37 | "sknw", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | dev = [ 42 | "pre-commit", 43 | "twine>=4.0.2" 44 | ] 45 | doc = [ 46 | "docutils>=0.8,!=0.18.*,!=0.19.*", 47 | "sphinx>=4", 48 | "sphinx-book-theme>=1.0.0", 49 | "myst-nb>=1.1.0", 50 | "sphinxcontrib-bibtex>=1.0.0", 51 | "sphinx-autodoc-typehints", 52 | "sphinxext-opengraph", 53 | # For notebooks 54 | "ipykernel", 55 | "ipython", 56 | "sphinx-copybutton", 57 | "sphinx-design" 58 | ] 59 | test = [ 60 | "pytest", 61 | "pytest-cov", 62 | ] 63 | transcriptomics = [ 64 | "scvi-tools", 65 | ] 66 | proteomics = [ 67 | "scarches", 68 | ] 69 | 70 | [tool.coverage.run] 71 | source = ["cellcharter"] 72 | omit = [ 73 | "**/test_*.py", 74 | ] 75 | 76 | [tool.pytest.ini_options] 77 | testpaths = ["tests"] 78 | xfail_strict = true 79 | addopts = [ 80 | "--import-mode=importlib", # allow using test files with same name 81 | ] 82 | filterwarnings = [ 83 | "ignore::anndata.OldFormatWarning", 84 | "ignore:.*this fit will run with no optimizer.*", 85 | "ignore:.*Consider increasing the value of the `num_workers` argument.*", 86 | ] 87 | 88 | [tool.isort] 89 | include_trailing_comma = true 90 | multi_line_output = 3 91 | profile = "black" 92 | skip_glob = ["docs/*"] 93 | 94 | [tool.black] 95 | line-length = 120 96 | target-version = ['py38'] 97 | include = '\.pyi?$' 98 | exclude = ''' 99 | ( 100 | /( 101 | \.eggs 102 | | \.git 103 | | \.hg 104 | | \.mypy_cache 105 | | \.tox 106 | | \.venv 107 | | _build 108 | | buck-out 109 | | build 110 | | dist 111 | )/ 112 | ) 113 | ''' 114 | 115 | [tool.jupytext] 116 | formats = "ipynb,md" 117 | 118 | [tool.cruft] 119 | skip = [ 120 | "tests", 121 | "src/**/__init__.py", 122 | "src/**/basic.py", 123 | "docs/api.md", 124 | "docs/changelog.md", 125 | "docs/references.bib", 126 | "docs/references.md", 127 | "docs/notebooks" 128 | ] 129 | -------------------------------------------------------------------------------- /src/cellcharter/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | from . import datasets, gr, pl, tl 4 | 5 | __all__ = ["gr", "pl", "tl", "datasets"] 6 | 7 | __version__ = version("cellcharter") 8 | -------------------------------------------------------------------------------- /src/cellcharter/_constants/_pkg_constants.py: -------------------------------------------------------------------------------- 1 | """Internal constants not exposed to the user.""" 2 | 3 | 4 | class Key: 5 | class obs: 6 | @classmethod 7 | @property 8 | def sample(cls) -> str: 9 | return "sample" 10 | -------------------------------------------------------------------------------- /src/cellcharter/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | 7 | AnyRandom = Union[int, np.random.RandomState, None] 8 | 9 | 10 | def str2list(value: Union[str, list]) -> list: 11 | """Check whether value is a string. If so, converts into a list containing value.""" 12 | return [value] if isinstance(value, str) else value 13 | -------------------------------------------------------------------------------- /src/cellcharter/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from ._dataset import * # noqa: F403 2 | -------------------------------------------------------------------------------- /src/cellcharter/datasets/_dataset.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | from squidpy.datasets._utils import AMetadata 4 | 5 | _codex_mouse_spleen = AMetadata( 6 | name="codex_mouse_spleen", 7 | doc_header="Pre-processed CODEX dataset of mouse spleen from `Goltsev et al " 8 | "`__.", 9 | shape=(707474, 29), 10 | url="https://figshare.com/ndownloader/files/38538101", 11 | ) 12 | 13 | for name, var in copy(locals()).items(): 14 | if isinstance(var, AMetadata): 15 | var._create_function(name, globals()) 16 | 17 | 18 | __all__ = ["codex_mouse_spleen"] # noqa: F822 19 | -------------------------------------------------------------------------------- /src/cellcharter/gr/__init__.py: -------------------------------------------------------------------------------- 1 | from ._aggr import aggregate_neighbors 2 | from ._build import connected_components, remove_intra_cluster_links, remove_long_links 3 | from ._group import enrichment 4 | from ._nhood import diff_nhood_enrichment, nhood_enrichment 5 | -------------------------------------------------------------------------------- /src/cellcharter/gr/_aggr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import scipy.sparse as sps 8 | from anndata import AnnData 9 | from scipy.sparse import spdiags 10 | from squidpy._constants._pkg_constants import Key as sqKey 11 | from squidpy._docs import d 12 | from tqdm.auto import tqdm 13 | 14 | from cellcharter._constants._pkg_constants import Key 15 | from cellcharter._utils import str2list 16 | 17 | 18 | def _aggregate_mean(adj, x): 19 | return adj @ x 20 | 21 | 22 | def _aggregate_var(adj, x): 23 | mean = adj @ x 24 | mean_squared = adj @ (x * x) 25 | return mean_squared - mean * mean 26 | 27 | 28 | def _aggregate(adj, x, method): 29 | if method == "mean": 30 | return _aggregate_mean(adj, x) 31 | elif method == "var": 32 | return _aggregate_var(adj, x) 33 | else: 34 | raise NotImplementedError 35 | 36 | 37 | def _mul_broadcast(mat1, mat2): 38 | return spdiags(mat2, 0, len(mat2), len(mat2)) * mat1 39 | 40 | 41 | def _hop(adj_hop, adj, adj_visited=None): 42 | adj_hop = adj_hop @ adj 43 | 44 | if adj_visited is not None: 45 | adj_hop = adj_hop > adj_visited # Logical not for sparse matrices 46 | adj_visited = adj_visited + adj_hop 47 | 48 | return adj_hop, adj_visited 49 | 50 | 51 | def _normalize(adj): 52 | deg = np.array(np.sum(adj, axis=1)).squeeze() 53 | 54 | with warnings.catch_warnings(): 55 | # If a cell doesn't have neighbors deg = 0 -> divide by zero 56 | warnings.filterwarnings(action="ignore", category=RuntimeWarning) 57 | deg_inv = 1 / deg 58 | deg_inv[deg_inv == float("inf")] = 0 59 | 60 | return _mul_broadcast(adj, deg_inv) 61 | 62 | 63 | def _setdiag(array, value): 64 | if isinstance(array, sps.csr_matrix): 65 | array = array.tolil() 66 | array.setdiag(value) 67 | array = array.tocsr() 68 | if value == 0: 69 | array.eliminate_zeros() 70 | return array 71 | 72 | 73 | def _aggregate_neighbors( 74 | adj: sps.spmatrix, 75 | X: np.ndarray, 76 | nhood_layers: list, 77 | aggregations: Optional[Union[str, list]] = "mean", 78 | disable_tqdm: bool = True, 79 | ) -> np.ndarray: 80 | adj = adj.astype(bool) 81 | adj = _setdiag(adj, 0) 82 | adj_hop = adj.copy() 83 | adj_visited = _setdiag(adj.copy(), 1) 84 | 85 | Xs = [] 86 | for i in tqdm(range(0, max(nhood_layers) + 1), disable=disable_tqdm): 87 | if i in nhood_layers: 88 | if i == 0: 89 | Xs.append(X) 90 | else: 91 | if i > 1: 92 | adj_hop, adj_visited = _hop(adj_hop, adj, adj_visited) 93 | adj_hop_norm = _normalize(adj_hop) 94 | 95 | for agg in aggregations: 96 | x = _aggregate(adj_hop_norm, X, agg) 97 | Xs.append(x) 98 | if sps.issparse(X): 99 | return sps.hstack(Xs) 100 | else: 101 | return np.hstack(Xs) 102 | 103 | 104 | @d.dedent 105 | def aggregate_neighbors( 106 | adata: AnnData, 107 | n_layers: Union[int, list], 108 | aggregations: Optional[Union[str, list]] = "mean", 109 | connectivity_key: Optional[str] = None, 110 | use_rep: Optional[str] = None, 111 | sample_key: Optional[str] = None, 112 | out_key: Optional[str] = "X_cellcharter", 113 | copy: bool = False, 114 | ) -> np.ndarray | None: 115 | """ 116 | Aggregate the features from each neighborhood layers and concatenate them, and optionally with the cells' features, into a single vector. 117 | 118 | Parameters 119 | ---------- 120 | %(adata)s 121 | n_layers 122 | Which neighborhood layers to aggregate from. 123 | If :class:`int`, the output vector includes the cells' features and the aggregated features of the neighbors until the layer at distance ``n_layers``, i.e. cells | 1-hop neighbors | ... | ``n_layers``-hop. 124 | If :class:`list`, every element corresponds to the distances at which the neighbors' features will be aggregated and concatenated. For example, [0, 1, 3] corresponds to cells|1-hop neighbors|3-hop neighbors. 125 | aggregations 126 | Which functions to use to aggregate the neighbors features. Default: ```mean``. 127 | connectivity_key 128 | Key in :attr:`anndata.AnnData.obsp` where spatial connectivities are stored. 129 | use_rep 130 | Key of the features. If :class:`None`, adata.X is used. Else, the key is used to access the field in the AnnData .obsm mapping. 131 | sample_key 132 | Key in :attr:`anndata.AnnData.obs` where the sample labels are stored. Must be different from :class:`None` if adata contains multiple samples. 133 | out_key 134 | Key in :attr:`anndata.AnnData.obsm` where the output matrix is stored if ``copy = False``. 135 | %(copy)s 136 | 137 | Returns 138 | ------- 139 | If ``copy = True``, returns a :class:`numpy.ndarray` of the features aggregated and concatenated. 140 | 141 | Otherwise, modifies the ``adata`` with the following key: 142 | - :attr:`anndata.AnnData.obsm` ``['{{out_key}}']`` - the above mentioned :class:`numpy.ndarray`. 143 | """ 144 | connectivity_key = sqKey.obsp.spatial_conn(connectivity_key) 145 | sample_key = Key.obs.sample if sample_key is None else sample_key 146 | aggregations = str2list(aggregations) 147 | 148 | X = adata.X if use_rep is None else adata.obsm[use_rep] 149 | 150 | if isinstance(n_layers, int): 151 | n_layers = list(range(n_layers + 1)) 152 | 153 | if sps.issparse(X): 154 | X_aggregated = sps.dok_matrix( 155 | (X.shape[0], X.shape[1] * ((len(n_layers) - 1) * len(aggregations) + 1)), dtype=np.float32 156 | ) 157 | else: 158 | X_aggregated = np.empty( 159 | (X.shape[0], X.shape[1] * ((len(n_layers) - 1) * len(aggregations) + 1)), dtype=np.float32 160 | ) 161 | 162 | if sample_key in adata.obs: 163 | samples = adata.obs[sample_key].unique() 164 | sample_idxs = [adata.obs[sample_key] == sample for sample in samples] 165 | else: 166 | sample_idxs = [np.arange(adata.shape[0])] 167 | 168 | for idxs in tqdm(sample_idxs, disable=(len(sample_idxs) == 1)): 169 | X_sample_aggregated = _aggregate_neighbors( 170 | adj=adata[idxs].obsp[connectivity_key], 171 | X=X[idxs], 172 | nhood_layers=n_layers, 173 | aggregations=aggregations, 174 | disable_tqdm=(len(sample_idxs) != 1), 175 | ) 176 | X_aggregated[idxs] = X_sample_aggregated 177 | 178 | if isinstance(X_aggregated, sps.dok_matrix): 179 | X_aggregated = X_aggregated.tocsr() 180 | 181 | if copy: 182 | return X_aggregated 183 | 184 | adata.obsm[out_key] = X_aggregated 185 | -------------------------------------------------------------------------------- /src/cellcharter/gr/_build.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import scipy.sparse as sps 6 | from anndata import AnnData 7 | from scipy.sparse import csr_matrix 8 | from squidpy._constants._pkg_constants import Key 9 | from squidpy._docs import d 10 | from squidpy.gr._utils import _assert_connectivity_key 11 | 12 | 13 | @d.dedent 14 | def remove_long_links( 15 | adata: AnnData, 16 | distance_percentile: float = 99.0, 17 | connectivity_key: str | None = None, 18 | distances_key: str | None = None, 19 | neighs_key: str | None = None, 20 | copy: bool = False, 21 | ) -> tuple[csr_matrix, csr_matrix] | None: 22 | """ 23 | Remove links between cells at a distance bigger than a certain percentile of all positive distances. 24 | 25 | It is designed for data with generic coordinates. 26 | 27 | Parameters 28 | ---------- 29 | %(adata)s 30 | 31 | distance_percentile 32 | Percentile of the distances between cells over which links are trimmed after the network is built. 33 | %(conn_key)s 34 | 35 | distances_key 36 | Key in :attr:`anndata.AnnData.obsp` where spatial distances are stored. 37 | Default is: :attr:`anndata.AnnData.obsp` ``['{{Key.obsp.spatial_dist()}}']``. 38 | neighs_key 39 | Key in :attr:`anndata.AnnData.uns` where the parameters from gr.spatial_neighbors are stored. 40 | Default is: :attr:`anndata.AnnData.uns` ``['{{Key.uns.spatial_neighs()}}']``. 41 | 42 | %(copy)s 43 | 44 | Returns 45 | ------- 46 | If ``copy = True``, returns a :class:`tuple` with the new spatial connectivities and distances matrices. 47 | 48 | Otherwise, modifies the ``adata`` with the following keys: 49 | - :attr:`anndata.AnnData.obsp` ``['{{connectivity_key}}']`` - the new spatial connectivities. 50 | - :attr:`anndata.AnnData.obsp` ``['{{distances_key}}']`` - the new spatial distances. 51 | - :attr:`anndata.AnnData.uns` ``['{{neighs_key}}']`` - :class:`dict` containing parameters. 52 | """ 53 | connectivity_key = Key.obsp.spatial_conn(connectivity_key) 54 | distances_key = Key.obsp.spatial_dist(distances_key) 55 | neighs_key = Key.uns.spatial_neighs(neighs_key) 56 | _assert_connectivity_key(adata, connectivity_key) 57 | _assert_connectivity_key(adata, distances_key) 58 | 59 | conns, dists = adata.obsp[connectivity_key], adata.obsp[distances_key] 60 | 61 | if copy: 62 | conns, dists = conns.copy(), dists.copy() 63 | 64 | threshold = np.percentile(np.array(dists[dists != 0]).squeeze(), distance_percentile) 65 | conns[dists > threshold] = 0 66 | dists[dists > threshold] = 0 67 | 68 | conns.eliminate_zeros() 69 | dists.eliminate_zeros() 70 | 71 | if copy: 72 | return conns, dists 73 | else: 74 | adata.uns[neighs_key]["params"]["radius"] = threshold 75 | 76 | 77 | def _remove_intra_cluster_links(labels, adjacency): 78 | target_labels = np.array(labels.iloc[adjacency.indices]) 79 | source_labels = np.array( 80 | labels.iloc[np.repeat(np.arange(adjacency.indptr.shape[0] - 1), np.diff(adjacency.indptr))] 81 | ) 82 | 83 | inter_cluster_mask = (source_labels != target_labels).astype(int) 84 | 85 | adjacency.data *= inter_cluster_mask 86 | adjacency.eliminate_zeros() 87 | 88 | return adjacency 89 | 90 | 91 | @d.dedent 92 | def remove_intra_cluster_links( 93 | adata: AnnData, 94 | cluster_key: str, 95 | connectivity_key: str | None = None, 96 | distances_key: str | None = None, 97 | copy: bool = False, 98 | ) -> tuple[csr_matrix, csr_matrix] | None: 99 | """ 100 | Remove links between cells that belong to the same cluster. 101 | 102 | Used in :func:`cellcharter.gr.nhood_enrichment` to consider only interactions between cells of different clusters. 103 | 104 | Parameters 105 | ---------- 106 | %(adata)s 107 | 108 | cluster_key 109 | Key in :attr:`anndata.AnnData.obs` of the cluster labeling to consider. 110 | 111 | %(conn_key)s 112 | 113 | distances_key 114 | Key in :attr:`anndata.AnnData.obsp` where spatial distances are stored. 115 | Default is: :attr:`anndata.AnnData.obsp` ``['{{Key.obsp.spatial_dist()}}']``. 116 | 117 | %(copy)s 118 | 119 | Returns 120 | ------- 121 | If ``copy = True``, returns a :class:`tuple` with the new spatial connectivities and distances matrices. 122 | 123 | Otherwise, modifies the ``adata`` with the following keys: 124 | - :attr:`anndata.AnnData.obsp` ``['{{connectivity_key}}']`` - the new spatial connectivities. 125 | - :attr:`anndata.AnnData.obsp` ``['{{distances_key}}']`` - the new spatial distances. 126 | """ 127 | connectivity_key = Key.obsp.spatial_conn(connectivity_key) 128 | distances_key = Key.obsp.spatial_dist(distances_key) 129 | _assert_connectivity_key(adata, connectivity_key) 130 | _assert_connectivity_key(adata, distances_key) 131 | 132 | conns = adata.obsp[connectivity_key].copy() if copy else adata.obsp[connectivity_key] 133 | dists = adata.obsp[distances_key].copy() if copy else adata.obsp[distances_key] 134 | 135 | conns, dists = (_remove_intra_cluster_links(adata.obs[cluster_key], adjacency) for adjacency in [conns, dists]) 136 | 137 | if copy: 138 | return conns, dists 139 | 140 | 141 | def _connected_components(adj: sps.spmatrix, min_cells: int = 250, count: int = 0) -> np.ndarray: 142 | n_components, labels = sps.csgraph.connected_components(adj, return_labels=True) 143 | components, counts = np.unique(labels, return_counts=True) 144 | 145 | small_components = components[counts < min_cells] 146 | small_components_idxs = np.in1d(labels, small_components) 147 | 148 | labels[small_components_idxs] = -1 149 | labels[~small_components_idxs] = pd.factorize(labels[~small_components_idxs])[0] + count 150 | 151 | return labels, (n_components - len(small_components)) 152 | 153 | 154 | @d.dedent 155 | def connected_components( 156 | adata: AnnData, 157 | cluster_key: str = None, 158 | min_cells: int = 250, 159 | connectivity_key: str = None, 160 | out_key: str = "component", 161 | copy: bool = False, 162 | ) -> None | np.ndarray: 163 | """ 164 | Compute the connected components of the spatial graph. 165 | 166 | Parameters 167 | ---------- 168 | %(adata)s 169 | cluster_key 170 | Key in :attr:`anndata.AnnData.obs` where the cluster labels are stored. If :class:`None`, the connected components are computed on the whole dataset. 171 | min_cells 172 | Minimum number of cells for a connected component to be considered. 173 | %(conn_key)s 174 | out_key 175 | Key in :attr:`anndata.AnnData.obs` where the output matrix is stored if ``copy = False``. 176 | %(copy)s 177 | 178 | Returns 179 | ------- 180 | If ``copy = True``, returns a :class:`numpy.ndarray` with the connected components labels. 181 | 182 | Otherwise, modifies the ``adata`` with the following key: 183 | - :attr:`anndata.AnnData.obs` ``['{{out_key}}']`` - - the above mentioned :class:`numpy.ndarray`. 184 | """ 185 | connectivity_key = Key.obsp.spatial_conn(connectivity_key) 186 | output = pd.Series(index=adata.obs.index, dtype="object") 187 | 188 | count = 0 189 | 190 | if cluster_key is not None: 191 | cluster_values = adata.obs[cluster_key].unique() 192 | 193 | for cluster in cluster_values: 194 | adata_cluster = adata[adata.obs[cluster_key] == cluster] 195 | 196 | labels, n_components = _connected_components( 197 | adj=adata_cluster.obsp[connectivity_key], min_cells=min_cells, count=count 198 | ) 199 | output[adata.obs[cluster_key] == cluster] = labels 200 | count += n_components 201 | else: 202 | labels, n_components = _connected_components( 203 | adj=adata.obsp[connectivity_key], 204 | min_cells=min_cells, 205 | ) 206 | output.loc[:] = labels 207 | 208 | output = output.astype("category") 209 | output[output == -1] = np.nan 210 | output = output.cat.remove_unused_categories() 211 | 212 | if copy: 213 | return output.values 214 | 215 | adata.obs[out_key] = output 216 | -------------------------------------------------------------------------------- /src/cellcharter/gr/_group.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from anndata import AnnData 6 | from squidpy._docs import d 7 | from tqdm import tqdm 8 | 9 | 10 | def _proportion(obs, id_key, val_key, normalize=True): 11 | df = pd.pivot(obs[[id_key, val_key]].value_counts().reset_index(), index=id_key, columns=val_key) 12 | df[df.isna()] = 0 13 | df.columns = df.columns.droplevel(0) 14 | if normalize: 15 | return df.div(df.sum(axis=1), axis=0) 16 | else: 17 | return df 18 | 19 | 20 | def _observed_permuted(annotations, group_key, label_key): 21 | annotations[group_key] = annotations[group_key].sample(frac=1).reset_index(drop=True).values 22 | return _proportion(annotations, id_key=label_key, val_key=group_key).reindex().T 23 | 24 | 25 | def _enrichment(observed, expected, log=True): 26 | enrichment = observed.div(expected, axis="index", level=0) 27 | 28 | if log: 29 | with np.errstate(divide="ignore"): 30 | enrichment = np.log2(enrichment) 31 | enrichment = enrichment.fillna(enrichment.min()) 32 | return enrichment 33 | 34 | 35 | def _empirical_pvalues(observed, expected): 36 | pvalues = np.zeros(observed.shape) 37 | pvalues[observed.values > 0] = ( 38 | 1 - np.sum(expected[:, observed.values > 0] < observed.values[observed.values > 0], axis=0) / expected.shape[0] 39 | ) 40 | pvalues[observed.values < 0] = ( 41 | 1 - np.sum(expected[:, observed.values < 0] > observed.values[observed.values < 0], axis=0) / expected.shape[0] 42 | ) 43 | return pd.DataFrame(pvalues, columns=observed.columns, index=observed.index) 44 | 45 | 46 | @d.dedent 47 | def enrichment( 48 | adata: AnnData, 49 | group_key: str, 50 | label_key: str, 51 | pvalues: bool = False, 52 | n_perms: int = 1000, 53 | log: bool = True, 54 | observed_expected: bool = False, 55 | copy: bool = False, 56 | ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] | None: 57 | """ 58 | Compute the enrichment of `label_key` in `group_key`. 59 | 60 | Parameters 61 | ---------- 62 | %(adata)s 63 | group_key 64 | Key in :attr:`anndata.AnnData.obs` where groups are stored. 65 | label_key 66 | Key in :attr:`anndata.AnnData.obs` where labels are stored. 67 | pvalues 68 | If `True`, compute empirical p-values by permutation. It will result in a slower computation. 69 | n_perms 70 | Number of permutations to compute empirical p-values. 71 | log 72 | If `True` use log2 fold change, otherwise use fold change. 73 | observed_expected 74 | If `True`, return also the observed and expected proportions. 75 | %(copy)s 76 | 77 | Returns 78 | ------- 79 | If ``copy = True``, returns a :class:`dict` with the following keys: 80 | - ``'enrichment'`` - the enrichment values. 81 | - ``'pvalue'`` - the enrichment pvalues (if `pvalues is True`). 82 | - ``'observed'`` - the observed proportions (if `observed_expected is True`). 83 | - ``'expected'`` - the expected proportions (if `observed_expected is True`). 84 | 85 | Otherwise, modifies the ``adata`` with the following keys: 86 | - :attr:`anndata.AnnData.uns` ``['{group_key}_{label_key}_nhood_enrichment']`` - the above mentioned dict. 87 | - :attr:`anndata.AnnData.uns` ``['{group_key}_{label_key}_nhood_enrichment']['params']`` - the parameters used. 88 | """ 89 | observed = _proportion(adata.obs, id_key=label_key, val_key=group_key).reindex().T 90 | observed[observed.isna()] = 0 91 | if not pvalues: 92 | expected = adata.obs[group_key].value_counts() / adata.shape[0] 93 | # Repeat over the number of labels 94 | expected = pd.concat([expected] * len(observed.columns), axis=1, keys=observed.columns) 95 | else: 96 | annotations = adata.obs.copy() 97 | 98 | expected = [_observed_permuted(annotations, group_key, label_key) for _ in tqdm(range(n_perms))] 99 | expected = np.stack(expected, axis=0) 100 | 101 | empirical_pvalues = _empirical_pvalues(observed, expected) 102 | 103 | expected = np.mean(expected, axis=0) 104 | expected = pd.DataFrame(expected, columns=observed.columns, index=observed.index) 105 | 106 | enrichment = _enrichment(observed, expected, log=log) 107 | 108 | result = {"enrichment": enrichment} 109 | 110 | if observed_expected: 111 | result["observed"] = observed 112 | result["expected"] = expected 113 | 114 | if pvalues: 115 | result["pvalue"] = empirical_pvalues 116 | 117 | if copy: 118 | return result 119 | else: 120 | adata.uns[f"{group_key}_{label_key}_enrichment"] = result 121 | adata.uns[f"{group_key}_{label_key}_enrichment"]["params"] = {"log": log} 122 | -------------------------------------------------------------------------------- /src/cellcharter/gr/_utils.py: -------------------------------------------------------------------------------- 1 | """Graph utilities.""" 2 | 3 | from __future__ import annotations 4 | 5 | from anndata import AnnData 6 | 7 | 8 | def _assert_distances_key(adata: AnnData, key: str) -> None: 9 | if key not in adata.obsp: 10 | key_added = key.replace("_distances", "") 11 | raise KeyError( 12 | f"Spatial distances key `{key}` not found in `adata.obsp`. " 13 | f"Please run `squidpy.gr.spatial_neighbors(..., key_added={key_added!r})` first." 14 | ) 15 | -------------------------------------------------------------------------------- /src/cellcharter/pl/__init__.py: -------------------------------------------------------------------------------- 1 | from ._autok import autok_stability 2 | from ._group import enrichment, proportion 3 | from ._nhood import diff_nhood_enrichment, nhood_enrichment 4 | from ._shape import boundaries, shape_metrics 5 | -------------------------------------------------------------------------------- /src/cellcharter/pl/_autok.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import seaborn as sns 8 | 9 | from cellcharter.tl import ClusterAutoK 10 | 11 | 12 | def autok_stability(autok: ClusterAutoK, save: str | Path | None = None, return_ax: bool = False) -> None: 13 | """ 14 | Plot the clustering stability. 15 | 16 | The clustering stability is computed by :class:`cellcharter.tl.ClusterAutoK`. 17 | 18 | Parameters 19 | ---------- 20 | autok 21 | The fitted :class:`cellcharter.tl.ClusterAutoK` model. 22 | save 23 | Whether to save the plot. 24 | similarity_function 25 | The similarity function to use. Defaults to :func:`sklearn.metrics.fowlkes_mallows_score`. 26 | 27 | Returns 28 | ------- 29 | Nothing, just plots the figure and optionally saves the plot. 30 | """ 31 | robustness_df = pd.melt( 32 | pd.DataFrame.from_dict({k: autok.stability[i] for i, k in enumerate(autok.n_clusters[1:-1])}, orient="columns"), 33 | var_name="N. clusters", 34 | value_name="Stability", 35 | ) 36 | ax = sns.lineplot(data=robustness_df, x="N. clusters", y="Stability") 37 | ax.set_xticks(autok.n_clusters[1:-1]) 38 | if save: 39 | plt.savefig(save) 40 | 41 | ax.spines.right.set_visible(False) 42 | ax.spines.top.set_visible(False) 43 | 44 | if return_ax: 45 | return ax 46 | -------------------------------------------------------------------------------- /src/cellcharter/pl/_group.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from pathlib import Path 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pandas as pd 10 | import seaborn as sns 11 | from anndata import AnnData 12 | from matplotlib.colors import LogNorm, Normalize 13 | from matplotlib.legend_handler import HandlerTuple 14 | from scipy.cluster import hierarchy 15 | from squidpy._docs import d 16 | from squidpy.gr._utils import _assert_categorical_obs 17 | from squidpy.pl._color_utils import Palette_t, _get_palette, _maybe_set_colors 18 | 19 | try: 20 | from matplotlib.colormaps import get_cmap 21 | except ImportError: 22 | from matplotlib.pyplot import get_cmap 23 | 24 | from cellcharter.gr._group import _proportion 25 | 26 | empty_handle = matplotlib.patches.Rectangle((0, 0), 1, 1, fill=False, edgecolor="none", visible=False) 27 | 28 | 29 | @d.dedent 30 | def proportion( 31 | adata: AnnData, 32 | group_key: str, 33 | label_key: str, 34 | groups: list | None = None, 35 | labels: list | None = None, 36 | rotation_xlabel: int = 45, 37 | ncols: int = 1, 38 | normalize: bool = True, 39 | palette: Palette_t = None, 40 | figsize: tuple[float, float] | None = None, 41 | dpi: int | None = None, 42 | save: str | Path | None = None, 43 | **kwargs, 44 | ) -> None: 45 | """ 46 | Plot the proportion of `y_key` in `x_key`. 47 | 48 | Parameters 49 | ---------- 50 | %(adata)s 51 | group_key 52 | Key in :attr:`anndata.AnnData.obs` where groups are stored. 53 | label_key 54 | Key in :attr:`anndata.AnnData.obs` where labels are stored. 55 | groups 56 | List of groups to plot. 57 | labels 58 | List of labels to plot. 59 | rotation_xlabel 60 | Rotation in degrees of the ticks of the x axis. 61 | ncols 62 | Number of columns for the legend. 63 | normalize 64 | If `True` use relative frequencies, outherwise use counts. 65 | palette 66 | Categorical colormap for the clusters. 67 | If ``None``, use :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_colors']``, if available. 68 | %(plotting)s 69 | kwargs 70 | Keyword arguments for :func:`pandas.DataFrame.plot.bar`. 71 | Returns 72 | ------- 73 | %(plotting_returns)s 74 | """ 75 | _assert_categorical_obs(adata, key=group_key) 76 | _assert_categorical_obs(adata, key=label_key) 77 | _maybe_set_colors(source=adata, target=adata, key=label_key, palette=palette) 78 | 79 | clusters = adata.obs[label_key].cat.categories 80 | palette = _get_palette(adata, cluster_key=label_key, categories=clusters) 81 | 82 | df = _proportion(obs=adata.obs, id_key=group_key, val_key=label_key, normalize=normalize) 83 | df = df[df.columns[::-1]] 84 | 85 | if groups is not None: 86 | df = df.loc[groups, :] 87 | 88 | if labels is not None: 89 | df = df.loc[:, labels] 90 | 91 | plt.figure(dpi=dpi) 92 | ax = df.plot.bar(stacked=True, figsize=figsize, color=palette, rot=rotation_xlabel, ax=plt.gca(), **kwargs) 93 | ax.grid(False) 94 | 95 | handles, labels = ax.get_legend_handles_labels() 96 | lgd = ax.legend(handles[::-1], labels[::-1], loc="center left", ncol=ncols, bbox_to_anchor=(1.0, 0.5)) 97 | 98 | if save: 99 | plt.savefig(save, bbox_extra_artists=(lgd, lgd), bbox_inches="tight") 100 | 101 | 102 | def _select_labels(fold_change, pvalues, labels, groups): 103 | col_name = fold_change.columns.name 104 | idx_name = fold_change.index.name 105 | 106 | if labels is not None: 107 | fold_change = fold_change.loc[labels] 108 | 109 | # The indexing removes the name of the index, so we need to set it back 110 | fold_change.index.name = idx_name 111 | 112 | if pvalues is not None: 113 | pvalues = pvalues.loc[labels] 114 | 115 | # The indexing removes the name of the index, so we need to set it back 116 | pvalues.index.name = idx_name 117 | 118 | if groups is not None: 119 | fold_change = fold_change.loc[:, groups] 120 | 121 | # The indexing removes the name of the columns, so we need to set it back 122 | fold_change.columns.name = col_name 123 | 124 | if pvalues is not None: 125 | pvalues = pvalues.loc[:, groups] 126 | 127 | # The indexing removes the name of the columns, so we need to set it back 128 | pvalues.columns.name = col_name 129 | 130 | return fold_change, pvalues 131 | 132 | 133 | # Calculate the dendrogram for rows and columns clustering 134 | def _reorder_labels(fold_change, pvalues, group_cluster, label_cluster): 135 | if label_cluster: 136 | order_rows = hierarchy.leaves_list(hierarchy.linkage(fold_change, method="complete")) 137 | fold_change = fold_change.iloc[order_rows] 138 | 139 | if pvalues is not None: 140 | pvalues = pvalues.iloc[order_rows] 141 | 142 | if group_cluster: 143 | order_cols = hierarchy.leaves_list(hierarchy.linkage(fold_change.T, method="complete")) 144 | fold_change = fold_change.iloc[:, order_cols] 145 | 146 | if pvalues is not None: 147 | pvalues = pvalues.iloc[:, order_cols] 148 | return fold_change, pvalues 149 | 150 | 151 | def _significance_colors(color, pvalues, significance): 152 | color[pvalues <= significance] = 0.0 153 | color[pvalues > significance] = 0.8 154 | return color 155 | 156 | 157 | def _pvalue_colorbar(ax, cmap_enriched, cmap_depleted, norm): 158 | from matplotlib.colorbar import ColorbarBase 159 | from mpl_toolkits.axes_grid1 import make_axes_locatable 160 | 161 | divider = make_axes_locatable(ax) 162 | 163 | # Append axes to the right of ax, with 5% width of ax 164 | cax1 = divider.append_axes("right", size="2%", pad=0.05) 165 | 166 | cbar1 = ColorbarBase(cax1, cmap=cmap_enriched, norm=norm, orientation="vertical") 167 | 168 | cbar1.ax.invert_yaxis() 169 | cbar1.ax.tick_params(labelsize=10) 170 | cbar1.set_ticks([], minor=True) 171 | cbar1.ax.set_title("p-value", fontdict={"fontsize": 10}) 172 | 173 | if cmap_depleted is not None: 174 | cax2 = divider.append_axes("right", size="2%", pad=0.10) 175 | 176 | # Place colorbars next to each other and share ticks 177 | cbar2 = ColorbarBase(cax2, cmap=cmap_depleted, norm=norm, orientation="vertical") 178 | cbar2.ax.invert_yaxis() 179 | cbar2.ax.tick_params(labelsize=10) 180 | cbar2.set_ticks([], minor=True) 181 | 182 | cbar1.set_ticks([]) 183 | 184 | 185 | def _enrichment_legend( 186 | scatters, fold_change_melt, dot_scale, size_max, enriched_only, significant_only, significance, size_threshold 187 | ): 188 | handles_list = [] 189 | labels_list = [] 190 | 191 | if enriched_only is False: 192 | handles_list.extend( 193 | [scatter.legend_elements(prop="colors", num=None)[0][0] for scatter in scatters] + [empty_handle] 194 | ) 195 | labels_list.extend(["Enriched", "Depleted", ""]) 196 | 197 | if significance is not None: 198 | handles_list.append(tuple([scatter.legend_elements(prop="colors", num=None)[0][0] for scatter in scatters])) 199 | labels_list.append(f"p-value < {significance}") 200 | 201 | if significant_only is False: 202 | handles_list.append(tuple([scatter.legend_elements(prop="colors", num=None)[0][1] for scatter in scatters])) 203 | labels_list.append(f"p-value >= {significance}") 204 | 205 | handles_list.append(empty_handle) 206 | labels_list.append("") 207 | 208 | handles, labels = scatters[0].legend_elements(prop="sizes", num=5, func=lambda x: x / 100 / dot_scale * size_max) 209 | 210 | if size_threshold is not None: 211 | # Show the threshold as a label only if the threshold is lower than the maximum fold change 212 | if enriched_only is True and fold_change_melt[fold_change_melt["value"] >= 0]["value"].max() > size_threshold: 213 | labels[-1] = f">{size_threshold:.1f}" 214 | elif fold_change_melt["value"].max() > size_threshold: 215 | labels[-1] = f">{size_threshold:.1f}" 216 | 217 | handles_list.extend([empty_handle] + handles) 218 | labels_list.extend(["log2 FC"] + labels) 219 | 220 | return handles_list, labels_list 221 | 222 | 223 | @d.dedent 224 | def enrichment( 225 | adata: AnnData, 226 | group_key: str, 227 | label_key: str, 228 | dot_scale: float = 3, 229 | group_cluster: bool = True, 230 | label_cluster: bool = False, 231 | groups: list | None = None, 232 | labels: list | None = None, 233 | show_pvalues: bool = False, 234 | significance: float | None = None, 235 | enriched_only: bool = True, 236 | significant_only: bool = False, 237 | size_threshold: float | None = None, 238 | palette: str | matplotlib.colors.ListedColormap | None = None, 239 | fontsize: str | int = "small", 240 | figsize: tuple[float, float] | None = (7, 5), 241 | save: str | Path | None = None, 242 | **kwargs, 243 | ): 244 | """ 245 | Plot a dotplot of the enrichment of `label_key` in `group_key`. 246 | 247 | Parameters 248 | ---------- 249 | %(adata)s 250 | group_key 251 | Key in :attr:`anndata.AnnData.obs` where groups are stored. 252 | label_key 253 | Key in :attr:`anndata.AnnData.obs` where labels are stored. 254 | dot_scale 255 | Scale of the dots. 256 | group_cluster 257 | If `True`, display groups ordered according to hierarchical clustering. 258 | label_cluster 259 | If `True`, display labels ordered according to hierarchical clustering. 260 | groups 261 | The groups for which to show the enrichment. 262 | labels 263 | The labels for which to show the enrichment. 264 | show_pvalues 265 | If `True`, show p-values as colors. 266 | significance 267 | If not `None`, show fold changes with a p-value above this threshold in a lighter color. 268 | enriched_only 269 | If `True`, display only enriched values and hide depleted values. 270 | significant_only 271 | If `True`, display only significant values and hide non-significant values. 272 | size_threshold 273 | Threshold for the size of the dots. Enrichment or depletions with absolute value above this threshold will have all the same size. 274 | palette 275 | Colormap for the enrichment values. It must be a diverging colormap. 276 | %(plotting)s 277 | kwargs 278 | Keyword arguments for :func:`matplotlib.pyplot.scatter`. 279 | """ 280 | if f"{group_key}_{label_key}_enrichment" not in adata.uns: 281 | raise ValueError("Run cellcharter.gr.enrichment first.") 282 | 283 | if size_threshold is not None and size_threshold <= 0: 284 | raise ValueError("size_threshold must be greater than 0.") 285 | 286 | if palette is None: 287 | palette = sns.diverging_palette(240, 10, as_cmap=True) 288 | elif isinstance(palette, str): 289 | palette = get_cmap(palette) 290 | 291 | pvalues = None 292 | if "pvalue" not in adata.uns[f"{group_key}_{label_key}_enrichment"]: 293 | if show_pvalues: 294 | ValueError("show_pvalues requires gr.enrichment to be run with pvalues=True.") 295 | 296 | if significance is not None: 297 | ValueError("significance requires gr.enrichment to be run with pvalues=True.") 298 | 299 | if significant_only: 300 | ValueError("significant_only requires gr.enrichment to be run with pvalues=True.") 301 | elif show_pvalues: 302 | pvalues = adata.uns[f"{group_key}_{label_key}_enrichment"]["pvalue"].copy().T 303 | else: 304 | if significance is not None: 305 | warnings.warn( 306 | "Significance requires show_pvalues=True. Ignoring significance.", 307 | UserWarning, 308 | stacklevel=2, 309 | ) 310 | significance = None 311 | 312 | if significant_only is True and significance is None: 313 | warnings.warn( 314 | "Significant_only requires significance to be set. Ignoring significant_only.", 315 | UserWarning, 316 | stacklevel=2, 317 | ) 318 | significant_only = False 319 | 320 | # Set kwargs['alpha'] to 1 if not set 321 | if "alpha" not in kwargs: 322 | kwargs["alpha"] = 1 323 | 324 | if "edgecolor" not in kwargs: 325 | kwargs["edgecolor"] = "none" 326 | 327 | fold_change = adata.uns[f"{group_key}_{label_key}_enrichment"]["enrichment"].copy().T 328 | 329 | fold_change, pvalues = _select_labels(fold_change, pvalues, labels, groups) 330 | 331 | # Set -inf values to minimum and inf values to maximum 332 | fold_change[:] = np.nan_to_num( 333 | fold_change, 334 | neginf=np.min(fold_change[np.isfinite(fold_change)]), 335 | posinf=np.max(fold_change[np.isfinite(fold_change)]), 336 | ) 337 | 338 | fold_change, pvalues = _reorder_labels(fold_change, pvalues, group_cluster, label_cluster) 339 | 340 | fold_change_melt = pd.melt(fold_change.reset_index(), id_vars=label_key) 341 | 342 | # Normalize the size of dots based on the absolute values in the dataframe, scaled to your preference 343 | sizes = fold_change_melt.copy() 344 | sizes["value"] = np.abs(sizes["value"]) 345 | size_max = sizes["value"].max() if size_threshold is None else size_threshold 346 | if size_threshold is not None: 347 | sizes["value"] = sizes["value"].clip(upper=size_threshold) 348 | 349 | sizes["value"] = sizes["value"] * 100 / sizes["value"].max() * dot_scale 350 | 351 | norm = Normalize(0, 1) 352 | # Set colormap to red if below 0, blue if above 0 353 | if significance is not None: 354 | color = _significance_colors(fold_change.copy(), pvalues, significance) 355 | else: 356 | if pvalues is not None: 357 | pvalues += 0.0001 358 | norm = LogNorm(vmin=pvalues.min().min(), vmax=pvalues.max().max()) 359 | color = pvalues.copy() 360 | else: 361 | color = fold_change.copy() 362 | color[:] = 0.0 363 | 364 | color = pd.melt(color.reset_index(), id_vars=label_key) 365 | 366 | # Create a figure and axis for plotting 367 | fig, ax = plt.subplots(figsize=figsize) 368 | 369 | scatters = [] 370 | enriched_mask = fold_change_melt["value"] >= 0 371 | 372 | significant_mask = np.ones_like(fold_change_melt["value"], dtype=bool) 373 | 374 | if significant_only: 375 | significant_mask = pd.melt(pvalues.reset_index(), id_vars=label_key)["value"] < significance 376 | 377 | cmap_enriched = matplotlib.colors.LinearSegmentedColormap.from_list("", [palette(1.0), palette(0.5)]) 378 | scatter_enriched = ax.scatter( 379 | pd.factorize(sizes[label_key])[0][enriched_mask & significant_mask], 380 | pd.factorize(sizes[group_key])[0][enriched_mask & significant_mask], 381 | s=sizes["value"][enriched_mask & significant_mask], 382 | c=color["value"][enriched_mask & significant_mask], 383 | cmap=cmap_enriched, 384 | norm=norm, 385 | **kwargs, 386 | ) 387 | scatters.append(scatter_enriched) 388 | 389 | cmap_depleted = None 390 | if enriched_only is False: 391 | cmap_depleted = matplotlib.colors.LinearSegmentedColormap.from_list("", [palette(0.0), palette(0.5)]) 392 | scatter_depleted = ax.scatter( 393 | pd.factorize(sizes[label_key])[0][~enriched_mask & significant_mask], 394 | pd.factorize(sizes[group_key])[0][~enriched_mask & significant_mask], 395 | s=sizes["value"][~enriched_mask & significant_mask], 396 | c=color["value"][~enriched_mask & significant_mask], 397 | cmap=cmap_depleted, 398 | norm=norm, 399 | **kwargs, 400 | ) 401 | scatters.append(scatter_depleted) 402 | 403 | if pvalues is not None and significance is None: 404 | _pvalue_colorbar(ax, cmap_enriched, cmap_depleted, norm) 405 | 406 | handles_list, labels_list = _enrichment_legend( 407 | scatters, fold_change_melt, dot_scale, size_max, enriched_only, significant_only, significance, size_threshold 408 | ) 409 | 410 | fig.legend( 411 | handles_list, 412 | labels_list, 413 | loc="outside upper left", 414 | bbox_to_anchor=(0.98, 0.95), 415 | handler_map={tuple: HandlerTuple(ndivide=None, pad=1)}, 416 | borderpad=1, 417 | handletextpad=1.0, 418 | fontsize=fontsize, 419 | ) 420 | 421 | # Adjust the ticks to match the dataframe's indices and columns 422 | ax.set_xticks(range(len(fold_change.index))) 423 | ax.set_yticks(range(len(fold_change.columns))) 424 | ax.set_xticklabels(fold_change.index, rotation=90) 425 | ax.set_yticklabels(fold_change.columns) 426 | ax.tick_params(axis="both", which="major", labelsize=fontsize) 427 | 428 | # Remove grid lines 429 | ax.grid(False) 430 | 431 | plt.tight_layout() 432 | 433 | if save: 434 | plt.savefig(save, bbox_inches="tight") 435 | -------------------------------------------------------------------------------- /src/cellcharter/pl/_nhood.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from itertools import combinations 5 | from pathlib import Path 6 | from types import MappingProxyType 7 | from typing import Any, Mapping 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | from anndata import AnnData 13 | from matplotlib import rcParams 14 | from matplotlib.axes import Axes 15 | from squidpy._docs import d 16 | from squidpy.gr._utils import _assert_categorical_obs 17 | from squidpy.pl._color_utils import Palette_t, _maybe_set_colors 18 | from squidpy.pl._graph import _get_data 19 | from squidpy.pl._spatial_utils import _panel_grid 20 | 21 | from cellcharter.pl._utils import _heatmap 22 | 23 | 24 | def _plot_nhood_enrichment( 25 | adata: AnnData, 26 | nhood_enrichment_values: dict, 27 | cluster_key: str, 28 | row_groups: str | None = None, 29 | col_groups: str | None = None, 30 | annotate: bool = False, 31 | n_digits: int = 2, 32 | significance: float | None = None, 33 | method: str | None = None, 34 | title: str | None = "Neighborhood enrichment", 35 | cmap: str = "bwr", 36 | palette: Palette_t = None, 37 | cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), 38 | fontsize=None, 39 | figsize: tuple[float, float] | None = None, 40 | dpi: int | None = None, 41 | ax: Axes | None = None, 42 | **kwargs: Any, 43 | ): 44 | enrichment = nhood_enrichment_values["enrichment"] 45 | adata_enrichment = AnnData(X=enrichment.astype(np.float32)) 46 | adata_enrichment.obs[cluster_key] = pd.Categorical(enrichment.index) 47 | 48 | if significance is not None: 49 | if "pvalue" not in nhood_enrichment_values: 50 | warnings.warn( 51 | "Significance requires gr.nhood_enrichment to be run with pvalues=True. Ignoring significance.", 52 | UserWarning, 53 | stacklevel=2, 54 | ) 55 | else: 56 | adata_enrichment.layers["significant"] = np.empty_like(enrichment, dtype=str) 57 | adata_enrichment.layers["significant"][nhood_enrichment_values["pvalue"].values <= significance] = "*" 58 | 59 | _maybe_set_colors(source=adata, target=adata_enrichment, key=cluster_key, palette=palette) 60 | 61 | if figsize is None: 62 | figsize = list(adata_enrichment.shape[::-1]) 63 | 64 | if row_groups is not None: 65 | figsize[1] = len(row_groups) 66 | 67 | if col_groups is not None: 68 | figsize[0] = len(col_groups) 69 | 70 | figsize = tuple(figsize) 71 | 72 | _heatmap( 73 | adata_enrichment, 74 | key=cluster_key, 75 | rows=row_groups, 76 | cols=col_groups, 77 | annotate=annotate, 78 | n_digits=n_digits, 79 | method=method, 80 | title=title, 81 | cont_cmap=cmap, 82 | fontsize=fontsize, 83 | figsize=figsize, 84 | dpi=dpi, 85 | cbar_kwargs=cbar_kwargs, 86 | ax=ax, 87 | **kwargs, 88 | ) 89 | 90 | 91 | @d.dedent 92 | def nhood_enrichment( 93 | adata: AnnData, 94 | cluster_key: str, 95 | row_groups: list[str] | None = None, 96 | col_groups: list[str] | None = None, 97 | min_freq: float | None = None, 98 | annotate: bool = False, 99 | transpose: bool = False, 100 | method: str | None = None, 101 | title: str | None = "Neighborhood enrichment", 102 | cmap: str = "bwr", 103 | palette: Palette_t = None, 104 | cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), 105 | figsize: tuple[float, float] | None = None, 106 | dpi: int | None = None, 107 | ax: Axes | None = None, 108 | n_digits: int = 2, 109 | significance: float | None = None, 110 | save: str | Path | None = None, 111 | **kwargs: Any, 112 | ) -> None: 113 | """ 114 | A modified version of squidpy's function for `plotting neighborhood enrichment `_. 115 | 116 | The enrichment is computed by :func:`cellcharter.gr.nhood_enrichment`. 117 | 118 | Parameters 119 | ---------- 120 | %(adata)s 121 | %(cluster_key)s 122 | row_groups 123 | Restrict the rows to these groups. If `None`, all groups are plotted. 124 | col_groups 125 | Restrict the columns to these groups. If `None`, all groups are plotted. 126 | %(heatmap_plotting)s 127 | 128 | n_digits 129 | The number of digits of the number in the annotations. 130 | significance 131 | Mark the values that are below this threshold with a star. If `None`, no significance is computed. It requires ``gr.nhood_enrichment`` to be run with ``pvalues=True``. 132 | kwargs 133 | Keyword arguments for :func:`matplotlib.pyplot.text`. 134 | 135 | Returns 136 | ------- 137 | %(plotting_returns)s 138 | """ 139 | _assert_categorical_obs(adata, key=cluster_key) 140 | nhood_enrichment_values = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment").copy() 141 | nhood_enrichment_values["enrichment"][np.isinf(nhood_enrichment_values["enrichment"])] = np.nan 142 | 143 | if transpose: 144 | nhood_enrichment_values["enrichment"] = nhood_enrichment_values["enrichment"].T 145 | 146 | if min_freq is not None: 147 | frequency = adata.obs[cluster_key].value_counts(normalize=True) 148 | nhood_enrichment_values["enrichment"].loc[frequency[frequency < min_freq].index] = np.nan 149 | nhood_enrichment_values["enrichment"].loc[:, frequency[frequency < min_freq].index] = np.nan 150 | 151 | _plot_nhood_enrichment( 152 | adata, 153 | nhood_enrichment_values, 154 | cluster_key, 155 | row_groups=row_groups, 156 | col_groups=col_groups, 157 | annotate=annotate, 158 | method=method, 159 | title=title, 160 | cmap=cmap, 161 | palette=palette, 162 | cbar_kwargs=cbar_kwargs, 163 | figsize=figsize, 164 | dpi=dpi, 165 | ax=ax, 166 | n_digits=n_digits, 167 | significance=significance, 168 | **kwargs, 169 | ) 170 | 171 | if save is not None: 172 | plt.savefig(save, bbox_inches="tight") 173 | 174 | 175 | @d.dedent 176 | def diff_nhood_enrichment( 177 | adata: AnnData, 178 | cluster_key: str, 179 | condition_key: str, 180 | condition_groups: list[str] | None = None, 181 | hspace: float = 0.25, 182 | wspace: float | None = None, 183 | ncols: int = 1, 184 | **nhood_kwargs: Any, 185 | ) -> None: 186 | r""" 187 | Plot the difference in neighborhood enrichment between conditions. 188 | 189 | The difference is computed by :func:`cellcharter.gr.diff_nhood_enrichment`. 190 | 191 | Parameters 192 | ---------- 193 | %(adata)s 194 | %(cluster_key)s 195 | condition_key 196 | Key in ``adata.obs`` that stores the sample condition (e.g., normal vs disease). 197 | condition_groups 198 | Restrict the conditions to these clusters. If `None`, all groups are plotted. 199 | hspace 200 | Height space between panels. 201 | wspace 202 | Width space between panels. 203 | ncols 204 | Number of panels per row. 205 | nhood_kwargs 206 | Keyword arguments for :func:`cellcharter.pl.nhood_enrichment`. 207 | 208 | Returns 209 | ------- 210 | %(plotting_returns)s 211 | """ 212 | _assert_categorical_obs(adata, key=cluster_key) 213 | _assert_categorical_obs(adata, key=condition_key) 214 | 215 | conditions = adata.obs[condition_key].cat.categories if condition_groups is None else condition_groups 216 | 217 | if nhood_kwargs is None: 218 | nhood_kwargs = {} 219 | 220 | cmap = nhood_kwargs.pop("cmap", "PRGn_r") 221 | save = nhood_kwargs.pop("save", None) 222 | 223 | n_combinations = len(conditions) * (len(conditions) - 1) // 2 224 | 225 | figsize = nhood_kwargs.get("figsize", rcParams["figure.figsize"]) 226 | 227 | # Plot neighborhood enrichment for each condition pair as a subplot 228 | _, grid = _panel_grid( 229 | num_panels=n_combinations, 230 | hspace=hspace, 231 | wspace=0.75 / figsize[0] + 0.02 if wspace is None else wspace, 232 | ncols=ncols, 233 | dpi=nhood_kwargs.get("dpi", rcParams["figure.dpi"]), 234 | figsize=nhood_kwargs.get("dpi", figsize), 235 | ) 236 | 237 | axs = [plt.subplot(grid[c]) for c in range(n_combinations)] 238 | 239 | for i, (condition1, condition2) in enumerate(combinations(conditions, 2)): 240 | if f"{condition1}_{condition2}" not in adata.uns[f"{cluster_key}_{condition_key}_diff_nhood_enrichment"]: 241 | nhood_enrichment_values = dict( 242 | adata.uns[f"{cluster_key}_{condition_key}_diff_nhood_enrichment"][f"{condition2}_{condition1}"] 243 | ) 244 | nhood_enrichment_values["enrichment"] = -nhood_enrichment_values["enrichment"] 245 | else: 246 | nhood_enrichment_values = adata.uns[f"{cluster_key}_{condition_key}_diff_nhood_enrichment"][ 247 | f"{condition1}_{condition2}" 248 | ] 249 | 250 | _plot_nhood_enrichment( 251 | adata, 252 | nhood_enrichment_values, 253 | cluster_key, 254 | cmap=cmap, 255 | ax=axs[i], 256 | title=f"{condition1} vs {condition2}", 257 | show_cols=i >= n_combinations - ncols, # Show column labels only the last subplot of each grid column 258 | show_rows=i % ncols == 0, # Show row labels only for the first subplot of each grid row 259 | **nhood_kwargs, 260 | ) 261 | 262 | if save is not None: 263 | plt.savefig(save, bbox_inches="tight") 264 | -------------------------------------------------------------------------------- /src/cellcharter/pl/_shape.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from itertools import combinations 5 | from pathlib import Path 6 | 7 | import anndata as ad 8 | import geopandas 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import seaborn as sns 13 | import spatialdata as sd 14 | import spatialdata_plot # noqa: F401 15 | from anndata import AnnData 16 | from scipy.stats import ttest_ind 17 | from squidpy._docs import d 18 | 19 | from ._utils import adjust_box_widths 20 | 21 | 22 | def plot_boundaries( 23 | adata: AnnData, 24 | sample: str, 25 | library_key: str = "sample", 26 | component_key: str = "component", 27 | alpha_boundary: float = 0.5, 28 | show_cells: bool = True, 29 | save: str | Path | None = None, 30 | ) -> None: 31 | """ 32 | Plot the boundaries of the clusters. 33 | 34 | Parameters 35 | ---------- 36 | %(adata)s 37 | sample 38 | Sample to plot. 39 | library_key 40 | Key in :attr:`anndata.AnnData.obs` where the sample labels are stored. 41 | component_key 42 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored. 43 | alpha_boundary 44 | Transparency of the boundaries. 45 | show_cells 46 | Whether to show the cells or not. 47 | 48 | Returns 49 | ------- 50 | %(plotting_returns)s 51 | """ 52 | # Print warning and call boundaries 53 | warnings.warn( 54 | "plot_boundaries is deprecated and will be removed in the next release. " "Please use `boundaries` instead.", 55 | FutureWarning, 56 | stacklevel=2, 57 | ) 58 | boundaries( 59 | adata=adata, 60 | sample=sample, 61 | library_key=library_key, 62 | component_key=component_key, 63 | alpha_boundary=alpha_boundary, 64 | show_cells=show_cells, 65 | save=save, 66 | ) 67 | 68 | 69 | @d.dedent 70 | def boundaries( 71 | adata: AnnData, 72 | sample: str, 73 | library_key: str = "sample", 74 | component_key: str = "component", 75 | alpha_boundary: float = 0.5, 76 | show_cells: bool = True, 77 | save: str | Path | None = None, 78 | ) -> None: 79 | """ 80 | Plot the boundaries of the clusters. 81 | 82 | Parameters 83 | ---------- 84 | %(adata)s 85 | sample 86 | Sample to plot. 87 | library_key 88 | Key in :attr:`anndata.AnnData.obs` where the sample labels are stored. 89 | component_key 90 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored. 91 | alpha_boundary 92 | Transparency of the boundaries. 93 | show_cells 94 | Whether to show the cells or not. 95 | 96 | Returns 97 | ------- 98 | %(plotting_returns)s 99 | """ 100 | adata = adata[adata.obs[library_key] == sample].copy() 101 | del adata.raw 102 | clusters = adata.obs[component_key].unique() 103 | 104 | boundaries = { 105 | cluster: boundary 106 | for cluster, boundary in adata.uns[f"shape_{component_key}"]["boundary"].items() 107 | if cluster in clusters 108 | } 109 | gdf = geopandas.GeoDataFrame(geometry=list(boundaries.values()), index=np.arange(len(boundaries)).astype(str)) 110 | adata.obs.loc[adata.obs[component_key] == -1, component_key] = np.nan 111 | adata.obs.index = "cell_" + adata.obs.index 112 | adata.obs["instance_id"] = adata.obs.index 113 | adata.obs["region"] = "cells" 114 | 115 | xy = adata.obsm["spatial"] 116 | cell_circles = sd.models.ShapesModel.parse(xy, geometry=0, radius=3000, index=adata.obs["instance_id"]) 117 | 118 | obs = pd.DataFrame(list(boundaries.keys()), columns=[component_key], index=np.arange(len(boundaries)).astype(str)) 119 | adata_obs = ad.AnnData(X=pd.DataFrame(index=obs.index, columns=adata.var_names), obs=obs) 120 | adata_obs.obs["region"] = "clusters" 121 | adata_obs.index = "cluster_" + adata_obs.obs.index 122 | adata_obs.obs["instance_id"] = np.arange(len(boundaries)).astype(str) 123 | adata_obs.obs[component_key] = pd.Categorical(adata_obs.obs[component_key]) 124 | 125 | adata = ad.concat((adata, adata_obs), join="outer") 126 | 127 | adata.obs["region"] = adata.obs["region"].astype("category") 128 | 129 | table = sd.models.TableModel.parse( 130 | adata, region_key="region", region=["clusters", "cells"], instance_key="instance_id" 131 | ) 132 | 133 | shapes = { 134 | "clusters": sd.models.ShapesModel.parse(gdf), 135 | "cells": sd.models.ShapesModel.parse(cell_circles), 136 | } 137 | 138 | sdata = sd.SpatialData(shapes=shapes, tables=table) 139 | 140 | ax = plt.gca() 141 | if show_cells: 142 | try: 143 | sdata.pl.render_shapes(elements="cells", color=component_key).pl.show(ax=ax, legend_loc=None) 144 | except TypeError: # TODO: remove after spatialdata-plot issue #256 is fixed 145 | warnings.warn( 146 | "Until the next spatialdata_plot release, the cells that do not belong to any component will be displayed with a random color instead of grey.", 147 | stacklevel=2, 148 | ) 149 | sdata.tables["table"].obs[component_key] = sdata.tables["table"].obs[component_key].cat.add_categories([-1]) 150 | sdata.tables["table"].obs[component_key] = sdata.tables["table"].obs[component_key].fillna(-1) 151 | sdata.pl.render_shapes(elements="cells", color=component_key).pl.show(ax=ax, legend_loc=None) 152 | 153 | sdata.pl.render_shapes( 154 | element="clusters", 155 | color=component_key, 156 | fill_alpha=alpha_boundary, 157 | ).pl.show(ax=ax) 158 | 159 | if save is not None: 160 | plt.savefig(save, bbox_inches="tight") 161 | 162 | 163 | def plot_shape_metrics( 164 | adata: AnnData, 165 | condition_key: str, 166 | condition_groups: list[str] | None = None, 167 | cluster_key: str | None = None, 168 | cluster_id: list[str] | None = None, 169 | component_key: str = "component", 170 | metrics: str | tuple[str] | list[str] = ("linearity", "curl"), 171 | figsize: tuple[float, float] = (8, 7), 172 | title: str | None = None, 173 | ) -> None: 174 | """ 175 | Boxplots of the shape metrics between two conditions. 176 | 177 | Parameters 178 | ---------- 179 | %(adata)s 180 | condition_key 181 | Key in :attr:`anndata.AnnData.obs` where the condition labels are stored. 182 | condition_groups 183 | List of two conditions to compare. If None, all pairwise comparisons are made. 184 | cluster_key 185 | Key in :attr:`anndata.AnnData.obs` where the cluster labels are stored. This is used to filter the clusters to plot. 186 | cluster_id 187 | List of clusters to plot. If None, all clusters are plotted. 188 | component_key 189 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored. 190 | metrics 191 | List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``. 192 | figsize 193 | Figure size. 194 | title 195 | Title of the plot. 196 | 197 | Returns 198 | ------- 199 | %(plotting_returns)s 200 | """ 201 | # Print warning and call shape_metrics 202 | warnings.warn( 203 | "plot_shape_metrics is deprecated and will be removed in the next release. " 204 | "Please use `shape_metrics` instead.", 205 | FutureWarning, 206 | stacklevel=2, 207 | ) 208 | shape_metrics( 209 | adata=adata, 210 | condition_key=condition_key, 211 | condition_groups=condition_groups, 212 | cluster_key=cluster_key, 213 | cluster_id=cluster_id, 214 | component_key=component_key, 215 | metrics=metrics, 216 | figsize=figsize, 217 | title=title, 218 | ) 219 | 220 | 221 | @d.dedent 222 | def shape_metrics( 223 | adata: AnnData, 224 | condition_key: str, 225 | condition_groups: list[str] | None = None, 226 | cluster_key: str | None = None, 227 | cluster_id: list[str] | None = None, 228 | component_key: str = "component", 229 | metrics: str | tuple[str] | list[str] = ("linearity", "curl"), 230 | fontsize: str | int = "small", 231 | figsize: tuple[float, float] = (8, 7), 232 | title: str | None = None, 233 | ) -> None: 234 | """ 235 | Boxplots of the shape metrics between two conditions. 236 | 237 | Parameters 238 | ---------- 239 | %(adata)s 240 | condition_key 241 | Key in :attr:`anndata.AnnData.obs` where the condition labels are stored. 242 | condition_groups 243 | List of two conditions to compare. If None, all pairwise comparisons are made. 244 | cluster_key 245 | Key in :attr:`anndata.AnnData.obs` where the cluster labels are stored. This is used to filter the clusters to plot. 246 | cluster_id 247 | List of clusters to plot. If None, all clusters are plotted. 248 | component_key 249 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored. 250 | metrics 251 | List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``. 252 | figsize 253 | Figure size. 254 | title 255 | Title of the plot. 256 | 257 | Returns 258 | ------- 259 | %(plotting_returns)s 260 | """ 261 | if isinstance(metrics, str): 262 | metrics = [metrics] 263 | elif isinstance(metrics, tuple): 264 | metrics = list(metrics) 265 | 266 | metrics_df = {metric: adata.uns[f"shape_{component_key}"][metric] for metric in metrics} 267 | metrics_df[condition_key] = ( 268 | adata[~adata.obs[condition_key].isna()] 269 | .obs[[component_key, condition_key]] 270 | .drop_duplicates() 271 | .set_index(component_key) 272 | .to_dict()[condition_key] 273 | ) 274 | 275 | metrics_df[cluster_key] = ( 276 | adata[~adata.obs[condition_key].isna()] 277 | .obs[[component_key, cluster_key]] 278 | .drop_duplicates() 279 | .set_index(component_key) 280 | .to_dict()[cluster_key] 281 | ) 282 | 283 | metrics_df = pd.DataFrame(metrics_df) 284 | if cluster_id is not None: 285 | metrics_df = metrics_df[metrics_df[cluster_key].isin(cluster_id)] 286 | 287 | metrics_df = pd.melt( 288 | metrics_df[metrics + [condition_key]], 289 | id_vars=[condition_key], 290 | var_name="metric", 291 | ) 292 | 293 | conditions = ( 294 | enumerate(combinations(adata.obs[condition_key].cat.categories, 2)) 295 | if condition_groups is None 296 | else [condition_groups] 297 | ) 298 | 299 | for condition1, condition2 in conditions: 300 | fig = plt.figure(figsize=figsize) 301 | metrics_condition_pair = metrics_df[metrics_df[condition_key].isin([condition1, condition2])] 302 | ax = sns.boxplot( 303 | data=metrics_condition_pair, 304 | x="metric", 305 | hue=condition_key, 306 | y="value", 307 | showfliers=False, 308 | hue_order=[condition1, condition2], 309 | ) 310 | 311 | ax.tick_params(labelsize=fontsize) 312 | ax.set_xlabel(ax.get_xlabel(), fontsize=fontsize) 313 | ax.tick_params(labelsize=fontsize) 314 | ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize) 315 | 316 | adjust_box_widths(fig, 0.9) 317 | 318 | ax = sns.stripplot( 319 | data=metrics_condition_pair, 320 | x="metric", 321 | hue=condition_key, 322 | y="value", 323 | color="0.08", 324 | size=4, 325 | jitter=0.13, 326 | dodge=True, 327 | hue_order=condition_groups if condition_groups else None, 328 | ) 329 | handles, labels = ax.get_legend_handles_labels() 330 | plt.legend( 331 | handles[0 : len(metrics_condition_pair[condition_key].unique())], 332 | labels[0 : len(metrics_condition_pair[condition_key].unique())], 333 | bbox_to_anchor=(1.24, 1.02), 334 | fontsize=fontsize, 335 | ) 336 | 337 | for count, metric in enumerate(["linearity", "curl"]): 338 | pvalue = ttest_ind( 339 | metrics_condition_pair[ 340 | (metrics_condition_pair[condition_key] == condition1) & (metrics_condition_pair["metric"] == metric) 341 | ]["value"], 342 | metrics_condition_pair[ 343 | (metrics_condition_pair[condition_key] == condition2) & (metrics_condition_pair["metric"] == metric) 344 | ]["value"], 345 | )[1] 346 | x1, x2 = count, count 347 | y, h, col = ( 348 | metrics_condition_pair[(metrics_condition_pair["metric"] == metric)]["value"].max() 349 | + 0.02 350 | + 0.05 * count, 351 | 0.01, 352 | "k", 353 | ) 354 | plt.plot([x1 - 0.2, x1 - 0.2, x2 + 0.2, x2 + 0.2], [y, y + h, y + h, y], lw=1.5, c=col) 355 | if pvalue < 0.05: 356 | plt.text( 357 | (x1 + x2) * 0.5, 358 | y + h * 2, 359 | f"p = {pvalue:.2e}", 360 | ha="center", 361 | va="bottom", 362 | color=col, 363 | fontdict={"fontsize": fontsize}, 364 | ) 365 | else: 366 | plt.text( 367 | (x1 + x2) * 0.5, 368 | y + h * 2, 369 | "ns", 370 | ha="center", 371 | va="bottom", 372 | color=col, 373 | fontdict={"fontsize": fontsize}, 374 | ) 375 | if title is not None: 376 | plt.title(title, fontdict={"fontsize": fontsize}) 377 | plt.show() 378 | -------------------------------------------------------------------------------- /src/cellcharter/pl/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import copy 4 | from types import MappingProxyType 5 | from typing import Any, Mapping 6 | 7 | import matplotlib as mpl 8 | import numpy as np 9 | import seaborn as sns 10 | import squidpy as sq 11 | from anndata import AnnData 12 | from matplotlib import colors as mcolors 13 | from matplotlib import pyplot as plt 14 | from matplotlib.axes import Axes 15 | from matplotlib.patches import PathPatch 16 | from mpl_toolkits.axes_grid1 import make_axes_locatable 17 | from scipy.cluster import hierarchy as sch 18 | from squidpy._constants._pkg_constants import Key 19 | 20 | try: 21 | from matplotlib import colormaps as cm 22 | except ImportError: 23 | from matplotlib import cm 24 | 25 | 26 | def _get_cmap_norm( 27 | adata: AnnData, 28 | key: str, 29 | order: tuple[list[int], list[int]] | None | None = None, 30 | ) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, mcolors.BoundaryNorm, mcolors.BoundaryNorm, int]: 31 | n_rows = n_cols = adata.obs[key].nunique() 32 | 33 | colors = adata.uns[Key.uns.colors(key)] 34 | 35 | if order is not None: 36 | row_order, col_order = order 37 | row_colors = [colors[i] for i in row_order] 38 | col_colors = [colors[i] for i in col_order] 39 | 40 | n_rows = len(row_order) 41 | n_cols = len(col_order) 42 | else: 43 | row_colors = col_colors = colors 44 | 45 | row_cmap = mcolors.ListedColormap(row_colors) 46 | col_cmap = mcolors.ListedColormap(col_colors) 47 | row_norm = mcolors.BoundaryNorm(np.arange(n_rows + 1), row_cmap.N) 48 | col_norm = mcolors.BoundaryNorm(np.arange(n_cols + 1), col_cmap.N) 49 | 50 | return row_cmap, col_cmap, row_norm, col_norm, n_rows 51 | 52 | 53 | def _heatmap( 54 | adata: AnnData, 55 | key: str, 56 | rows: list[str] | None = None, 57 | cols: list[str] | None = None, 58 | title: str = "", 59 | method: str | None = None, 60 | cont_cmap: str | mcolors.Colormap = "bwr", 61 | annotate: bool = True, 62 | fontsize: int | None = None, 63 | figsize: tuple[float, float] | None = None, 64 | dpi: int | None = None, 65 | cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), 66 | ax: Axes | None = None, 67 | n_digits: int = 2, 68 | show_cols: bool = True, 69 | show_rows: bool = True, 70 | **kwargs: Any, 71 | ) -> mpl.figure.Figure: 72 | cbar_kwargs = dict(cbar_kwargs) 73 | 74 | if fontsize is not None: 75 | kwargs["annot_kws"] = {"fontdict": {"fontsize": fontsize}} 76 | 77 | if ax is None: 78 | fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize) 79 | else: 80 | fig = ax.figure 81 | 82 | if method is not None: 83 | row_order, col_order, row_link, col_link = sq.pl._utils._dendrogram( 84 | adata.X, method, optimal_ordering=adata.n_obs <= 1500 85 | ) 86 | else: 87 | row_order = ( 88 | np.arange(len(adata.obs[key])) 89 | if rows is None 90 | else np.argwhere(adata.obs.index.isin(np.array(rows).astype(str))).flatten() 91 | ) 92 | col_order = ( 93 | np.arange(len(adata.var_names)) 94 | if cols is None 95 | else np.argwhere(adata.var_names.isin(np.array(cols).astype(str))).flatten() 96 | ) 97 | 98 | row_order = row_order[::-1] 99 | row_labels = adata.obs[key].iloc[row_order] 100 | col_labels = adata.var_names[col_order] 101 | 102 | data = adata[row_order, col_order].copy().X 103 | 104 | # row_cmap, col_cmap, row_norm, col_norm, n_cls = sq.pl._utils._get_cmap_norm(adata, key, order=(row_order, len(row_order) + col_order)) 105 | row_cmap, col_cmap, row_norm, col_norm, n_cls = _get_cmap_norm(adata, key, order=(row_order, col_order)) 106 | col_norm = mcolors.BoundaryNorm(np.arange(len(col_order) + 1), col_cmap.N) 107 | 108 | row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm) 109 | col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm) 110 | 111 | vmin = kwargs.pop("vmin", np.nanmin(data)) 112 | vmax = kwargs.pop("vmax", np.nanmax(data)) 113 | vcenter = kwargs.pop("vcenter", 0) 114 | norm = mpl.colors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax) 115 | cont_cmap = copy(cm.get_cmap(cont_cmap)) 116 | cont_cmap.set_bad(color="grey") 117 | 118 | annot = np.round(data[::-1], n_digits).astype(str) if annotate else None 119 | if "significant" in adata.layers: 120 | significant = adata.layers["significant"].astype(str) 121 | annot = np.char.add(np.empty_like(data[::-1], dtype=str), significant[row_order[:, None], col_order][::-1]) 122 | 123 | ax = sns.heatmap( 124 | data[::-1], 125 | cmap=cont_cmap, 126 | norm=norm, 127 | ax=ax, 128 | square=True, 129 | annot=annot, 130 | cbar=False, 131 | fmt="", 132 | **kwargs, 133 | ) 134 | 135 | for _, spine in ax.spines.items(): 136 | spine.set_visible(True) 137 | 138 | ax.tick_params(top=False, bottom=False, labeltop=False, labelbottom=False) 139 | ax.set_xticks([]) 140 | ax.set_yticks([]) 141 | 142 | divider = make_axes_locatable(ax) 143 | row_cats = divider.append_axes("left", size=0.1, pad=0.1) 144 | col_cats = divider.append_axes("bottom", size=0.1, pad=0.1) 145 | cax = divider.append_axes("right", size="2%", pad=0.1) 146 | 147 | if method is not None: # cluster rows but don't plot dendrogram 148 | col_ax = divider.append_axes("top", size="5%") 149 | sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black") 150 | col_ax.axis("off") 151 | 152 | c = fig.colorbar( 153 | mpl.cm.ScalarMappable(norm=norm, cmap=cont_cmap), 154 | cax=cax, 155 | ticks=np.linspace(norm.vmin, norm.vmax, 10), 156 | orientation="vertical", 157 | format="%0.2f", 158 | **cbar_kwargs, 159 | ) 160 | c.ax.tick_params(labelsize=fontsize) 161 | 162 | # column labels colorbar 163 | c = fig.colorbar(col_sm, cax=col_cats, orientation="horizontal", ticklocation="bottom") 164 | 165 | if rows == cols or show_cols is False: 166 | c.set_ticks([]) 167 | c.set_ticklabels([]) 168 | else: 169 | c.set_ticks(np.arange(len(col_labels)) + 0.5) 170 | c.set_ticklabels(col_labels, fontdict={"fontsize": fontsize}) 171 | if np.any([len(l) > 3 for l in col_labels]): 172 | c.ax.tick_params(rotation=90) 173 | c.outline.set_visible(False) 174 | 175 | # row labels colorbar 176 | c = fig.colorbar(row_sm, cax=row_cats, orientation="vertical", ticklocation="left") 177 | if show_rows is False: 178 | c.set_ticks([]) 179 | c.set_ticklabels([]) 180 | else: 181 | c.set_ticks(np.arange(n_cls) + 0.5) 182 | c.set_ticklabels(row_labels, fontdict={"fontsize": fontsize}) 183 | c.set_label(key, fontsize=fontsize) 184 | c.outline.set_visible(False) 185 | 186 | ax.set_title(title, fontdict={"fontsize": fontsize}) 187 | 188 | return fig, ax 189 | 190 | 191 | def _reorder(values, order, axis=1): 192 | if axis == 0: 193 | values = values.iloc[order, :] 194 | elif axis == 1: 195 | values = values.iloc[:, order] 196 | else: 197 | raise ValueError("The axis parameter accepts only values 0 and 1.") 198 | return values 199 | 200 | 201 | def _clip(values, min_threshold=None, max_threshold=None, new_min=None, new_max=None, new_middle=None): 202 | values_clipped = values.copy() 203 | if new_middle is not None: 204 | values_clipped[:] = new_middle 205 | if min_threshold is not None: 206 | values_clipped[values < min_threshold] = new_min if new_min is not None else min_threshold 207 | if max_threshold is not None: 208 | values_clipped[values > max_threshold] = new_max if new_max is not None else max_threshold 209 | return values_clipped 210 | 211 | 212 | def adjust_box_widths(g, fac): 213 | """Adjust the widths of a seaborn-generated boxplot.""" 214 | # iterating through Axes instances 215 | for ax in g.axes: 216 | # iterating through axes artists: 217 | for c in ax.get_children(): 218 | # searching for PathPatches 219 | if isinstance(c, PathPatch): 220 | # getting current width of box: 221 | p = c.get_path() 222 | verts = p.vertices 223 | verts_sub = verts[:-1] 224 | xmin = np.min(verts_sub[:, 0]) 225 | xmax = np.max(verts_sub[:, 0]) 226 | xmid = 0.5 * (xmin + xmax) 227 | xhalf = 0.5 * (xmax - xmin) 228 | 229 | # setting new width of box 230 | xmin_new = xmid - fac * xhalf 231 | xmax_new = xmid + fac * xhalf 232 | verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new 233 | verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new 234 | 235 | # setting new width of median line 236 | for l in ax.lines: 237 | if np.all(l.get_xdata() == [xmin, xmax]): 238 | l.set_xdata([xmin_new, xmax_new]) 239 | -------------------------------------------------------------------------------- /src/cellcharter/tl/__init__.py: -------------------------------------------------------------------------------- 1 | from ._autok import ClusterAutoK 2 | from ._gmm import Cluster, GaussianMixture 3 | from ._shape import boundaries, curl, elongation, linearity, purity 4 | from ._trvae import TRVAE 5 | -------------------------------------------------------------------------------- /src/cellcharter/tl/_gmm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import List, Tuple, cast 5 | 6 | import anndata as ad 7 | import numpy as np 8 | import pandas as pd 9 | import scipy.sparse as sps 10 | import torch 11 | from pytorch_lightning import Trainer 12 | from torchgmm.base.data import ( 13 | DataLoader, 14 | TensorLike, 15 | collate_tensor, 16 | dataset_from_tensors, 17 | ) 18 | from torchgmm.bayes import GaussianMixture as TorchGaussianMixture 19 | from torchgmm.bayes.gmm.lightning_module import GaussianMixtureLightningModule 20 | from torchgmm.bayes.gmm.model import GaussianMixtureModel 21 | 22 | from .._utils import AnyRandom 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class GaussianMixture(TorchGaussianMixture): 28 | """ 29 | Adapted version of GaussianMixture clustering model from the `torchgmm `_ library. 30 | 31 | Parameters 32 | ---------- 33 | n_clusters 34 | The number of components in the GMM. The dimensionality of each component is automatically inferred from the data. 35 | covariance_type 36 | The type of covariance to assume for all Gaussian components. 37 | init_strategy 38 | The strategy for initializing component means and covariances. 39 | init_means 40 | An optional initial guess for the means of the components. If provided, 41 | must be a tensor of shape ``[num_components, num_features]``. If this is given, 42 | the ``init_strategy`` is ignored and the means are handled as if K-means 43 | initialization has been run. 44 | convergence_tolerance 45 | The change in the per-datapoint negative log-likelihood which 46 | implies that training has converged. 47 | covariance_regularization 48 | A small value which is added to the diagonal of the 49 | covariance matrix to ensure that it is positive semi-definite. 50 | batch_size: The batch size to use when fitting the model. If not provided, the full 51 | data will be used as a single batch. Set this if the full data does not fit into 52 | memory. 53 | trainer_params 54 | Initialization parameters to use when initializing a PyTorch Lightning 55 | trainer. By default, it disables various stdout logs unless TorchGMM is configured to 56 | do verbose logging. Checkpointing and logging are disabled regardless of the log 57 | level. This estimator further sets the following overridable defaults: 58 | - ``max_epochs=100``. 59 | random_state 60 | Initialization seed. 61 | 62 | """ 63 | 64 | #: The fitted PyTorch module with all estimated parameters. 65 | model_: GaussianMixtureModel 66 | #: A boolean indicating whether the model converged during training. 67 | converged_: bool 68 | #: The number of iterations the model was fitted for, excluding initialization. 69 | num_iter_: int 70 | #: The average per-datapoint negative log-likelihood at the last training step. 71 | nll_: float 72 | 73 | def __init__( 74 | self, 75 | n_clusters: int = 1, 76 | *, 77 | covariance_type: str = "full", 78 | init_strategy: str = "kmeans", 79 | init_means: torch.Tensor = None, 80 | convergence_tolerance: float = 0.001, 81 | covariance_regularization: float = 1e-06, 82 | batch_size: int = None, 83 | trainer_params: dict = None, 84 | random_state: AnyRandom = 0, 85 | ): 86 | super().__init__( 87 | num_components=n_clusters, 88 | covariance_type=covariance_type, 89 | init_strategy=init_strategy, 90 | init_means=init_means, 91 | convergence_tolerance=convergence_tolerance, 92 | covariance_regularization=covariance_regularization, 93 | batch_size=batch_size, 94 | trainer_params=trainer_params, 95 | ) 96 | self.n_clusters = n_clusters 97 | self.random_state = random_state 98 | 99 | def fit(self, data: TensorLike) -> GaussianMixture: 100 | """ 101 | Fits the Gaussian mixture on the provided data, estimating component priors, means and covariances. Parameters are estimated using the EM algorithm. 102 | 103 | Parameters 104 | ---------- 105 | data 106 | The tabular data to fit on. The dimensionality of the Gaussian mixture is automatically inferred from this data. 107 | Returns 108 | ---------- 109 | The fitted Gaussian mixture. 110 | """ 111 | if sps.issparse(data): 112 | raise ValueError( 113 | "Sparse data is not supported. You may have forgotten to reduce the dimensionality of the data. Otherwise, please convert the data to a dense format." 114 | ) 115 | return self._fit(data) 116 | 117 | def _fit(self, data) -> GaussianMixture: 118 | try: 119 | return super().fit(data) 120 | except torch._C._LinAlgError as e: 121 | if self.covariance_regularization >= 1: 122 | raise ValueError( 123 | "Cholesky decomposition failed even with covariance regularization = 1. The matrix may be singular." 124 | ) from e 125 | else: 126 | self.covariance_regularization *= 10 127 | logger.warning( 128 | f"Cholesky decomposition failed. Retrying with covariance regularization {self.covariance_regularization}." 129 | ) 130 | return self._fit(data) 131 | 132 | def predict(self, data: TensorLike) -> torch.Tensor: 133 | """ 134 | Computes the most likely components for each of the provided datapoints. 135 | 136 | Parameters 137 | ---------- 138 | data 139 | The datapoints for which to obtain the most likely components. 140 | Returns 141 | ---------- 142 | A tensor of shape ``[num_datapoints]`` with the indices of the most likely components. 143 | Note 144 | ---------- 145 | Use :func:`predict_proba` to obtain probabilities for each component instead of the 146 | most likely component only. 147 | Attention 148 | ---------- 149 | When calling this function in a multi-process environment, each process receives only 150 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 151 | the values returned from this method. 152 | """ 153 | return super().predict(data).numpy() 154 | 155 | def predict_proba(self, data: TensorLike) -> torch.Tensor: 156 | """ 157 | Computes a distribution over the components for each of the provided datapoints. 158 | 159 | Parameters 160 | ---------- 161 | data 162 | The datapoints for which to compute the component assignment probabilities. 163 | Returns 164 | ---------- 165 | A tensor of shape ``[num_datapoints, num_components]`` with the assignment 166 | probabilities for each component and datapoint. Note that each row of the vector sums 167 | to 1, i.e. the returned tensor provides a proper distribution over the components for 168 | each datapoint. 169 | Attention 170 | ---------- 171 | When calling this function in a multi-process environment, each process receives only 172 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 173 | the values returned from this method. 174 | """ 175 | loader = DataLoader( 176 | dataset_from_tensors(data), 177 | batch_size=self.batch_size or len(data), 178 | collate_fn=collate_tensor, 179 | ) 180 | trainer_params = self.trainer_params.copy() 181 | trainer_params["logger"] = False 182 | result = Trainer(**trainer_params).predict(GaussianMixtureLightningModule(self.model_), loader) 183 | return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)]) 184 | 185 | def score_samples(self, data: TensorLike) -> torch.Tensor: 186 | """ 187 | Computes the negative log-likelihood (NLL) of each of the provided datapoints. 188 | 189 | Parameters 190 | ---------- 191 | data 192 | The datapoints for which to compute the NLL. 193 | Returns 194 | ---------- 195 | A tensor of shape ``[num_datapoints]`` with the NLL for each datapoint. 196 | Attention 197 | ---------- 198 | When calling this function in a multi-process environment, each process receives only 199 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 200 | the values returned from this method. 201 | """ 202 | loader = DataLoader( 203 | dataset_from_tensors(data), 204 | batch_size=self.batch_size or len(data), 205 | collate_fn=collate_tensor, 206 | ) 207 | trainer_params = self.trainer_params.copy() 208 | trainer_params["logger"] = False 209 | result = Trainer(**trainer_params).predict(GaussianMixtureLightningModule(self.model_), loader) 210 | return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)]) 211 | 212 | 213 | class Cluster(GaussianMixture): 214 | """ 215 | Cluster cells or spots based on the neighborhood aggregated features from CellCharter. 216 | 217 | Parameters 218 | ---------- 219 | n_clusters 220 | The number of components in the GMM. The dimensionality of each component is automatically inferred from the data. 221 | covariance_type 222 | The type of covariance to assume for all Gaussian components. 223 | init_strategy 224 | The strategy for initializing component means and covariances. 225 | init_means 226 | An optional initial guess for the means of the components. If provided, 227 | must be a tensor of shape ``[num_components, num_features]``. If this is given, 228 | the ``init_strategy`` is ignored and the means are handled as if K-means 229 | initialization has been run. 230 | convergence_tolerance 231 | The change in the per-datapoint negative log-likelihood which 232 | implies that training has converged. 233 | covariance_regularization 234 | A small value which is added to the diagonal of the 235 | covariance matrix to ensure that it is positive semi-definite. 236 | batch_size: The batch size to use when fitting the model. If not provided, the full 237 | data will be used as a single batch. Set this if the full data does not fit into 238 | memory. 239 | trainer_params 240 | Initialization parameters to use when initializing a PyTorch Lightning 241 | trainer. By default, it disables various stdout logs unless TorchGMM is configured to 242 | do verbose logging. Checkpointing and logging are disabled regardless of the log 243 | level. This estimator further sets the following overridable defaults: 244 | - ``max_epochs=100``. 245 | random_state 246 | Initialization seed. 247 | Examples 248 | -------- 249 | >>> adata = anndata.read_h5ad(path_to_anndata) 250 | >>> sq.gr.spatial_neighbors(adata, coord_type='generic', delaunay=True) 251 | >>> cc.gr.remove_long_links(adata) 252 | >>> cc.gr.aggregate_neighbors(adata, n_layers=3) 253 | >>> model = cc.tl.Cluster(n_clusters=11) 254 | >>> model.fit(adata, use_rep='X_cellcharter') 255 | """ 256 | 257 | def __init__( 258 | self, 259 | n_clusters: int = 1, 260 | *, 261 | covariance_type: str = "full", 262 | init_strategy: str = "kmeans", 263 | init_means: torch.Tensor = None, 264 | convergence_tolerance: float = 0.001, 265 | covariance_regularization: float = 1e-06, 266 | batch_size: int = None, 267 | trainer_params: dict = None, 268 | random_state: AnyRandom = 0, 269 | ): 270 | super().__init__( 271 | n_clusters=n_clusters, 272 | covariance_type=covariance_type, 273 | init_strategy=init_strategy, 274 | init_means=init_means, 275 | convergence_tolerance=convergence_tolerance, 276 | covariance_regularization=covariance_regularization, 277 | batch_size=batch_size, 278 | trainer_params=trainer_params, 279 | random_state=random_state, 280 | ) 281 | 282 | def fit(self, adata: ad.AnnData, use_rep: str = "X_cellcharter"): 283 | """ 284 | Fit data into ``n_clusters`` clusters. 285 | 286 | Parameters 287 | ---------- 288 | adata 289 | Annotated data object. 290 | use_rep 291 | Key in :attr:`anndata.AnnData.obsm` to use as data to fit the clustering model. 292 | """ 293 | logging_level = logging.root.level 294 | 295 | X = adata.X if use_rep is None else adata.obsm[use_rep] 296 | 297 | logging_level = logging.getLogger("lightning.pytorch").getEffectiveLevel() 298 | logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) 299 | 300 | super().fit(X) 301 | 302 | logging.getLogger("lightning.pytorch").setLevel(logging_level) 303 | 304 | adata.uns["_cellcharter"] = {k: v for k, v in self.get_params().items() if k != "init_means"} 305 | 306 | def predict(self, adata: ad.AnnData, use_rep: str = "X_cellcharter") -> pd.Categorical: 307 | """ 308 | Predict the labels for the data in ``use_rep`` using the fitted model. 309 | 310 | Parameters 311 | ---------- 312 | adata 313 | Annotated data object. 314 | use_rep 315 | Key in :attr:`anndata.AnnData.obsm` used as data to fit the clustering model. If ``None``, uses :attr:`anndata.AnnData.X`. 316 | k 317 | Number of clusters to predict using the fitted model. If ``None``, the number of clusters with the highest stability will be selected. If ``max_runs > 1``, the model with the largest marginal likelihood will be used among the ones fitted on ``k``. 318 | """ 319 | X = adata.X if use_rep is None else adata.obsm[use_rep] 320 | return pd.Categorical(super().predict(X), categories=np.arange(self.n_clusters)) 321 | -------------------------------------------------------------------------------- /src/cellcharter/tl/_trvae.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from typing import Optional 5 | 6 | from anndata import AnnData, read_h5ad 7 | from torch import nn 8 | 9 | try: 10 | from scarches.models import TRVAE as scaTRVAE 11 | from scarches.models import trVAE 12 | from scarches.models.base._utils import _validate_var_names 13 | except ImportError: 14 | 15 | class TRVAE: 16 | r""" 17 | scArches\'s trVAE model adapted to image-based proteomics data. 18 | 19 | The last ReLU layer of the neural network is removed to allow for continuous and real output values 20 | 21 | Parameters 22 | ---------- 23 | adata: `~anndata.AnnData` 24 | Annotated data matrix. Has to be count data for 'nb' and 'zinb' loss and normalized log transformed data 25 | for 'mse' loss. 26 | condition_key: String 27 | column name of conditions in `adata.obs` data frame. 28 | conditions: List 29 | List of Condition names that the used data will contain to get the right encoding when used after reloading. 30 | hidden_layer_sizes: List 31 | A list of hidden layer sizes for encoder network. Decoder network will be the reversed order. 32 | latent_dim: Integer 33 | Bottleneck layer (z) size. 34 | dr_rate: Float 35 | Dropout rate applied to all layers, if `dr_rate==0` no dropout will be applied. 36 | use_mmd: Boolean 37 | If 'True' an additional MMD loss will be calculated on the latent dim. 'z' or the first decoder layer 'y'. 38 | mmd_on: String 39 | Choose on which layer MMD loss will be calculated on if 'use_mmd=True': 'z' for latent dim or 'y' for first 40 | decoder layer. 41 | mmd_boundary: Integer or None 42 | Choose on how many conditions the MMD loss should be calculated on. If 'None' MMD will be calculated on all 43 | conditions. 44 | recon_loss: String 45 | Definition of Reconstruction-Loss-Method, 'mse', 'nb' or 'zinb'. 46 | beta: Float 47 | Scaling Factor for MMD loss 48 | use_bn: Boolean 49 | If `True` batch normalization will be applied to layers. 50 | use_ln: Boolean 51 | If `True` layer normalization will be applied to layers. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | adata: AnnData, 57 | condition_key: str = None, 58 | conditions: Optional[list] = None, 59 | hidden_layer_sizes: list | tuple = (256, 64), 60 | latent_dim: int = 10, 61 | dr_rate: float = 0.05, 62 | use_mmd: bool = True, 63 | mmd_on: str = "z", 64 | mmd_boundary: Optional[int] = None, 65 | recon_loss: Optional[str] = "nb", 66 | beta: float = 1, 67 | use_bn: bool = False, 68 | use_ln: bool = True, 69 | ): 70 | raise ImportError("scarches is not installed. Please install scarches to use this method.") 71 | 72 | @classmethod 73 | def load(cls, dir_path: str, adata: Optional[AnnData] = None, map_location: Optional[str] = None): 74 | """ 75 | Instantiate a model from the saved output. 76 | 77 | Parameters 78 | ---------- 79 | dir_path 80 | Path to saved outputs. 81 | adata 82 | AnnData object. 83 | If None, will check for and load anndata saved with the model. 84 | map_location 85 | Location where all tensors should be loaded (e.g., `torch.device('cpu')`) 86 | Returns 87 | ------- 88 | Model with loaded state dictionaries. 89 | """ 90 | raise ImportError("scarches is not installed. Please install scarches to use this method.") 91 | 92 | else: 93 | 94 | class TRVAE(scaTRVAE): 95 | r""" 96 | scArches\'s trVAE model adapted to image-based proteomics data. 97 | 98 | The last ReLU layer of the neural network is removed to allow for continuous and real output values 99 | 100 | Parameters 101 | ---------- 102 | adata: `~anndata.AnnData` 103 | Annotated data matrix. Has to be count data for 'nb' and 'zinb' loss and normalized log transformed data 104 | for 'mse' loss. 105 | condition_key: String 106 | column name of conditions in `adata.obs` data frame. 107 | conditions: List 108 | List of Condition names that the used data will contain to get the right encoding when used after reloading. 109 | hidden_layer_sizes: List 110 | A list of hidden layer sizes for encoder network. Decoder network will be the reversed order. 111 | latent_dim: Integer 112 | Bottleneck layer (z) size. 113 | dr_rate: Float 114 | Dropout rate applied to all layers, if `dr_rate==0` no dropout will be applied. 115 | use_mmd: Boolean 116 | If 'True' an additional MMD loss will be calculated on the latent dim. 'z' or the first decoder layer 'y'. 117 | mmd_on: String 118 | Choose on which layer MMD loss will be calculated on if 'use_mmd=True': 'z' for latent dim or 'y' for first 119 | decoder layer. 120 | mmd_boundary: Integer or None 121 | Choose on how many conditions the MMD loss should be calculated on. If 'None' MMD will be calculated on all 122 | conditions. 123 | recon_loss: String 124 | Definition of Reconstruction-Loss-Method, 'mse', 'nb' or 'zinb'. 125 | beta: Float 126 | Scaling Factor for MMD loss 127 | use_bn: Boolean 128 | If `True` batch normalization will be applied to layers. 129 | use_ln: Boolean 130 | If `True` layer normalization will be applied to layers. 131 | """ 132 | 133 | def __init__( 134 | self, 135 | adata: AnnData, 136 | condition_key: str = None, 137 | conditions: Optional[list] = None, 138 | hidden_layer_sizes: list | tuple = (256, 64), 139 | latent_dim: int = 10, 140 | dr_rate: float = 0.05, 141 | use_mmd: bool = True, 142 | mmd_on: str = "z", 143 | mmd_boundary: Optional[int] = None, 144 | recon_loss: Optional[str] = "mse", 145 | beta: float = 1, 146 | use_bn: bool = False, 147 | use_ln: bool = True, 148 | ): 149 | self.adata = adata 150 | 151 | self.condition_key_ = condition_key 152 | 153 | if conditions is None: 154 | if condition_key is not None: 155 | self.conditions_ = adata.obs[condition_key].unique().tolist() 156 | else: 157 | self.conditions_ = [] 158 | else: 159 | self.conditions_ = conditions 160 | 161 | self.hidden_layer_sizes_ = hidden_layer_sizes 162 | self.latent_dim_ = latent_dim 163 | self.dr_rate_ = dr_rate 164 | self.use_mmd_ = use_mmd 165 | self.mmd_on_ = mmd_on 166 | self.mmd_boundary_ = mmd_boundary 167 | self.recon_loss_ = recon_loss 168 | self.beta_ = beta 169 | self.use_bn_ = use_bn 170 | self.use_ln_ = use_ln 171 | 172 | self.input_dim_ = adata.n_vars 173 | 174 | self.model = trVAE( 175 | self.input_dim_, 176 | self.conditions_, 177 | list(self.hidden_layer_sizes_), 178 | self.latent_dim_, 179 | self.dr_rate_, 180 | self.use_mmd_, 181 | self.mmd_on_, 182 | self.mmd_boundary_, 183 | self.recon_loss_, 184 | self.beta_, 185 | self.use_bn_, 186 | self.use_ln_, 187 | ) 188 | 189 | decoder_layer_sizes = self.model.hidden_layer_sizes.copy() 190 | decoder_layer_sizes.reverse() 191 | decoder_layer_sizes.append(self.model.input_dim) 192 | 193 | self.model.decoder.recon_decoder = nn.Linear(decoder_layer_sizes[-2], decoder_layer_sizes[-1]) 194 | 195 | self.is_trained_ = False 196 | 197 | self.trainer = None 198 | 199 | @classmethod 200 | def load(cls, dir_path: str, adata: Optional[AnnData] = None, map_location: Optional[str] = None): 201 | """ 202 | Instantiate a model from the saved output. 203 | 204 | Parameters 205 | ---------- 206 | dir_path 207 | Path to saved outputs. 208 | adata 209 | AnnData object. 210 | If None, will check for and load anndata saved with the model. 211 | map_location 212 | Location where all tensors should be loaded (e.g., `torch.device('cpu')`) 213 | Returns 214 | ------- 215 | Model with loaded state dictionaries. 216 | """ 217 | adata_path = os.path.join(dir_path, "adata.h5ad") 218 | 219 | load_adata = adata is None 220 | 221 | if os.path.exists(adata_path) and load_adata: 222 | adata = read_h5ad(adata_path) 223 | elif not os.path.exists(adata_path) and load_adata: 224 | raise ValueError("Save path contains no saved anndata and no adata was passed.") 225 | 226 | attr_dict, model_state_dict, var_names = cls._load_params(dir_path, map_location=map_location) 227 | 228 | # Overwrite adata with new genes 229 | adata = _validate_var_names(adata, var_names) 230 | 231 | cls._validate_adata(adata, attr_dict) 232 | init_params = cls._get_init_params_from_dict(attr_dict) 233 | 234 | model = cls(adata, **init_params) 235 | model.model.load_state_dict(model_state_dict) 236 | model.model.eval() 237 | 238 | model.is_trained_ = attr_dict["is_trained_"] 239 | 240 | return model 241 | -------------------------------------------------------------------------------- /src/cellcharter/tl/_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | 3 | import numpy as np 4 | from joblib import Parallel, delayed 5 | from sklearn.metrics import adjusted_rand_score 6 | 7 | 8 | def _stability(labels, similarity_function=adjusted_rand_score, n_jobs=-1): 9 | clusters = list(labels.keys()) 10 | max_runs = len(labels[clusters[0]]) 11 | num_combinations = max_runs * (max_runs - 1) // 2 12 | 13 | stabilities = Parallel(n_jobs=n_jobs)( 14 | delayed(similarity_function)(labels[clusters[k]][i], labels[clusters[k] + 1][j]) 15 | for k in range(len(clusters) - 1) 16 | for i, j in combinations(range(max_runs), 2) 17 | ) 18 | 19 | # Transform test into a list of chunks of size num_combinations 20 | stabilities = [stabilities[i : i + num_combinations] for i in range(0, len(stabilities), num_combinations)] 21 | 22 | # Append to every element of test the previous element 23 | stabilities = [stabilities[i] + stabilities[i - 1] for i in range(1, len(stabilities))] 24 | 25 | return np.array(stabilities) 26 | -------------------------------------------------------------------------------- /tests/_data/test_data.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_data/test_data.h5ad -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 1, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 10, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 11, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 12, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 13, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 14, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 15, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 16, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 2, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 3, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 4, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 5, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 6, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 7, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 8, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 9, 3 | "num_features": 136, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_imc/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/params.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 1, 3 | "num_features": 144, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_clusters": 1, 3 | "covariance_type": "full", 4 | "init_strategy": "kmeans++", 5 | "init_means": null, 6 | "convergence_tolerance": 0.001, 7 | "covariance_regularization": 1e-6, 8 | "batch_size": null, 9 | "trainer_params": { 10 | "logger": false, 11 | "log_every_n_steps": 1, 12 | "enable_progress_bar": false, 13 | "enable_checkpointing": false, 14 | "enable_model_summary": false, 15 | "max_epochs": 100, 16 | "accelerator": "cpu" 17 | }, 18 | "random_state": 42 19 | } 20 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 2, 3 | "num_features": 144, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_clusters": 2, 3 | "covariance_type": "full", 4 | "init_strategy": "kmeans++", 5 | "init_means": null, 6 | "convergence_tolerance": 0.001, 7 | "covariance_regularization": 1e-6, 8 | "batch_size": null, 9 | "trainer_params": { 10 | "logger": false, 11 | "log_every_n_steps": 1, 12 | "enable_progress_bar": false, 13 | "enable_checkpointing": false, 14 | "enable_model_summary": false, 15 | "max_epochs": 100, 16 | "accelerator": "cpu" 17 | }, 18 | "random_state": 44 19 | } 20 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 3, 3 | "num_features": 144, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_clusters": 3, 3 | "covariance_type": "full", 4 | "init_strategy": "kmeans++", 5 | "init_means": null, 6 | "convergence_tolerance": 0.001, 7 | "covariance_regularization": 1e-6, 8 | "batch_size": null, 9 | "trainer_params": { 10 | "logger": false, 11 | "log_every_n_steps": 1, 12 | "enable_progress_bar": false, 13 | "enable_checkpointing": false, 14 | "enable_model_summary": false, 15 | "max_epochs": 100, 16 | "accelerator": "cpu" 17 | }, 18 | "random_state": 44 19 | } 20 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 4, 3 | "num_features": 144, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_clusters": 4, 3 | "covariance_type": "full", 4 | "init_strategy": "kmeans++", 5 | "init_means": null, 6 | "convergence_tolerance": 0.001, 7 | "covariance_regularization": 1e-6, 8 | "batch_size": null, 9 | "trainer_params": { 10 | "logger": false, 11 | "log_every_n_steps": 1, 12 | "enable_progress_bar": false, 13 | "enable_checkpointing": false, 14 | "enable_model_summary": false, 15 | "max_epochs": 100, 16 | "accelerator": "cpu" 17 | }, 18 | "random_state": 42 19 | } 20 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 5, 3 | "num_features": 144, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_clusters": 5, 3 | "covariance_type": "full", 4 | "init_strategy": "kmeans++", 5 | "init_means": null, 6 | "convergence_tolerance": 0.001, 7 | "covariance_regularization": 1e-6, 8 | "batch_size": null, 9 | "trainer_params": { 10 | "logger": false, 11 | "log_every_n_steps": 1, 12 | "enable_progress_bar": false, 13 | "enable_checkpointing": false, 14 | "enable_model_summary": false, 15 | "max_epochs": 100, 16 | "accelerator": "cpu" 17 | }, 18 | "random_state": 44 19 | } 20 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/attributes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/attributes.pickle -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_components": 6, 3 | "num_features": 144, 4 | "covariance_type": "full" 5 | } 6 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/model/parameters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/model/parameters.pt -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_clusters": 6, 3 | "covariance_type": "full", 4 | "init_strategy": "kmeans++", 5 | "init_means": null, 6 | "convergence_tolerance": 0.001, 7 | "covariance_regularization": 1e-6, 8 | "batch_size": null, 9 | "trainer_params": { 10 | "logger": false, 11 | "log_every_n_steps": 1, 12 | "enable_progress_bar": false, 13 | "enable_checkpointing": false, 14 | "enable_model_summary": false, 15 | "max_epochs": 100, 16 | "accelerator": "cpu" 17 | }, 18 | "random_state": 43 19 | } 20 | -------------------------------------------------------------------------------- /tests/_models/cellcharter_autok_mibitof/params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/params.pickle -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import time 2 | from urllib.error import HTTPError 3 | 4 | import anndata as ad 5 | import numpy as np 6 | import pytest 7 | import scanpy as sc 8 | from squidpy._constants._pkg_constants import Key 9 | 10 | _adata = sc.read("tests/_data/test_data.h5ad") 11 | _adata.raw = _adata.copy() 12 | 13 | 14 | @pytest.fixture() 15 | def non_visium_adata() -> ad.AnnData: 16 | non_visium_coords = np.array([[1, 0], [3, 0], [5, 6], [0, 4]]) 17 | adata = ad.AnnData(X=non_visium_coords, dtype=int) 18 | adata.obsm[Key.obsm.spatial] = non_visium_coords 19 | return adata 20 | 21 | 22 | @pytest.fixture() 23 | def adata() -> ad.AnnData: 24 | return _adata.copy() 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def codex_adata() -> ad.AnnData: 29 | max_retries = 3 30 | retry_delay = 5 # seconds 31 | 32 | for attempt in range(max_retries): 33 | try: 34 | adata = sc.read( 35 | "tests/_data/codex_adata.h5ad", backup_url="https://figshare.com/ndownloader/files/46832722" 36 | ) 37 | return adata[adata.obs["sample"].isin(["BALBc-1", "MRL-5"])].copy() 38 | except HTTPError as e: 39 | if attempt == max_retries - 1: # Last attempt 40 | pytest.skip(f"Failed to download test data after {max_retries} attempts: {str(e)}") 41 | time.sleep(retry_delay) 42 | -------------------------------------------------------------------------------- /tests/graph/test_aggregate_neighbors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sps 3 | import squidpy as sq 4 | from anndata import AnnData 5 | 6 | from cellcharter.gr import aggregate_neighbors 7 | 8 | 9 | class TestAggregateNeighbors: 10 | def test_aggregate_neighbors(self): 11 | n_layers = 2 12 | aggregations = ["mean", "var"] 13 | 14 | G = sps.csr_matrix( 15 | np.array( 16 | [ 17 | [0, 1, 1, 1, 1, 0, 0, 0, 0], 18 | [1, 0, 0, 0, 0, 1, 0, 0, 0], 19 | [1, 0, 0, 0, 0, 0, 1, 1, 0], 20 | [1, 0, 0, 0, 1, 0, 0, 0, 1], 21 | [1, 0, 0, 1, 0, 0, 0, 0, 0], 22 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 23 | [0, 0, 1, 0, 0, 0, 0, 1, 0], 24 | [0, 0, 1, 0, 0, 0, 1, 0, 0], 25 | [0, 0, 0, 1, 0, 0, 0, 0, 0], 26 | ] 27 | ) 28 | ) 29 | 30 | X = np.vstack((np.power(2, np.arange(G.shape[0])), np.power(2, np.arange(G.shape[0])[::-1]))).T.astype( 31 | np.float32 32 | ) 33 | 34 | adata = AnnData(X=X, obsp={"spatial_connectivities": G}) 35 | 36 | L1_mean_truth = np.vstack( 37 | ([7.5, 16.5, 64.66, 91, 4.5, 2, 66, 34, 8], [60, 132, 87.33, 91, 144, 128, 33, 34, 32]) 38 | ).T.astype(np.float32) 39 | 40 | L2_mean_truth = np.vstack( 41 | ([120, 9.33, 8.67, 3, 87.33, 1, 1, 1, 8.5], [3.75, 37.33, 58.67, 96, 64.33, 256, 256, 256, 136]) 42 | ).T.astype(np.float32) 43 | 44 | L1_var_truth = np.vstack( 45 | ([5.36, 15.5, 51.85, 116.83, 3.5, 0, 62, 30, 0], [42.90, 124, 119.27, 116.83, 112, 0, 31, 30, 0]) 46 | ).T.astype(np.float32) 47 | 48 | L2_var_truth = np.vstack( 49 | ([85.79, 4.99, 5.73, 1.0, 119.27, 0, 0, 0, 7.5], [2.68, 19.96, 49.46, 32, 51.85, 0, 0, 0, 120]) 50 | ).T.astype(np.float32) 51 | 52 | aggregate_neighbors(adata, n_layers=n_layers, aggregations=["mean", "var"]) 53 | 54 | assert adata.obsm["X_cellcharter"].shape == (adata.shape[0], X.shape[1] * (n_layers * len(aggregations) + 1)) 55 | 56 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [0, 1]], X, rtol=0.01) 57 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [2, 3]], L1_mean_truth, rtol=0.01) 58 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [4, 5]], L1_var_truth**2, rtol=0.01) 59 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [6, 7]], L2_mean_truth, rtol=0.01) 60 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [8, 9]], L2_var_truth**2, rtol=0.01) 61 | 62 | def test_aggregations(self, adata: AnnData): 63 | sq.gr.spatial_neighbors(adata) 64 | aggregate_neighbors(adata, n_layers=3, aggregations="mean", out_key="X_str") 65 | aggregate_neighbors(adata, n_layers=3, aggregations=["mean"], out_key="X_list") 66 | 67 | assert (adata.obsm["X_str"] != adata.obsm["X_list"]).nnz == 0 68 | -------------------------------------------------------------------------------- /tests/graph/test_build.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sps 3 | import squidpy as sq 4 | from anndata import AnnData 5 | from squidpy._constants._pkg_constants import Key 6 | 7 | import cellcharter as cc 8 | 9 | 10 | class TestRemoveLongLinks: 11 | def test_remove_long_links(self, non_visium_adata: AnnData): 12 | # ground-truth removing connections longer that 50th percentile 13 | correct_dist_perc = np.array( 14 | [ 15 | [0.0, 2.0, 0.0, 4.12310563], 16 | [2.0, 0.0, 0, 5.0], 17 | [0.0, 0, 0.0, 0.0], 18 | [4.12310563, 5.0, 0.0, 0.0], 19 | ] 20 | ) 21 | correct_graph_perc = np.array( 22 | [[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]] 23 | ) 24 | 25 | sq.gr.spatial_neighbors(non_visium_adata, coord_type="generic", delaunay=True) 26 | cc.gr.remove_long_links(non_visium_adata, distance_percentile=50) 27 | 28 | spatial_graph = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray() 29 | spatial_dist = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray() 30 | 31 | np.testing.assert_array_equal(spatial_graph, correct_graph_perc) 32 | np.testing.assert_allclose(spatial_dist, correct_dist_perc) 33 | 34 | 35 | class TestRemoveIntraClusterLinks: 36 | def test_mixed_clusters(self, non_visium_adata: AnnData): 37 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 38 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0])) 39 | ) 40 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix( 41 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]] 42 | ) 43 | non_visium_adata.obs["cluster"] = [0, 0, 1, 1] 44 | 45 | correct_conns = np.array([[0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0], [1, 1, 0, 0]]) 46 | correct_dists = np.array([[0, 0, 4, 4], [0, 0, 6, 3], [4, 6, 0, 0], [4, 3, 0, 0]]) 47 | 48 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster") 49 | 50 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray() 51 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray() 52 | 53 | np.testing.assert_array_equal(trimmed_conns, correct_conns) 54 | np.testing.assert_allclose(trimmed_dists, correct_dists) 55 | 56 | def test_same_clusters(self, non_visium_adata: AnnData): 57 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 58 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0])) 59 | ) 60 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix( 61 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]] 62 | ) 63 | non_visium_adata.obs["cluster"] = [0, 0, 0, 0] 64 | 65 | correct_conns = np.zeros((non_visium_adata.shape[0], non_visium_adata.shape[0])) 66 | correct_dists = np.zeros((non_visium_adata.shape[0], non_visium_adata.shape[0])) 67 | 68 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster") 69 | 70 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray() 71 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray() 72 | 73 | np.testing.assert_array_equal(trimmed_conns, correct_conns) 74 | np.testing.assert_allclose(trimmed_dists, correct_dists) 75 | 76 | def test_different_clusters(self, non_visium_adata: AnnData): 77 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 78 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0])) 79 | ) 80 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix( 81 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]] 82 | ) 83 | non_visium_adata.obs["cluster"] = [0, 1, 2, 3] 84 | 85 | correct_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].copy() 86 | correct_conns.setdiag(0) 87 | correct_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()] 88 | 89 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster") 90 | 91 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray() 92 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray() 93 | 94 | np.testing.assert_array_equal(trimmed_conns, correct_conns.toarray()) 95 | np.testing.assert_allclose(trimmed_dists, correct_dists.toarray()) 96 | 97 | def test_copy(self, non_visium_adata: AnnData): 98 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 99 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0])) 100 | ) 101 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix( 102 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]] 103 | ) 104 | non_visium_adata.obs["cluster"] = [0, 0, 1, 1] 105 | 106 | correct_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].copy() 107 | correct_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].copy() 108 | 109 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster", copy=True) 110 | 111 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray() 112 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray() 113 | 114 | np.testing.assert_array_equal(trimmed_conns, correct_conns.toarray()) 115 | np.testing.assert_allclose(trimmed_dists, correct_dists.toarray()) 116 | 117 | 118 | class TestConnectedComponents: 119 | def test_component_present(self, adata: AnnData): 120 | sq.gr.spatial_neighbors(adata, coord_type="grid", n_neighs=6, delaunay=False) 121 | cc.gr.connected_components(adata, min_cells=10) 122 | 123 | assert "component" in adata.obs 124 | 125 | def test_connected_components_no_cluster(self): 126 | adata = AnnData( 127 | X=np.full((4, 2), 1), 128 | ) 129 | 130 | adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 131 | np.array( 132 | [ 133 | [0, 1, 0, 0], 134 | [1, 0, 0, 0], 135 | [0, 0, 0, 1], 136 | [0, 0, 1, 0], 137 | ] 138 | ) 139 | ) 140 | correct_components = np.array([0, 0, 1, 1]) 141 | 142 | components = cc.gr.connected_components(adata, min_cells=0, copy=True) 143 | 144 | np.testing.assert_array_equal(components, correct_components) 145 | 146 | components = cc.gr.connected_components(adata, min_cells=0, copy=False, out_key="comp") 147 | assert "comp" in adata.obs 148 | np.testing.assert_array_equal(adata.obs["comp"].values, correct_components) 149 | 150 | def test_connected_components_cluster(self): 151 | adata = AnnData(X=np.full((4, 2), 1), obs={"cluster": [0, 0, 1, 1]}) 152 | 153 | adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 154 | np.array( 155 | [ 156 | [0, 1, 0, 0], 157 | [1, 0, 0, 1], 158 | [0, 0, 0, 1], 159 | [0, 1, 1, 0], 160 | ] 161 | ) 162 | ) 163 | correct_components = np.array([0, 0, 1, 1]) 164 | 165 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=0, copy=True) 166 | 167 | np.testing.assert_array_equal(components, correct_components) 168 | 169 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=0, copy=False, out_key="comp") 170 | assert "comp" in adata.obs 171 | np.testing.assert_array_equal(adata.obs["comp"].values, correct_components) 172 | 173 | def test_connected_components_min_cells(self): 174 | adata = AnnData(X=np.full((5, 2), 1), obs={"cluster": [0, 0, 0, 1, 1]}) 175 | 176 | adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix( 177 | np.array( 178 | [ 179 | [0, 1, 1, 0, 0], 180 | [1, 0, 0, 0, 0], 181 | [1, 0, 0, 0, 1], 182 | [0, 0, 0, 0, 1], 183 | [0, 0, 1, 1, 0], 184 | ] 185 | ) 186 | ) 187 | correct_components = np.array([0, 0, 0, 1, 1]) 188 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=2, copy=True) 189 | np.testing.assert_array_equal(components, correct_components) 190 | 191 | correct_components = np.array([0, 0, 0, np.nan, np.nan]) 192 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=3, copy=True) 193 | np.testing.assert_array_equal(components, correct_components) 194 | 195 | def test_codex(self, codex_adata: AnnData): 196 | min_cells = 250 197 | correct_number_components = 97 198 | if "component" in codex_adata.obs: 199 | del codex_adata.obs["component"] 200 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=min_cells) 201 | 202 | assert codex_adata.obs["component"].dtype == "category" 203 | assert len(codex_adata.obs["component"].cat.categories) == correct_number_components 204 | for component in codex_adata.obs["component"].cat.categories: 205 | # Check that all components have at least min_cells cells 206 | assert np.sum(codex_adata.obs["component"] == component) >= min_cells 207 | 208 | # Check that all cells in the component are in the same cluster 209 | assert len(codex_adata.obs["cluster_cellcharter"][codex_adata.obs["component"] == component].unique()) == 1 210 | -------------------------------------------------------------------------------- /tests/graph/test_diff_nhood.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from anndata import AnnData 4 | 5 | import cellcharter as cc 6 | 7 | _CLUSTER_KEY = "cell_type" 8 | _CONDITION_KEY = "sample" 9 | key = f"{_CLUSTER_KEY}_{_CONDITION_KEY}_diff_nhood_enrichment" 10 | 11 | 12 | class TestDiffNhoodEnrichment: 13 | def test_enrichment(self, codex_adata: AnnData): 14 | n_conditions = codex_adata.obs[_CONDITION_KEY].cat.categories.shape[0] 15 | cc.gr.diff_nhood_enrichment( 16 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, only_inter=False, log_fold_change=False 17 | ) 18 | 19 | assert len(codex_adata.uns[key]) == n_conditions * (n_conditions - 1) / 2 20 | 21 | for nhood_enrichment in codex_adata.uns[key].values(): 22 | enrichment = nhood_enrichment["enrichment"] 23 | assert np.all((enrichment >= -1) & (enrichment <= 1)) 24 | 25 | del codex_adata.uns[key] 26 | 27 | def test_pvalues(self, codex_adata: AnnData): 28 | n_conditions = codex_adata.obs[_CONDITION_KEY].cat.categories.shape[0] 29 | cc.gr.diff_nhood_enrichment( 30 | codex_adata, 31 | cluster_key=_CLUSTER_KEY, 32 | condition_key=_CONDITION_KEY, 33 | library_key="sample", 34 | only_inter=True, 35 | log_fold_change=True, 36 | pvalues=True, 37 | n_perms=100, 38 | ) 39 | 40 | assert len(codex_adata.uns[key]) == n_conditions * (n_conditions - 1) / 2 41 | 42 | for nhood_enrichment in codex_adata.uns[key].values(): 43 | assert "pvalue" in nhood_enrichment 44 | pvalue = nhood_enrichment["pvalue"] 45 | assert np.all((pvalue >= 0) & (pvalue <= 1)) 46 | 47 | del codex_adata.uns[key] 48 | 49 | def test_symmetric_vs_nonsymmetric(self, codex_adata: AnnData): 50 | # Test symmetric case 51 | cc.gr.diff_nhood_enrichment(codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, symmetric=True) 52 | symmetric_result = codex_adata.uns[key].copy() 53 | del codex_adata.uns[key] 54 | 55 | # Test non-symmetric case 56 | cc.gr.diff_nhood_enrichment( 57 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, symmetric=False 58 | ) 59 | nonsymmetric_result = codex_adata.uns[key] 60 | 61 | # Results should be different when symmetric=False 62 | for pair_key in symmetric_result: 63 | assert not np.allclose( 64 | symmetric_result[pair_key]["enrichment"], nonsymmetric_result[pair_key]["enrichment"], equal_nan=True 65 | ) 66 | 67 | del codex_adata.uns[key] 68 | 69 | def test_condition_groups(self, codex_adata: AnnData): 70 | conditions = codex_adata.obs[_CONDITION_KEY].cat.categories[:2] 71 | cc.gr.diff_nhood_enrichment( 72 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, condition_groups=conditions 73 | ) 74 | 75 | # Should only have one comparison 76 | assert len(codex_adata.uns[key]) == 1 77 | pair_key = f"{conditions[0]}_{conditions[1]}" 78 | assert pair_key in codex_adata.uns[key] 79 | 80 | del codex_adata.uns[key] 81 | 82 | def test_invalid_inputs(self, codex_adata: AnnData): 83 | # Test invalid cluster key 84 | with pytest.raises(KeyError): 85 | cc.gr.diff_nhood_enrichment(codex_adata, cluster_key="invalid_key", condition_key=_CONDITION_KEY) 86 | 87 | # Test invalid condition key 88 | with pytest.raises(KeyError): 89 | cc.gr.diff_nhood_enrichment(codex_adata, cluster_key=_CLUSTER_KEY, condition_key="invalid_key") 90 | 91 | # Test invalid library key when using pvalues 92 | with pytest.raises(KeyError): 93 | cc.gr.diff_nhood_enrichment( 94 | codex_adata, 95 | cluster_key=_CLUSTER_KEY, 96 | condition_key=_CONDITION_KEY, 97 | library_key="invalid_key", 98 | pvalues=True, 99 | ) 100 | 101 | def test_copy_return(self, codex_adata: AnnData): 102 | # Test copy=True returns results without modifying adata 103 | result = cc.gr.diff_nhood_enrichment( 104 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, copy=True 105 | ) 106 | 107 | assert key not in codex_adata.uns 108 | assert isinstance(result, dict) 109 | assert len(result) > 0 110 | -------------------------------------------------------------------------------- /tests/graph/test_group.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import squidpy as sq 3 | 4 | import cellcharter as cc 5 | 6 | GROUP_KEY = "batch" 7 | LABEL_KEY = "Cell_class" 8 | key = f"{GROUP_KEY}_{LABEL_KEY}_enrichment" 9 | 10 | adata = sq.datasets.merfish() 11 | 12 | 13 | class TestEnrichment: 14 | def test_enrichment(self): 15 | cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY) 16 | 17 | assert key in adata.uns 18 | assert "enrichment" in adata.uns[key] 19 | assert "params" in adata.uns[key] 20 | 21 | del adata.uns[key] 22 | 23 | def test_copy(self): 24 | enrichment_dict = cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY, copy=True) 25 | 26 | assert "enrichment" in enrichment_dict 27 | 28 | enrichment_dict = cc.gr.enrichment( 29 | adata, group_key=GROUP_KEY, label_key=LABEL_KEY, copy=True, observed_expected=True 30 | ) 31 | 32 | assert "enrichment" in enrichment_dict 33 | assert "observed" in enrichment_dict 34 | assert "expected" in enrichment_dict 35 | 36 | def test_obs_exp(self): 37 | cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY, observed_expected=True) 38 | 39 | assert key in adata.uns 40 | assert "enrichment" in adata.uns[key] 41 | assert "observed" in adata.uns[key] 42 | assert "expected" in adata.uns[key] 43 | 44 | observed = adata.uns[key]["observed"] 45 | expected = adata.uns[key]["expected"] 46 | 47 | assert observed.shape == ( 48 | adata.obs[GROUP_KEY].cat.categories.shape[0], 49 | adata.obs[LABEL_KEY].cat.categories.shape[0], 50 | ) 51 | assert expected.shape[0] == adata.obs[GROUP_KEY].cat.categories.shape[0] 52 | assert np.all((observed >= 0) & (observed <= 1)) 53 | assert np.all((expected >= 0) & (expected <= 1)) 54 | 55 | def test_perm(self): 56 | result_analytical = cc.gr.enrichment( 57 | adata, group_key=GROUP_KEY, label_key=LABEL_KEY, pvalues=False, copy=True, observed_expected=True 58 | ) 59 | result_perm = cc.gr.enrichment( 60 | adata, 61 | group_key=GROUP_KEY, 62 | label_key=LABEL_KEY, 63 | pvalues=True, 64 | n_perms=5000, 65 | observed_expected=True, 66 | copy=True, 67 | ) 68 | np.testing.assert_allclose(result_analytical["enrichment"], result_perm["enrichment"], atol=0.1) 69 | np.testing.assert_allclose(result_analytical["observed"], result_perm["observed"], atol=0.1) 70 | np.testing.assert_allclose(result_analytical["expected"], result_perm["expected"], atol=0.1) 71 | -------------------------------------------------------------------------------- /tests/graph/test_nhood.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import squidpy as sq 4 | from squidpy._constants._pkg_constants import Key 5 | 6 | import cellcharter as cc 7 | 8 | _CK = "cell type" 9 | key = Key.uns.nhood_enrichment(_CK) 10 | 11 | adata = sq.datasets.imc() 12 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True) 13 | cc.gr.remove_long_links(adata) 14 | 15 | 16 | class TestNhoodEnrichment: 17 | def test_enrichment(self): 18 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, only_inter=False, log_fold_change=False) 19 | enrichment = adata.uns[key]["enrichment"] 20 | assert np.all((enrichment >= -1) & (enrichment <= 1)) 21 | 22 | del adata.uns[key] 23 | 24 | def test_fold_change(self): 25 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, log_fold_change=True) 26 | 27 | del adata.uns[key] 28 | 29 | def test_nhood_obs_exp(self): 30 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, only_inter=False, observed_expected=True) 31 | observed = adata.uns[key]["observed"] 32 | expected = adata.uns[key]["expected"] 33 | 34 | assert observed.shape[0] == adata.obs[_CK].cat.categories.shape[0] 35 | assert observed.shape == expected.shape 36 | assert np.all((observed >= 0) & (observed <= 1)) 37 | assert np.all((expected >= 0) & (expected <= 1)) 38 | 39 | del adata.uns[key] 40 | 41 | def test_symmetric(self): 42 | result = cc.gr.nhood_enrichment( 43 | adata, cluster_key=_CK, symmetric=True, log_fold_change=True, only_inter=False, copy=True 44 | ) 45 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02) 46 | 47 | result = cc.gr.nhood_enrichment( 48 | adata, cluster_key=_CK, symmetric=True, log_fold_change=False, only_inter=False, copy=True 49 | ) 50 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02) 51 | 52 | result = cc.gr.nhood_enrichment( 53 | adata, cluster_key=_CK, symmetric=True, log_fold_change=False, only_inter=True, copy=True 54 | ) 55 | result["enrichment"][result["enrichment"].isna()] = 0 # issymmetric fails with NaNs 56 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02) 57 | 58 | result = cc.gr.nhood_enrichment( 59 | adata, cluster_key=_CK, symmetric=True, log_fold_change=True, only_inter=True, copy=True 60 | ) 61 | result["enrichment"][result["enrichment"].isna()] = 0 # issymmetric fails with NaNs 62 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02) 63 | 64 | def test_perm(self): 65 | result_analytical = cc.gr.nhood_enrichment( 66 | adata, cluster_key=_CK, only_inter=True, pvalues=False, observed_expected=True, copy=True 67 | ) 68 | result_perm = cc.gr.nhood_enrichment( 69 | adata, 70 | cluster_key=_CK, 71 | only_inter=True, 72 | pvalues=True, 73 | n_perms=5000, 74 | observed_expected=True, 75 | copy=True, 76 | n_jobs=15, 77 | ) 78 | np.testing.assert_allclose(result_analytical["enrichment"], result_perm["enrichment"], atol=0.1) 79 | np.testing.assert_allclose(result_analytical["observed"], result_perm["observed"], atol=0.1) 80 | np.testing.assert_allclose(result_analytical["expected"], result_perm["expected"], atol=0.1) 81 | -------------------------------------------------------------------------------- /tests/plotting/test_group.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pytest 3 | import squidpy as sq 4 | 5 | try: 6 | from matplotlib.colormaps import get_cmap 7 | except ImportError: 8 | from matplotlib.pyplot import get_cmap 9 | 10 | import cellcharter as cc 11 | 12 | GROUP_KEY = "batch" 13 | LABEL_KEY = "Cell_class" 14 | key = f"{GROUP_KEY}_{LABEL_KEY}_enrichment" 15 | 16 | adata = sq.datasets.merfish() 17 | adata_empirical = adata.copy() 18 | cc.gr.enrichment(adata_empirical, group_key=GROUP_KEY, label_key=LABEL_KEY, pvalues=True, n_perms=1000) 19 | 20 | adata_analytical = adata.copy() 21 | cc.gr.enrichment(adata_analytical, group_key=GROUP_KEY, label_key=LABEL_KEY) 22 | 23 | 24 | class TestProportion: 25 | def test_proportion(self): 26 | cc.pl.proportion(adata, group_key=GROUP_KEY, label_key=LABEL_KEY) 27 | 28 | def test_groups_labels(self): 29 | cc.pl.proportion( 30 | adata, 31 | group_key=GROUP_KEY, 32 | label_key=LABEL_KEY, 33 | groups=adata.obs[GROUP_KEY].cat.categories[:3], 34 | labels=adata.obs[LABEL_KEY].cat.categories[:4], 35 | ) 36 | 37 | 38 | class TestPlotEnrichment: 39 | @pytest.mark.parametrize("adata_enrichment", [adata_analytical, adata_empirical]) 40 | def test_enrichment(self, adata_enrichment): 41 | cc.pl.enrichment(adata_enrichment, group_key=GROUP_KEY, label_key=LABEL_KEY) 42 | 43 | @pytest.mark.parametrize("adata_enrichment", [adata_analytical, adata_empirical]) 44 | @pytest.mark.parametrize("label_cluster", [False, True]) 45 | @pytest.mark.parametrize("show_pvalues", [False, True]) 46 | @pytest.mark.parametrize("groups", [None, adata.obs[GROUP_KEY].cat.categories[:3]]) 47 | @pytest.mark.parametrize("labels", [None, adata.obs[LABEL_KEY].cat.categories[:4]]) 48 | @pytest.mark.parametrize("size_threshold", [1, 2.5]) 49 | @pytest.mark.parametrize("palette", [None, "coolwarm", get_cmap("coolwarm")]) 50 | @pytest.mark.parametrize("figsize", [None, (10, 8)]) 51 | @pytest.mark.parametrize("alpha,edgecolor", [(1, "red"), (0.5, "blue")]) 52 | def test_params( 53 | self, 54 | adata_enrichment, 55 | label_cluster, 56 | show_pvalues, 57 | groups, 58 | labels, 59 | size_threshold, 60 | palette, 61 | figsize, 62 | alpha, 63 | edgecolor, 64 | ): 65 | cc.pl.enrichment( 66 | adata_enrichment, 67 | group_key=GROUP_KEY, 68 | label_key=LABEL_KEY, 69 | label_cluster=label_cluster, 70 | groups=groups, 71 | labels=labels, 72 | show_pvalues=show_pvalues, 73 | size_threshold=size_threshold, 74 | palette=palette, 75 | figsize=figsize, 76 | alpha=alpha, 77 | edgecolor=edgecolor, 78 | ) 79 | plt.close() 80 | 81 | def test_no_pvalues(self): 82 | # If the enrichment data is not present, it should raise an error 83 | with pytest.raises(ValueError): 84 | cc.pl.enrichment(adata_analytical, "group_key", "label_key", show_pvalues=True) 85 | 86 | with pytest.raises(ValueError): 87 | cc.pl.enrichment(adata_analytical, "group_key", "label_key", significance=0.01) 88 | 89 | with pytest.raises(ValueError): 90 | cc.pl.enrichment(adata_analytical, "group_key", "label_key", significant_only=True) 91 | 92 | def test_obs_exp(self): 93 | cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY, observed_expected=True) 94 | cc.pl.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY) 95 | 96 | def test_enrichment_no_enrichment_data(self): 97 | with pytest.raises(ValueError): 98 | cc.pl.enrichment(adata, "group_key", "label_key") 99 | 100 | def test_size_threshold_zero(self): 101 | with pytest.raises(ValueError): 102 | cc.pl.enrichment(adata_empirical, group_key=GROUP_KEY, label_key=LABEL_KEY, size_threshold=0) 103 | -------------------------------------------------------------------------------- /tests/plotting/test_plot_nhood.py: -------------------------------------------------------------------------------- 1 | import squidpy as sq 2 | from squidpy._constants._pkg_constants import Key 3 | 4 | import cellcharter as cc 5 | 6 | _CK = "cell type" 7 | key = Key.uns.nhood_enrichment(_CK) 8 | 9 | adata = sq.datasets.imc() 10 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True) 11 | cc.gr.remove_long_links(adata) 12 | 13 | 14 | class TestPlotNhoodEnrichment: 15 | def test_annotate(self): 16 | cc.gr.nhood_enrichment(adata, cluster_key=_CK) 17 | cc.pl.nhood_enrichment(adata, cluster_key=_CK, annotate=True) 18 | 19 | del adata.uns[key] 20 | 21 | def test_significance(self): 22 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, pvalues=True, n_perms=100) 23 | 24 | cc.pl.nhood_enrichment(adata, cluster_key=_CK, significance=0.05) 25 | cc.pl.nhood_enrichment(adata, cluster_key=_CK, annotate=True, significance=0.05) 26 | 27 | del adata.uns[key] 28 | -------------------------------------------------------------------------------- /tests/plotting/test_plot_stability.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import scipy.sparse as sps 3 | import squidpy as sq 4 | 5 | import cellcharter as cc 6 | 7 | 8 | class TestPlotStability: 9 | @pytest.mark.parametrize("dataset_name", ["mibitof"]) 10 | def test_spatial_proteomics(self, dataset_name: str): 11 | download_dataset = getattr(sq.datasets, dataset_name) 12 | adata = download_dataset() 13 | if sps.issparse(adata.X): 14 | adata.X = adata.X.todense() 15 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True) 16 | cc.gr.remove_long_links(adata) 17 | cc.gr.aggregate_neighbors(adata, n_layers=3) 18 | 19 | model = cc.tl.ClusterAutoK.load(f"tests/_models/cellcharter_autok_{dataset_name}") 20 | 21 | cc.pl.autok_stability(model) 22 | -------------------------------------------------------------------------------- /tests/plotting/test_shape.py: -------------------------------------------------------------------------------- 1 | from anndata import AnnData 2 | 3 | import cellcharter as cc 4 | 5 | 6 | class TestPlotBoundaries: 7 | def test_boundaries(self, codex_adata: AnnData): 8 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=250) 9 | cc.tl.boundaries(codex_adata) 10 | cc.pl.boundaries(codex_adata, sample="BALBc-1", alpha_boundary=0.5, show_cells=False) 11 | 12 | # def test_boundaries_only(self, codex_adata: AnnData): 13 | # cc.tl.boundaries(codex_adata) 14 | # cc.pl.boundaries(codex_adata, sample="BALBc-1", alpha_boundary=0.5) 15 | -------------------------------------------------------------------------------- /tests/tools/test_autok.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import scipy.sparse as sps 4 | import squidpy as sq 5 | 6 | import cellcharter as cc 7 | 8 | 9 | class TestClusterAutoK: 10 | @pytest.mark.parametrize("dataset_name", ["mibitof"]) 11 | def test_spatial_proteomics(self, dataset_name: str): 12 | download_dataset = getattr(sq.datasets, dataset_name) 13 | adata = download_dataset() 14 | if sps.issparse(adata.X): 15 | adata.X = adata.X.todense() 16 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True) 17 | cc.gr.remove_long_links(adata) 18 | cc.gr.aggregate_neighbors(adata, n_layers=3) 19 | 20 | model_params = { 21 | "init_strategy": "kmeans", 22 | "random_state": 42, 23 | "trainer_params": {"accelerator": "cpu", "enable_progress_bar": False}, 24 | } 25 | autok = cc.tl.ClusterAutoK( 26 | n_clusters=(2, 5), model_class=cc.tl.GaussianMixture, model_params=model_params, max_runs=3 27 | ) 28 | autok.fit(adata, use_rep="X_cellcharter") 29 | adata.obs[f"cellcharter_{autok.best_k}"] = autok.predict(adata, use_rep="X_cellcharter", k=autok.best_k) 30 | 31 | assert len(np.unique(adata.obs[f"cellcharter_{autok.best_k}"])) == autok.best_k 32 | -------------------------------------------------------------------------------- /tests/tools/test_gmm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import scipy.sparse as sps 3 | import squidpy as sq 4 | 5 | import cellcharter as cc 6 | 7 | 8 | class TestCluster: 9 | @pytest.mark.parametrize("dataset_name", ["mibitof"]) 10 | def test_sparse(self, dataset_name: str): 11 | download_dataset = getattr(sq.datasets, dataset_name) 12 | adata = download_dataset() 13 | adata.X = sps.csr_matrix(adata.X) 14 | 15 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True) 16 | cc.gr.remove_long_links(adata) 17 | 18 | gmm = cc.tl.Cluster(n_clusters=(10)) 19 | 20 | # Check if fit raises a ValueError 21 | with pytest.raises(ValueError): 22 | gmm.fit(adata, use_rep=None) 23 | -------------------------------------------------------------------------------- /tests/tools/test_shape.py: -------------------------------------------------------------------------------- 1 | from anndata import AnnData 2 | from shapely import Polygon 3 | 4 | import cellcharter as cc 5 | 6 | 7 | # Test for cc.tl.boundaries, that computes the topological boundaries of sets of cells. 8 | class TestBoundaries: 9 | def test_boundaries(self, codex_adata: AnnData): 10 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=250) 11 | cc.tl.boundaries(codex_adata) 12 | 13 | boundaries = codex_adata.uns["shape_component"]["boundary"] 14 | 15 | assert isinstance(boundaries, dict) 16 | 17 | # Check if boundaries contains all components of codex_adata 18 | assert set(boundaries.keys()) == set(codex_adata.obs["component"].cat.categories) 19 | 20 | def test_copy(self, codex_adata: AnnData): 21 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=250) 22 | boundaries = cc.tl.boundaries(codex_adata, copy=True) 23 | 24 | assert isinstance(boundaries, dict) 25 | 26 | # Check if boundaries contains all components of codex_adata 27 | assert set(boundaries.keys()) == set(codex_adata.obs["component"].cat.categories) 28 | 29 | 30 | class TestLinearity: 31 | def test_rectangle(self, codex_adata: AnnData): 32 | codex_adata.obs["rectangle"] = 1 33 | 34 | polygon = Polygon([(0, 0), (0, 10), (2, 10), (2, 0)]) 35 | 36 | codex_adata.uns["shape_rectangle"] = {"boundary": {1: polygon}} 37 | linearities = cc.tl.linearity(codex_adata, "rectangle", copy=True) 38 | assert linearities[1] == 1.0 39 | 40 | def test_symmetrical_cross(self, codex_adata: AnnData): 41 | codex_adata.obs["cross"] = 1 42 | 43 | # Symmetrical cross with arm width of 2 and length of 5 44 | polygon = Polygon( 45 | [(0, 5), (0, 7), (5, 7), (5, 12), (7, 12), (7, 7), (12, 7), (12, 5), (7, 5), (7, 0), (5, 0), (5, 5)] 46 | ) 47 | 48 | codex_adata.uns["shape_cross"] = {"boundary": {1: polygon}} 49 | linearities = cc.tl.linearity(codex_adata, "cross", copy=True) 50 | 51 | # The cross is symmetrical, so the linearity should be 0.5 52 | assert abs(linearities[1] - 0.5) < 0.01 53 | 54 | def test_thickness(self, codex_adata: AnnData): 55 | # The thickness of the cross should not influence the linearity 56 | codex_adata.obs["cross"] = 1 57 | 58 | # Symmetrical cross with arm width of 2 and length of 5 59 | polygon1 = Polygon( 60 | [(0, 5), (0, 6), (5, 6), (5, 11), (6, 11), (6, 6), (11, 6), (11, 5), (6, 5), (6, 0), (5, 0), (5, 5)] 61 | ) 62 | 63 | # Symmetrical cross with arm width of 2 and length of 5 64 | polygon2 = Polygon( 65 | [(0, 5), (0, 7), (5, 7), (5, 12), (7, 12), (7, 7), (12, 7), (12, 5), (7, 5), (7, 0), (5, 0), (5, 5)] 66 | ) 67 | 68 | codex_adata.uns["shape_cross"] = {"boundary": {1: polygon1}} 69 | linearities1 = cc.tl.linearity(codex_adata, "cross", copy=True) 70 | 71 | codex_adata.uns["shape_cross"] = {"boundary": {1: polygon2}} 72 | linearities2 = cc.tl.linearity(codex_adata, "cross", copy=True) 73 | 74 | assert abs(linearities1[1] - linearities2[1]) < 0.01 75 | --------------------------------------------------------------------------------