├── .codecov.yaml
├── .cruft.json
├── .editorconfig
├── .flake8
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
└── workflows
│ ├── build.yaml
│ ├── release.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _static
│ ├── .gitkeep
│ ├── cellcharter.png
│ ├── cellcharter_workflow.png
│ ├── css
│ │ └── custom.css
│ └── spatial_clusters.png
├── _templates
│ ├── .gitkeep
│ └── autosummary
│ │ └── class.rst
├── api.md
├── changelog.md
├── conf.py
├── contributing.md
├── extensions
│ └── typed_returns.py
├── graph.md
├── index.md
├── notebooks
│ ├── codex_mouse_spleen.ipynb
│ ├── cosmx_human_nsclc.ipynb
│ ├── data
│ │ └── .gitkeep
│ ├── example.ipynb
│ └── tutorial_models
│ │ └── codex_mouse_spleen_trvae
│ │ ├── attr.pkl
│ │ ├── model_params.pt
│ │ └── var_names.csv
├── plotting.md
├── references.bib
├── references.md
└── tools.md
├── pyproject.toml
├── src
└── cellcharter
│ ├── __init__.py
│ ├── _constants
│ └── _pkg_constants.py
│ ├── _utils.py
│ ├── datasets
│ ├── __init__.py
│ └── _dataset.py
│ ├── gr
│ ├── __init__.py
│ ├── _aggr.py
│ ├── _build.py
│ ├── _group.py
│ ├── _nhood.py
│ └── _utils.py
│ ├── pl
│ ├── __init__.py
│ ├── _autok.py
│ ├── _group.py
│ ├── _nhood.py
│ ├── _shape.py
│ └── _utils.py
│ └── tl
│ ├── __init__.py
│ ├── _autok.py
│ ├── _gmm.py
│ ├── _shape.py
│ ├── _trvae.py
│ └── _utils.py
└── tests
├── _data
└── test_data.h5ad
├── _models
├── cellcharter_autok_imc
│ ├── attributes.pickle
│ ├── best_models
│ │ ├── GaussianMixture_k1
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k10
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k11
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k12
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k13
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k14
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k15
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k16
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k2
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k3
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k4
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k5
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k6
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k7
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ ├── GaussianMixture_k8
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ │ ├── config.json
│ │ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ │ └── GaussianMixture_k9
│ │ │ ├── attributes.pickle
│ │ │ ├── model
│ │ │ ├── config.json
│ │ │ └── parameters.pt
│ │ │ └── params.pickle
│ └── params.pickle
└── cellcharter_autok_mibitof
│ ├── attributes.pickle
│ ├── best_models
│ ├── GaussianMixture_k1
│ │ ├── attributes.pickle
│ │ ├── model
│ │ │ ├── config.json
│ │ │ └── parameters.pt
│ │ └── params.json
│ ├── GaussianMixture_k2
│ │ ├── attributes.pickle
│ │ ├── model
│ │ │ ├── config.json
│ │ │ └── parameters.pt
│ │ └── params.json
│ ├── GaussianMixture_k3
│ │ ├── attributes.pickle
│ │ ├── model
│ │ │ ├── config.json
│ │ │ └── parameters.pt
│ │ └── params.json
│ ├── GaussianMixture_k4
│ │ ├── attributes.pickle
│ │ ├── model
│ │ │ ├── config.json
│ │ │ └── parameters.pt
│ │ └── params.json
│ ├── GaussianMixture_k5
│ │ ├── attributes.pickle
│ │ ├── model
│ │ │ ├── config.json
│ │ │ └── parameters.pt
│ │ └── params.json
│ └── GaussianMixture_k6
│ │ ├── attributes.pickle
│ │ ├── model
│ │ ├── config.json
│ │ └── parameters.pt
│ │ └── params.json
│ └── params.pickle
├── conftest.py
├── graph
├── test_aggregate_neighbors.py
├── test_build.py
├── test_diff_nhood.py
├── test_group.py
└── test_nhood.py
├── plotting
├── test_group.py
├── test_plot_nhood.py
├── test_plot_stability.py
└── test_shape.py
└── tools
├── test_autok.py
├── test_gmm.py
└── test_shape.py
/.codecov.yaml:
--------------------------------------------------------------------------------
1 | # Based on pydata/xarray
2 | codecov:
3 | require_ci_to_pass: no
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | # Require 1% coverage, i.e., always succeed
10 | target: 1
11 | patch: false
12 | changes: false
13 |
14 | comment:
15 | layout: diff, flags, files
16 | behavior: once
17 | require_base: no
18 |
--------------------------------------------------------------------------------
/.cruft.json:
--------------------------------------------------------------------------------
1 | {
2 | "template": "https://github.com/scverse/cookiecutter-scverse",
3 | "commit": "87a407a65408d75a949c0b54b19fd287475a56f8",
4 | "checkout": "v0.4.0",
5 | "context": {
6 | "cookiecutter": {
7 | "project_name": "cellcharter",
8 | "package_name": "cellcharter",
9 | "project_description": "A Python package for the identification, characterization and comparison of spatial clusters from spatial -omics data.",
10 | "author_full_name": "Marco Varrone",
11 | "author_email": "marco.varrone@unil.ch",
12 | "github_user": "marcovarrone",
13 | "project_repo": "https://github.com/marcovarrone/cellcharter",
14 | "license": "BSD 3-Clause License",
15 | "_copy_without_render": [
16 | ".github/workflows/build.yaml",
17 | ".github/workflows/test.yaml",
18 | "docs/_templates/autosummary/**.rst"
19 | ],
20 | "_render_devdocs": false,
21 | "_jinja2_env_vars": {
22 | "lstrip_blocks": true,
23 | "trim_blocks": true
24 | },
25 | "_template": "https://github.com/scverse/cookiecutter-scverse"
26 | }
27 | },
28 | "directory": null
29 | }
30 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 |
11 | [*.{yml,yaml}]
12 | indent_size = 2
13 |
14 | [.cruft.json]
15 | indent_size = 2
16 |
17 | [Makefile]
18 | indent_style = tab
19 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | # Can't yet be moved to the pyproject.toml due to https://github.com/PyCQA/flake8/issues/234
2 | [flake8]
3 | max-line-length = 120
4 | ignore =
5 | # line break before a binary operator -> black does not adhere to PEP8
6 | W503
7 | # line break occured after a binary operator -> black does not adhere to PEP8
8 | W504
9 | # line too long -> we accept long comment lines; black gets rid of long code lines
10 | E501
11 | # whitespace before : -> black does not adhere to PEP8
12 | E203
13 | # line break before binary operator -> black does not adhere to PEP8
14 | W503
15 | # missing whitespace after ,', ';', or ':' -> black does not adhere to PEP8
16 | E231
17 | # continuation line over-indented for hanging indent -> black does not adhere to PEP8
18 | E126
19 | # too many leading '#' for block comment -> this is fine for indicating sections
20 | E262
21 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
22 | E731
23 | # allow I, O, l as variable names -> I is the identity matrix
24 | E741
25 | # Missing docstring in public package
26 | D104
27 | # Missing docstring in public module
28 | D100
29 | # Missing docstring in __init__
30 | D107
31 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
32 | B008
33 | # Missing docstring in magic method
34 | D105
35 | # format string does contain unindexed parameters
36 | P101
37 | # first line should end with a period [Bug: doesn't work with single-line docstrings]
38 | D400
39 | # First line should be in imperative mood; try rephrasing
40 | D401
41 | # Block quote ends without a blank line; unexpected unindent
42 | RST201
43 | # Definition list ends without a blank line; unexpected unindent
44 | RST203
45 | # Block quote ends without a blank line; unexpected unindent
46 | RST301
47 | exclude = .git,__pycache__,build,docs/_build,dist
48 | per-file-ignores =
49 | tests/*: D
50 | */__init__.py: F401
51 | src/cellcharter/_constants/_pkg_constants.py: D101,D102,D106
52 | rst-roles =
53 | attr,
54 | class,
55 | func,
56 | ref,
57 | cite:p,
58 | cite:t,
59 | mod,
60 | rst-directives =
61 | envvar,
62 | exception,
63 | rst-substitutions =
64 | version,
65 | extend-ignore =
66 | RST307,
67 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Report something that is broken or incorrect
3 | labels: bug
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports)
9 | detailing how to provide the necessary information for us to reproduce your bug. In brief:
10 | * Please provide exact steps how to reproduce the bug in a clean Python environment.
11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure.
12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly
13 | available datasets or to share a subset of your data.
14 |
15 | - type: textarea
16 | id: report
17 | attributes:
18 | label: Report
19 | description: A clear and concise description of what the bug is.
20 | validations:
21 | required: true
22 |
23 | - type: textarea
24 | id: versions
25 | attributes:
26 | label: Version information
27 | description: |
28 | Please paste below the output of
29 |
30 | ```python
31 | import session_info
32 | session_info.show(html=False, dependencies=True)
33 | ```
34 | placeholder: |
35 | -----
36 | anndata 0.8.0rc2.dev27+ge524389
37 | session_info 1.0.0
38 | -----
39 | asttokens NA
40 | awkward 1.8.0
41 | backcall 0.2.0
42 | cython_runtime NA
43 | dateutil 2.8.2
44 | debugpy 1.6.0
45 | decorator 5.1.1
46 | entrypoints 0.4
47 | executing 0.8.3
48 | h5py 3.7.0
49 | ipykernel 6.15.0
50 | jedi 0.18.1
51 | mpl_toolkits NA
52 | natsort 8.1.0
53 | numpy 1.22.4
54 | packaging 21.3
55 | pandas 1.4.2
56 | parso 0.8.3
57 | pexpect 4.8.0
58 | pickleshare 0.7.5
59 | pkg_resources NA
60 | prompt_toolkit 3.0.29
61 | psutil 5.9.1
62 | ptyprocess 0.7.0
63 | pure_eval 0.2.2
64 | pydev_ipython NA
65 | pydevconsole NA
66 | pydevd 2.8.0
67 | pydevd_file_utils NA
68 | pydevd_plugins NA
69 | pydevd_tracing NA
70 | pygments 2.12.0
71 | pytz 2022.1
72 | scipy 1.8.1
73 | setuptools 62.5.0
74 | setuptools_scm NA
75 | six 1.16.0
76 | stack_data 0.3.0
77 | tornado 6.1
78 | traitlets 5.3.0
79 | wcwidth 0.2.5
80 | zmq 23.1.0
81 | -----
82 | IPython 8.4.0
83 | jupyter_client 7.3.4
84 | jupyter_core 4.10.0
85 | -----
86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0]
87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35
88 | -----
89 | Session information updated at 2022-07-07 17:55
90 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Scverse Community Forum
4 | url: https://discourse.scverse.org/
5 | about: If you have questions about “How to do X”, please ask them here.
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Propose a new feature for cellcharter
3 | labels: enhancement
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Description of feature
9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered.
10 | validations:
11 | required: true
12 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: Check Build
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | package:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Set up Python 3.10
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: "3.10"
22 | cache: "pip"
23 | cache-dependency-path: "**/pyproject.toml"
24 | - name: Install build dependencies
25 | run: python -m pip install --upgrade pip wheel twine build
26 | - name: Build package
27 | run: python -m build
28 | - name: Check package
29 | run: twine check --strict dist/*.whl
30 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/
8 | jobs:
9 | release:
10 | name: Upload release to PyPI
11 | runs-on: ubuntu-latest
12 | environment:
13 | name: pypi
14 | url: https://pypi.org/p/cellcharter
15 | permissions:
16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
17 | steps:
18 | - uses: actions/checkout@v4
19 | with:
20 | filter: blob:none
21 | fetch-depth: 0
22 | - uses: actions/setup-python@v4
23 | with:
24 | python-version: "3.x"
25 | cache: "pip"
26 | - run: pip install build
27 | - run: python -m build
28 | - name: Publish package distributions to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | schedule:
9 | - cron: "0 5 1,15 * *"
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | test:
17 | runs-on: ${{ matrix.os }}
18 | defaults:
19 | run:
20 | shell: bash -e {0} # -e to fail on error
21 |
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | include:
26 | - os: ubuntu-latest
27 | python: "3.10"
28 | - os: ubuntu-latest
29 | python: "3.12"
30 | - os: ubuntu-latest
31 | python: "3.12"
32 | pip-flags: "--pre"
33 | name: PRE-RELEASE DEPENDENCIES
34 |
35 | name: ${{ matrix.name }} Python ${{ matrix.python }}
36 |
37 | env:
38 | OS: ${{ matrix.os }}
39 | PYTHON: ${{ matrix.python }}
40 |
41 | steps:
42 | - uses: actions/checkout@v3
43 | with:
44 | lfs: true # Ensure LFS files are fetched
45 |
46 | - name: Set up Python ${{ matrix.python }}
47 | uses: actions/setup-python@v4
48 | with:
49 | python-version: ${{ matrix.python }}
50 | cache: "pip"
51 | cache-dependency-path: "**/pyproject.toml"
52 |
53 | - name: Install test dependencies
54 | run: |
55 | python -m pip install --upgrade pip wheel
56 | - name: Install dependencies
57 | run: |
58 | pip install ${{ matrix.pip-flags }} ".[dev,test]"
59 | - name: Test
60 | env:
61 | MPLBACKEND: agg
62 | PLATFORM: ${{ matrix.os }}
63 | DISPLAY: :42
64 | run: |
65 | coverage run -m pytest -v --color=yes
66 | - name: Report coverage
67 | run: |
68 | coverage report
69 | - name: Upload coverage
70 | uses: codecov/codecov-action@v3
71 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Temp files
2 | .DS_Store
3 | *~
4 | buck-out/
5 |
6 | # Compiled files
7 | .venv/
8 | __pycache__/
9 | .mypy_cache/
10 | .ruff_cache/
11 |
12 | # Distribution / packaging
13 | /build/
14 | /dist/
15 | /*.egg-info/
16 |
17 | # Tests and coverage
18 | /.pytest_cache/
19 | /.cache/
20 | /data/
21 | /node_modules/
22 |
23 | # docs
24 | /docs/generated/
25 | /docs/_build/
26 |
27 | # IDEs
28 | /.idea/
29 | /.vscode/
30 |
31 | docs/notebooks/tutorial_data
32 | docs/jupyter_execute
33 |
34 | .ipynb_checkpoints
35 | *.ipynb
36 |
37 | trash
38 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | node: 16.14.2
5 | default_stages:
6 | - pre-commit
7 | - pre-push
8 | minimum_pre_commit_version: 2.16.0
9 | repos:
10 | - repo: https://github.com/psf/black
11 | rev: 25.1.0
12 | hooks:
13 | - id: black
14 | - repo: https://github.com/pre-commit/mirrors-prettier
15 | rev: v4.0.0-alpha.8
16 | hooks:
17 | - id: prettier
18 | exclude: '.*\.py$ |.*\.toml$'
19 | - repo: https://github.com/asottile/blacken-docs
20 | rev: 1.19.1
21 | hooks:
22 | - id: blacken-docs
23 | - repo: https://github.com/PyCQA/isort
24 | rev: 6.0.1
25 | hooks:
26 | - id: isort
27 | - repo: https://github.com/asottile/yesqa
28 | rev: v1.5.0
29 | hooks:
30 | - id: yesqa
31 | additional_dependencies:
32 | - flake8-tidy-imports
33 | - flake8-docstrings
34 | - flake8-rst-docstrings
35 | - flake8-comprehensions
36 | - flake8-bugbear
37 | - flake8-blind-except
38 | - repo: https://github.com/pre-commit/pre-commit-hooks
39 | rev: v5.0.0
40 | hooks:
41 | - id: detect-private-key
42 | - id: check-ast
43 | - id: end-of-file-fixer
44 | - id: mixed-line-ending
45 | args: [--fix=lf]
46 | - id: trailing-whitespace
47 | - id: check-case-conflict
48 | - repo: https://github.com/PyCQA/autoflake
49 | rev: v2.3.1
50 | hooks:
51 | - id: autoflake
52 | args:
53 | - --in-place
54 | - --remove-all-unused-imports
55 | - --remove-unused-variable
56 | - --ignore-init-module-imports
57 | - repo: https://github.com/PyCQA/flake8
58 | rev: 7.2.0
59 | hooks:
60 | - id: flake8
61 | additional_dependencies:
62 | - flake8-tidy-imports
63 | - flake8-docstrings
64 | - flake8-rst-docstrings
65 | - flake8-comprehensions
66 | - flake8-bugbear
67 | - flake8-blind-except
68 | - repo: https://github.com/asottile/pyupgrade
69 | rev: v3.19.1
70 | hooks:
71 | - id: pyupgrade
72 | args: [--py3-plus, --py38-plus, --keep-runtime-typing]
73 | - repo: local
74 | hooks:
75 | - id: forbid-to-commit
76 | name: Don't commit rej files
77 | entry: |
78 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
79 | Fix the merge conflicts manually and remove the .rej files.
80 | language: fail
81 | files: '.*\.rej$'
82 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html
2 | version: 2
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.10"
7 | sphinx:
8 | configuration: docs/conf.py
9 | # disable this for more lenient docs builds
10 | fail_on_warning: false
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements:
16 | - doc
17 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog][],
6 | and this project adheres to [Semantic Versioning][].
7 |
8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/
9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html
10 |
11 | ## [Unreleased]
12 |
13 | ### Added
14 |
15 | - Basic tool, preprocessing and plotting functions
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, Marco Varrone
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 | **A Python package for the identification, characterization and comparison of spatial clusters from spatial omics data.**
5 |
6 | ---
7 |
8 |
9 | Documentation •
10 | Examples •
11 | Paper •
12 | Preprint
13 |
14 |
15 | [![Tests][badge-tests]][link-tests]
16 | [![Documentation][badge-docs]][link-docs]
17 | [](https://pepy.tech/projects/cellcharter)
18 |
19 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/CSOgroup/cellcharter/test.yaml?branch=main
20 | [link-tests]: https://github.com/CSOgroup/cellcharter/actions/workflows/test.yml
21 | [badge-docs]: https://img.shields.io/readthedocs/cellcharter
22 |
23 |
24 |
25 | ## Background
26 |
27 |
28 | Spatial clustering (or spatial domain identification) determines cellular niches characterized by specific admixing of these populations. It assigns cells to clusters based on both their intrinsic features (e.g., protein or mRNA expression), and the features of neighboring cells in the tissue.
29 |
30 |
31 |
32 |
33 |
34 |
35 | CellCharter is able to automatically identify spatial domains and offers a suite of approaches for cluster characterization and comparison.
36 |
37 |
38 |
39 |
40 |
41 | ## Features
42 |
43 | - **Identify niches for multiple samples**: By combining the power of scVI and scArches, CellCharter can identify domains for multiple samples simultaneously, even with in presence of batch effects.
44 | - **Scalability**: CellCharter can handle large datasets with millions of cells and thousands of features. The possibility to run it on GPUs makes it even faster
45 | - **Flexibility**: CellCharter can be used with different types of spatial omics data, such as spatial transcriptomics, proteomics, epigenomics and multiomics data. The only difference is the method used for dimensionality reduction and batch effect removal.
46 | - Spatial transcriptomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Zero-inflated negative binomial distribution.
47 | - Spatial proteomics: CellCharter has been tested on a version of [scArches](https://docs.scarches.org/en/latest/api/models.html#scarches.models.TRVAE), modified to use Mean Squared Error loss instead of the default Negative Binomial loss.
48 | - Spatial epigenomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Poisson distribution.
49 | - Spatial multiomics: it's possible to use multi-omics models such as [MultiVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.MULTIVI.html#scvi.model.MULTIVI), or use the concatenation of the results from the different models.
50 | - **Best candidates for the number of domains**: CellCharter offers a [method to find multiple best candidates](https://cellcharter.readthedocs.io/en/latest/generated/cellcharter.tl.ClusterAutoK.html) for the number of domains, based on the stability of a certain number of domains across multiple runs.
51 | - **Domain characterization**: CellCharter provides a set of tools to characterize and compare the spatial domains, such as domain proportion, cell type enrichment, (differential) neighborhood enrichment, and domain shape characterization.
52 |
53 | Since CellCharter 0.3.0, we moved the implementation of the Gaussian Mixture Model (GMM) from [PyCave](https://github.com/borchero/pycave), not maintained anymore, to [TorchGMM](https://github.com/CSOgroup/torchgmm), a fork of PyCave maintained by the CSOgroup. This change allows us to have a more stable and maintained implementation of GMM that is compatible with the most recent versions of PyTorch.
54 |
55 | ## Getting started
56 |
57 | Please refer to the [documentation][link-docs]. In particular, the
58 |
59 | - [API documentation][link-api].
60 | - [Tutorials][link-tutorial]
61 |
62 | ## Installation
63 |
64 | 1. Create a conda or pyenv environment
65 | 2. Install Python >= 3.10,<3.13 and [PyTorch](https://pytorch.org) >= 1.12.0. If you are planning to use a GPU, make sure to download and install the correct version of PyTorch first from [here](https://pytorch.org/get-started/locally/).
66 | 3. Install the library used for dimensionality reduction and batch effect removal according to the data type you are planning to analyze:
67 | - [scVI](https://github.com/scverse/scvi-tools) for spatial transcriptomics and/or epigenomics data such as 10x Visium and Xenium, Nanostring CosMx, Vizgen MERSCOPE, Stereo-seq, DBiT-seq, MERFISH and seqFISH data.
68 | - A modified version of [scArches](https://github.com/theislab/scarches)'s TRVAE model for spatial proteomics data such as Akoya CODEX, Lunaphore COMET, CyCIF, IMC and MIBI-TOF data.
69 | 4. Install CellCharter using pip:
70 |
71 | ```bash
72 | pip install cellcharter
73 | ```
74 |
75 | We report here an example of an installation aimed at analyzing spatial transcriptomics data (and thus installing `scvi-tools`).
76 | This example is based on a Linux CentOS 7 system with an NVIDIA A100 GPU.
77 | It will install Pytorch for GPU by default.
78 |
79 | ```bash
80 | conda create -n cellcharter-env -c conda-forge python=3.12
81 | conda activate cellcharter-env
82 | pip install scvi-tools
83 | pip install cellcharter
84 | ```
85 |
86 | Note: a different system may require different commands to install PyTorch and JAX. Refer to their respective documentation for more details.
87 |
88 | ## Contribution
89 |
90 | If you found a bug or you want to propose a new feature, please use the [issue tracker][issue-tracker].
91 |
92 | [issue-tracker]: https://github.com/CSOgroup/cellcharter/issues
93 | [link-docs]: https://cellcharter.readthedocs.io
94 | [link-api]: https://cellcharter.readthedocs.io/en/latest/api.html
95 | [link-tutorial]: https://cellcharter.readthedocs.io/en/latest/notebooks/codex_mouse_spleen.html
96 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/_static/cellcharter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/cellcharter.png
--------------------------------------------------------------------------------
/docs/_static/cellcharter_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/cellcharter_workflow.png
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */
2 | div.cell_output table.dataframe {
3 | font-size: 0.8em;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/_static/spatial_clusters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_static/spatial_clusters.png
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/_templates/.gitkeep
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ## Graph
4 |
5 | ```{eval-rst}
6 | .. module:: cellcharter.gr
7 | .. currentmodule:: cellcharter
8 |
9 | .. autosummary::
10 | :toctree: generated
11 |
12 | gr.aggregate_neighbors
13 | gr.nhood_enrichment
14 | gr.remove_long_links
15 | gr.remove_intra_cluster_links
16 |
17 | ```
18 |
19 | ## Tools
20 |
21 | ```{eval-rst}
22 | .. module:: cellcharter.tl
23 | .. currentmodule:: cellcharter
24 |
25 | .. autosummary::
26 | :toctree: generated
27 |
28 | tl.Cluster
29 | tl.ClusterAutoK
30 | tl.TRVAE
31 | ```
32 |
33 | ## Plotting
34 |
35 | ```{eval-rst}
36 | .. module:: cellcharter.pl
37 | .. currentmodule:: cellcharter
38 |
39 | .. autosummary::
40 | :toctree: generated
41 |
42 | pl.autok_stability
43 | pl.nhood_enrichment
44 | pl.proportion
45 | ```
46 |
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | ```{include} ../CHANGELOG.md
2 |
3 | ```
4 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 | import sys
9 | from datetime import datetime
10 | from importlib.metadata import metadata
11 | from pathlib import Path
12 |
13 | HERE = Path(__file__).parent
14 | sys.path.insert(0, str(HERE / "extensions"))
15 |
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | info = metadata("cellcharter")
20 | project_name = info["Name"]
21 | author = info["Author"]
22 | copyright = f"{datetime.now():%Y}, {author}."
23 | version = info["Version"]
24 | repository_url = f"https://github.com/CSOgroup/{project_name}"
25 |
26 | # The full version, including alpha/beta/rc tags
27 | release = info["Version"]
28 |
29 | bibtex_bibfiles = ["references.bib"]
30 | templates_path = ["_templates"]
31 | nitpicky = True # Warn about broken links
32 | needs_sphinx = "4.0"
33 |
34 | html_context = {
35 | "display_github": True, # Integrate GitHub
36 | "github_user": "CSOgroup", # Username
37 | "github_repo": project_name, # Repo name
38 | "github_version": "main", # Version
39 | "conf_py_path": "/docs/", # Path in the checkout to the docs root
40 | }
41 |
42 | # -- General configuration ---------------------------------------------------
43 |
44 | # Add any Sphinx extension module names here, as strings.
45 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
46 | extensions = [
47 | "myst_nb",
48 | "sphinx_copybutton",
49 | "sphinx.ext.autodoc",
50 | "sphinx.ext.intersphinx",
51 | "sphinx.ext.autosummary",
52 | "sphinx.ext.napoleon",
53 | "sphinxcontrib.bibtex",
54 | "sphinx_autodoc_typehints",
55 | "sphinx.ext.mathjax",
56 | "IPython.sphinxext.ipython_console_highlighting",
57 | "sphinxext.opengraph",
58 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
59 | ]
60 |
61 | autosummary_generate = True
62 | autodoc_member_order = "groupwise"
63 | default_role = "literal"
64 | napoleon_google_docstring = False
65 | napoleon_numpy_docstring = True
66 | napoleon_include_init_with_doc = False
67 | napoleon_use_rtype = True # having a separate entry generally helps readability
68 | napoleon_use_param = True
69 | myst_heading_anchors = 6 # create anchors for h1-h6
70 | myst_enable_extensions = [
71 | "amsmath",
72 | "colon_fence",
73 | "deflist",
74 | "dollarmath",
75 | "html_image",
76 | "html_admonition",
77 | ]
78 | myst_url_schemes = ("http", "https", "mailto")
79 | nb_output_stderr = "remove"
80 | nb_execution_mode = "off"
81 | nb_merge_streams = True
82 | typehints_defaults = "braces"
83 |
84 | source_suffix = {
85 | ".rst": "restructuredtext",
86 | ".ipynb": "myst-nb",
87 | ".myst": "myst-nb",
88 | }
89 |
90 | intersphinx_mapping = {
91 | "python": ("https://docs.python.org/3", None),
92 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
93 | "numpy": ("https://numpy.org/doc/stable/", None),
94 | }
95 |
96 | # List of patterns, relative to source directory, that match files and
97 | # directories to ignore when looking for source files.
98 | # This pattern also affects html_static_path and html_extra_path.
99 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
100 |
101 |
102 | # -- Options for HTML output -------------------------------------------------
103 |
104 | # The theme to use for HTML and HTML Help pages. See the documentation for
105 | # a list of builtin themes.
106 | #
107 | html_theme = "sphinx_book_theme"
108 | html_static_path = ["_static"]
109 | html_css_files = ["css/custom.css"]
110 |
111 | html_title = project_name
112 |
113 | html_theme_options = {
114 | "repository_url": repository_url,
115 | "use_repository_button": True,
116 | "path_to_docs": "docs/",
117 | "navigation_with_keys": False,
118 | }
119 |
120 | pygments_style = "default"
121 |
122 | nitpick_ignore = [
123 | # If building the documentation fails because of a missing link that is outside your control,
124 | # you can add an exception to this list.
125 | # ("py:class", "igraph.Graph"),
126 | ]
127 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing guide
2 |
3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too.
4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important
5 | information to get you started on contributing.
6 |
7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer
8 | to the [scanpy developer guide][].
9 |
10 | ## Installing dev dependencies
11 |
12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build
13 | the documentation_. It's easy to install them using `pip`:
14 |
15 | ```bash
16 | cd cellcharter
17 | pip install -e ".[dev,test,doc]"
18 | ```
19 |
20 | ## Code-style
21 |
22 | This package uses [pre-commit][] to enforce consistent code-styles.
23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message.
24 |
25 | To enable pre-commit locally, simply run
26 |
27 | ```bash
28 | pre-commit install
29 | ```
30 |
31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time.
32 |
33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before
34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message.
35 |
36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use
37 |
38 | ```bash
39 | git pull --rebase
40 | ```
41 |
42 | to integrate the changes into yours.
43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage.
44 |
45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors]
46 | and [prettier][prettier-editors].
47 |
48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/
49 | [prettier-editors]: https://prettier.io/docs/en/editors.html
50 |
51 | ## Writing tests
52 |
53 | ```{note}
54 | Remember to first install the package with `pip install -e '.[dev,test]'`
55 | ```
56 |
57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added
58 | to the package.
59 |
60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the
61 | command line by executing
62 |
63 | ```bash
64 | pytest
65 | ```
66 |
67 | in the root of the repository.
68 |
69 | ### Continuous integration
70 |
71 | Continuous integration will automatically run the tests on all pull requests and test
72 | against the minimum and maximum supported Python version.
73 |
74 | Additionally, there's a CI job that tests against pre-releases of all dependencies
75 | (if there are any). The purpose of this check is to detect incompatibilities
76 | of new package versions early on and gives you time to fix the issue or reach
77 | out to the developers of the dependency before the package is released to a wider audience.
78 |
79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests
80 |
81 | ## Publishing a release
82 |
83 | ### Updating the version number
84 |
85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief
86 |
87 | > Given a version number MAJOR.MINOR.PATCH, increment the:
88 | >
89 | > 1. MAJOR version when you make incompatible API changes,
90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and
91 | > 3. PATCH version when you make backwards compatible bug fixes.
92 | >
93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format.
94 |
95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub.
96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI.
97 |
98 | ## Writing documentation
99 |
100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:
101 |
102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types
106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
107 |
108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
109 | on how to write documentation.
110 |
111 | ### Tutorials with myst-nb and jupyter notebooks
112 |
113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][].
114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells.
115 | It is your responsibility to update and re-run the notebook whenever necessary.
116 |
117 | If you are interested in automatically running notebooks as part of the continuous integration, please check
118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse`
119 | repository.
120 |
121 | #### Hints
122 |
123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
124 | if you do so can sphinx automatically create a link to the external documentation.
125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to
126 | the `nitpick_ignore` list in `docs/conf.py`
127 |
128 | #### Building the docs locally
129 |
130 | ```bash
131 | cd docs
132 | make html
133 | open _build/html/index.html
134 | ```
135 |
136 |
137 |
138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html
139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html
140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui
141 | [codecov]: https://about.codecov.io/sign-up/
142 | [codecov docs]: https://docs.codecov.com/docs
143 | [codecov bot]: https://docs.codecov.com/docs/team-bot
144 | [codecov app]: https://github.com/apps/codecov
145 | [pre-commit.ci]: https://pre-commit.ci/
146 | [readthedocs.org]: https://readthedocs.org/
147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/
148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/
149 | [pre-commit]: https://pre-commit.com/
150 | [anndata]: https://github.com/scverse/anndata
151 | [mudata]: https://github.com/scverse/mudata
152 | [pytest]: https://docs.pytest.org/
153 | [semver]: https://semver.org/
154 | [sphinx]: https://www.sphinx-doc.org/en/master/
155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html
156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html
158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints
159 | [pypi]: https://pypi.org/
160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository
161 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | from __future__ import annotations
4 |
5 | import re
6 | from collections.abc import Generator, Iterable
7 |
8 | from sphinx.application import Sphinx
9 | from sphinx.ext.napoleon import NumpyDocstring
10 |
11 |
12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
13 | for line in lines:
14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line):
15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
16 | else:
17 | yield line
18 |
19 |
20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]:
21 | lines_raw = self._dedent(self._consume_to_next_section())
22 | if lines_raw[0] == ":":
23 | del lines_raw[0]
24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw)))
25 | if lines and lines[-1]:
26 | lines.append("")
27 | return lines
28 |
29 |
30 | def setup(app: Sphinx):
31 | """Set app."""
32 | NumpyDocstring._parse_returns_section = _parse_returns_section
33 |
--------------------------------------------------------------------------------
/docs/graph.md:
--------------------------------------------------------------------------------
1 | # Graph
2 |
3 | ```{eval-rst}
4 | .. module:: cellcharter.gr
5 | .. currentmodule:: cellcharter
6 |
7 | .. autosummary::
8 | :toctree: generated
9 |
10 | gr.aggregate_neighbors
11 | gr.connected_components
12 | gr.diff_nhood_enrichment
13 | gr.enrichment
14 | gr.nhood_enrichment
15 | gr.remove_long_links
16 | gr.remove_intra_cluster_links
17 | ```
18 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{include} ../README.md
2 |
3 | ```
4 |
5 | ```{toctree}
6 | :hidden: true
7 | :maxdepth: 2
8 | :caption: API
9 |
10 | graph.md
11 | tools.md
12 | plotting.md
13 | ```
14 |
15 | ```{toctree}
16 | :hidden: true
17 | :maxdepth: 2
18 | :caption: Tutorials
19 |
20 | notebooks/codex_mouse_spleen
21 | notebooks/cosmx_human_nsclc
22 |
23 |
24 | ```
25 |
--------------------------------------------------------------------------------
/docs/notebooks/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/notebooks/data/.gitkeep
--------------------------------------------------------------------------------
/docs/notebooks/example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Example notebook"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "from anndata import AnnData\n",
18 | "import cellcharter"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "adata = AnnData(np.random.normal(size=(20, 10)))"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 3,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "Implement a preprocessing function here."
40 | ]
41 | },
42 | {
43 | "data": {
44 | "text/plain": [
45 | "0"
46 | ]
47 | },
48 | "execution_count": 3,
49 | "metadata": {},
50 | "output_type": "execute_result"
51 | }
52 | ],
53 | "source": [
54 | "cellcharter.pp.basic_preproc(adata)"
55 | ]
56 | }
57 | ],
58 | "metadata": {
59 | "kernelspec": {
60 | "display_name": "Python 3",
61 | "language": "python",
62 | "name": "python3"
63 | },
64 | "language_info": {
65 | "codemirror_mode": {
66 | "name": "ipython",
67 | "version": 3
68 | },
69 | "file_extension": ".py",
70 | "mimetype": "text/x-python",
71 | "name": "python",
72 | "nbconvert_exporter": "python",
73 | "pygments_lexer": "ipython3",
74 | "version": "3.8.13"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 4
79 | }
80 |
--------------------------------------------------------------------------------
/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/attr.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/attr.pkl
--------------------------------------------------------------------------------
/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/model_params.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/model_params.pt
--------------------------------------------------------------------------------
/docs/notebooks/tutorial_models/codex_mouse_spleen_trvae/var_names.csv:
--------------------------------------------------------------------------------
1 | CD45
2 | Ly6C
3 | TCR
4 | Ly6G
5 | CD19
6 | CD169
7 | CD106
8 | CD3
9 | CD1632
10 | CD8a
11 | CD90
12 | F480
13 | CD11c
14 | Ter119
15 | CD11b
16 | IgD
17 | CD27
18 | CD5
19 | B220
20 | CD71
21 | CD31
22 | CD4
23 | IgM
24 | CD79b
25 | ERTR7
26 | CD35
27 | CD2135
28 | CD44
29 | NKp46
30 |
--------------------------------------------------------------------------------
/docs/plotting.md:
--------------------------------------------------------------------------------
1 | # Plotting
2 |
3 | ```{eval-rst}
4 | .. module:: cellcharter.pl
5 | .. currentmodule:: cellcharter
6 |
7 | .. autosummary::
8 | :toctree: generated
9 |
10 | pl.autok_stability
11 | pl.diff_nhood_enrichment
12 | pl.enrichment
13 | pl.nhood_enrichment
14 | pl.proportion
15 | pl.boundaries
16 | pl.shape_metrics
17 | ```
18 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{Wolf2018,
2 | author = {Wolf, F. Alexander
3 | and Angerer, Philipp
4 | and Theis, Fabian J.},
5 | title = {SCANPY: large-scale single-cell gene expression data analysis},
6 | journal = {Genome Biology},
7 | year = {2018},
8 | month = {Feb},
9 | day = {06},
10 | volume = {19},
11 | number = {1},
12 | pages = {15},
13 | abstract = {Scanpy is a scalable toolkit for analyzing single-cell gene expression data. It includes methods for preprocessing, visualization, clustering, pseudotime and trajectory inference, differential expression testing, and simulation of gene regulatory networks. Its Python-based implementation efficiently deals with data sets of more than one million cells (https://github.com/theislab/Scanpy). Along with Scanpy, we present AnnData, a generic class for handling annotated data matrices (https://github.com/theislab/anndata).},
14 | issn = {1474-760X},
15 | doi = {10.1186/s13059-017-1382-0},
16 | url = {https://doi.org/10.1186/s13059-017-1382-0}
17 | }
18 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/docs/tools.md:
--------------------------------------------------------------------------------
1 | # Tools
2 |
3 | ```{eval-rst}
4 | .. module:: cellcharter.tl
5 | .. currentmodule:: cellcharter
6 |
7 | .. autosummary::
8 | :toctree: generated
9 |
10 | tl.boundaries
11 | tl.Cluster
12 | tl.ClusterAutoK
13 | tl.curl
14 | tl.elongation
15 | tl.linearity
16 | tl.purity
17 | tl.TRVAE
18 | ```
19 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "hatchling.build"
3 | requires = ["hatchling"]
4 |
5 | [tool.hatch.build]
6 | exclude = [
7 | "docs*",
8 | "tests*",
9 | ]
10 |
11 | [project]
12 | name = "cellcharter"
13 | version = "0.3.4"
14 | description = "A Python package for the identification, characterization and comparison of spatial clusters from spatial -omics data."
15 | readme = "README.md"
16 | requires-python = ">=3.10,<3.14"
17 | license = {file = "LICENSE"}
18 | authors = [
19 | {name = "CSO group"},
20 | ]
21 | maintainers = [
22 | {name = "Marco Varrone", email = "marco.varrone@unil.ch"},
23 | ]
24 | urls.Documentation = "https://cellcharter.readthedocs.io/"
25 | urls.Source = "https://github.com/CSOgroup/cellcharter"
26 | urls.Home-page = "https://github.com/CSOgroup/cellcharter"
27 | dependencies = [
28 | "anndata",
29 | "scikit-learn",
30 | "squidpy >= 1.6.3",
31 | "torchgmm >= 0.1.2",
32 | # for debug logging (referenced from the issue template)
33 | "session-info",
34 | "spatialdata",
35 | "spatialdata-plot",
36 | "rasterio",
37 | "sknw",
38 | ]
39 |
40 | [project.optional-dependencies]
41 | dev = [
42 | "pre-commit",
43 | "twine>=4.0.2"
44 | ]
45 | doc = [
46 | "docutils>=0.8,!=0.18.*,!=0.19.*",
47 | "sphinx>=4",
48 | "sphinx-book-theme>=1.0.0",
49 | "myst-nb>=1.1.0",
50 | "sphinxcontrib-bibtex>=1.0.0",
51 | "sphinx-autodoc-typehints",
52 | "sphinxext-opengraph",
53 | # For notebooks
54 | "ipykernel",
55 | "ipython",
56 | "sphinx-copybutton",
57 | "sphinx-design"
58 | ]
59 | test = [
60 | "pytest",
61 | "pytest-cov",
62 | ]
63 | transcriptomics = [
64 | "scvi-tools",
65 | ]
66 | proteomics = [
67 | "scarches",
68 | ]
69 |
70 | [tool.coverage.run]
71 | source = ["cellcharter"]
72 | omit = [
73 | "**/test_*.py",
74 | ]
75 |
76 | [tool.pytest.ini_options]
77 | testpaths = ["tests"]
78 | xfail_strict = true
79 | addopts = [
80 | "--import-mode=importlib", # allow using test files with same name
81 | ]
82 | filterwarnings = [
83 | "ignore::anndata.OldFormatWarning",
84 | "ignore:.*this fit will run with no optimizer.*",
85 | "ignore:.*Consider increasing the value of the `num_workers` argument.*",
86 | ]
87 |
88 | [tool.isort]
89 | include_trailing_comma = true
90 | multi_line_output = 3
91 | profile = "black"
92 | skip_glob = ["docs/*"]
93 |
94 | [tool.black]
95 | line-length = 120
96 | target-version = ['py38']
97 | include = '\.pyi?$'
98 | exclude = '''
99 | (
100 | /(
101 | \.eggs
102 | | \.git
103 | | \.hg
104 | | \.mypy_cache
105 | | \.tox
106 | | \.venv
107 | | _build
108 | | buck-out
109 | | build
110 | | dist
111 | )/
112 | )
113 | '''
114 |
115 | [tool.jupytext]
116 | formats = "ipynb,md"
117 |
118 | [tool.cruft]
119 | skip = [
120 | "tests",
121 | "src/**/__init__.py",
122 | "src/**/basic.py",
123 | "docs/api.md",
124 | "docs/changelog.md",
125 | "docs/references.bib",
126 | "docs/references.md",
127 | "docs/notebooks"
128 | ]
129 |
--------------------------------------------------------------------------------
/src/cellcharter/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 |
3 | from . import datasets, gr, pl, tl
4 |
5 | __all__ = ["gr", "pl", "tl", "datasets"]
6 |
7 | __version__ = version("cellcharter")
8 |
--------------------------------------------------------------------------------
/src/cellcharter/_constants/_pkg_constants.py:
--------------------------------------------------------------------------------
1 | """Internal constants not exposed to the user."""
2 |
3 |
4 | class Key:
5 | class obs:
6 | @classmethod
7 | @property
8 | def sample(cls) -> str:
9 | return "sample"
10 |
--------------------------------------------------------------------------------
/src/cellcharter/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union
4 |
5 | import numpy as np
6 |
7 | AnyRandom = Union[int, np.random.RandomState, None]
8 |
9 |
10 | def str2list(value: Union[str, list]) -> list:
11 | """Check whether value is a string. If so, converts into a list containing value."""
12 | return [value] if isinstance(value, str) else value
13 |
--------------------------------------------------------------------------------
/src/cellcharter/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from ._dataset import * # noqa: F403
2 |
--------------------------------------------------------------------------------
/src/cellcharter/datasets/_dataset.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 |
3 | from squidpy.datasets._utils import AMetadata
4 |
5 | _codex_mouse_spleen = AMetadata(
6 | name="codex_mouse_spleen",
7 | doc_header="Pre-processed CODEX dataset of mouse spleen from `Goltsev et al "
8 | "`__.",
9 | shape=(707474, 29),
10 | url="https://figshare.com/ndownloader/files/38538101",
11 | )
12 |
13 | for name, var in copy(locals()).items():
14 | if isinstance(var, AMetadata):
15 | var._create_function(name, globals())
16 |
17 |
18 | __all__ = ["codex_mouse_spleen"] # noqa: F822
19 |
--------------------------------------------------------------------------------
/src/cellcharter/gr/__init__.py:
--------------------------------------------------------------------------------
1 | from ._aggr import aggregate_neighbors
2 | from ._build import connected_components, remove_intra_cluster_links, remove_long_links
3 | from ._group import enrichment
4 | from ._nhood import diff_nhood_enrichment, nhood_enrichment
5 |
--------------------------------------------------------------------------------
/src/cellcharter/gr/_aggr.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import warnings
4 | from typing import Optional, Union
5 |
6 | import numpy as np
7 | import scipy.sparse as sps
8 | from anndata import AnnData
9 | from scipy.sparse import spdiags
10 | from squidpy._constants._pkg_constants import Key as sqKey
11 | from squidpy._docs import d
12 | from tqdm.auto import tqdm
13 |
14 | from cellcharter._constants._pkg_constants import Key
15 | from cellcharter._utils import str2list
16 |
17 |
18 | def _aggregate_mean(adj, x):
19 | return adj @ x
20 |
21 |
22 | def _aggregate_var(adj, x):
23 | mean = adj @ x
24 | mean_squared = adj @ (x * x)
25 | return mean_squared - mean * mean
26 |
27 |
28 | def _aggregate(adj, x, method):
29 | if method == "mean":
30 | return _aggregate_mean(adj, x)
31 | elif method == "var":
32 | return _aggregate_var(adj, x)
33 | else:
34 | raise NotImplementedError
35 |
36 |
37 | def _mul_broadcast(mat1, mat2):
38 | return spdiags(mat2, 0, len(mat2), len(mat2)) * mat1
39 |
40 |
41 | def _hop(adj_hop, adj, adj_visited=None):
42 | adj_hop = adj_hop @ adj
43 |
44 | if adj_visited is not None:
45 | adj_hop = adj_hop > adj_visited # Logical not for sparse matrices
46 | adj_visited = adj_visited + adj_hop
47 |
48 | return adj_hop, adj_visited
49 |
50 |
51 | def _normalize(adj):
52 | deg = np.array(np.sum(adj, axis=1)).squeeze()
53 |
54 | with warnings.catch_warnings():
55 | # If a cell doesn't have neighbors deg = 0 -> divide by zero
56 | warnings.filterwarnings(action="ignore", category=RuntimeWarning)
57 | deg_inv = 1 / deg
58 | deg_inv[deg_inv == float("inf")] = 0
59 |
60 | return _mul_broadcast(adj, deg_inv)
61 |
62 |
63 | def _setdiag(array, value):
64 | if isinstance(array, sps.csr_matrix):
65 | array = array.tolil()
66 | array.setdiag(value)
67 | array = array.tocsr()
68 | if value == 0:
69 | array.eliminate_zeros()
70 | return array
71 |
72 |
73 | def _aggregate_neighbors(
74 | adj: sps.spmatrix,
75 | X: np.ndarray,
76 | nhood_layers: list,
77 | aggregations: Optional[Union[str, list]] = "mean",
78 | disable_tqdm: bool = True,
79 | ) -> np.ndarray:
80 | adj = adj.astype(bool)
81 | adj = _setdiag(adj, 0)
82 | adj_hop = adj.copy()
83 | adj_visited = _setdiag(adj.copy(), 1)
84 |
85 | Xs = []
86 | for i in tqdm(range(0, max(nhood_layers) + 1), disable=disable_tqdm):
87 | if i in nhood_layers:
88 | if i == 0:
89 | Xs.append(X)
90 | else:
91 | if i > 1:
92 | adj_hop, adj_visited = _hop(adj_hop, adj, adj_visited)
93 | adj_hop_norm = _normalize(adj_hop)
94 |
95 | for agg in aggregations:
96 | x = _aggregate(adj_hop_norm, X, agg)
97 | Xs.append(x)
98 | if sps.issparse(X):
99 | return sps.hstack(Xs)
100 | else:
101 | return np.hstack(Xs)
102 |
103 |
104 | @d.dedent
105 | def aggregate_neighbors(
106 | adata: AnnData,
107 | n_layers: Union[int, list],
108 | aggregations: Optional[Union[str, list]] = "mean",
109 | connectivity_key: Optional[str] = None,
110 | use_rep: Optional[str] = None,
111 | sample_key: Optional[str] = None,
112 | out_key: Optional[str] = "X_cellcharter",
113 | copy: bool = False,
114 | ) -> np.ndarray | None:
115 | """
116 | Aggregate the features from each neighborhood layers and concatenate them, and optionally with the cells' features, into a single vector.
117 |
118 | Parameters
119 | ----------
120 | %(adata)s
121 | n_layers
122 | Which neighborhood layers to aggregate from.
123 | If :class:`int`, the output vector includes the cells' features and the aggregated features of the neighbors until the layer at distance ``n_layers``, i.e. cells | 1-hop neighbors | ... | ``n_layers``-hop.
124 | If :class:`list`, every element corresponds to the distances at which the neighbors' features will be aggregated and concatenated. For example, [0, 1, 3] corresponds to cells|1-hop neighbors|3-hop neighbors.
125 | aggregations
126 | Which functions to use to aggregate the neighbors features. Default: ```mean``.
127 | connectivity_key
128 | Key in :attr:`anndata.AnnData.obsp` where spatial connectivities are stored.
129 | use_rep
130 | Key of the features. If :class:`None`, adata.X is used. Else, the key is used to access the field in the AnnData .obsm mapping.
131 | sample_key
132 | Key in :attr:`anndata.AnnData.obs` where the sample labels are stored. Must be different from :class:`None` if adata contains multiple samples.
133 | out_key
134 | Key in :attr:`anndata.AnnData.obsm` where the output matrix is stored if ``copy = False``.
135 | %(copy)s
136 |
137 | Returns
138 | -------
139 | If ``copy = True``, returns a :class:`numpy.ndarray` of the features aggregated and concatenated.
140 |
141 | Otherwise, modifies the ``adata`` with the following key:
142 | - :attr:`anndata.AnnData.obsm` ``['{{out_key}}']`` - the above mentioned :class:`numpy.ndarray`.
143 | """
144 | connectivity_key = sqKey.obsp.spatial_conn(connectivity_key)
145 | sample_key = Key.obs.sample if sample_key is None else sample_key
146 | aggregations = str2list(aggregations)
147 |
148 | X = adata.X if use_rep is None else adata.obsm[use_rep]
149 |
150 | if isinstance(n_layers, int):
151 | n_layers = list(range(n_layers + 1))
152 |
153 | if sps.issparse(X):
154 | X_aggregated = sps.dok_matrix(
155 | (X.shape[0], X.shape[1] * ((len(n_layers) - 1) * len(aggregations) + 1)), dtype=np.float32
156 | )
157 | else:
158 | X_aggregated = np.empty(
159 | (X.shape[0], X.shape[1] * ((len(n_layers) - 1) * len(aggregations) + 1)), dtype=np.float32
160 | )
161 |
162 | if sample_key in adata.obs:
163 | samples = adata.obs[sample_key].unique()
164 | sample_idxs = [adata.obs[sample_key] == sample for sample in samples]
165 | else:
166 | sample_idxs = [np.arange(adata.shape[0])]
167 |
168 | for idxs in tqdm(sample_idxs, disable=(len(sample_idxs) == 1)):
169 | X_sample_aggregated = _aggregate_neighbors(
170 | adj=adata[idxs].obsp[connectivity_key],
171 | X=X[idxs],
172 | nhood_layers=n_layers,
173 | aggregations=aggregations,
174 | disable_tqdm=(len(sample_idxs) != 1),
175 | )
176 | X_aggregated[idxs] = X_sample_aggregated
177 |
178 | if isinstance(X_aggregated, sps.dok_matrix):
179 | X_aggregated = X_aggregated.tocsr()
180 |
181 | if copy:
182 | return X_aggregated
183 |
184 | adata.obsm[out_key] = X_aggregated
185 |
--------------------------------------------------------------------------------
/src/cellcharter/gr/_build.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import scipy.sparse as sps
6 | from anndata import AnnData
7 | from scipy.sparse import csr_matrix
8 | from squidpy._constants._pkg_constants import Key
9 | from squidpy._docs import d
10 | from squidpy.gr._utils import _assert_connectivity_key
11 |
12 |
13 | @d.dedent
14 | def remove_long_links(
15 | adata: AnnData,
16 | distance_percentile: float = 99.0,
17 | connectivity_key: str | None = None,
18 | distances_key: str | None = None,
19 | neighs_key: str | None = None,
20 | copy: bool = False,
21 | ) -> tuple[csr_matrix, csr_matrix] | None:
22 | """
23 | Remove links between cells at a distance bigger than a certain percentile of all positive distances.
24 |
25 | It is designed for data with generic coordinates.
26 |
27 | Parameters
28 | ----------
29 | %(adata)s
30 |
31 | distance_percentile
32 | Percentile of the distances between cells over which links are trimmed after the network is built.
33 | %(conn_key)s
34 |
35 | distances_key
36 | Key in :attr:`anndata.AnnData.obsp` where spatial distances are stored.
37 | Default is: :attr:`anndata.AnnData.obsp` ``['{{Key.obsp.spatial_dist()}}']``.
38 | neighs_key
39 | Key in :attr:`anndata.AnnData.uns` where the parameters from gr.spatial_neighbors are stored.
40 | Default is: :attr:`anndata.AnnData.uns` ``['{{Key.uns.spatial_neighs()}}']``.
41 |
42 | %(copy)s
43 |
44 | Returns
45 | -------
46 | If ``copy = True``, returns a :class:`tuple` with the new spatial connectivities and distances matrices.
47 |
48 | Otherwise, modifies the ``adata`` with the following keys:
49 | - :attr:`anndata.AnnData.obsp` ``['{{connectivity_key}}']`` - the new spatial connectivities.
50 | - :attr:`anndata.AnnData.obsp` ``['{{distances_key}}']`` - the new spatial distances.
51 | - :attr:`anndata.AnnData.uns` ``['{{neighs_key}}']`` - :class:`dict` containing parameters.
52 | """
53 | connectivity_key = Key.obsp.spatial_conn(connectivity_key)
54 | distances_key = Key.obsp.spatial_dist(distances_key)
55 | neighs_key = Key.uns.spatial_neighs(neighs_key)
56 | _assert_connectivity_key(adata, connectivity_key)
57 | _assert_connectivity_key(adata, distances_key)
58 |
59 | conns, dists = adata.obsp[connectivity_key], adata.obsp[distances_key]
60 |
61 | if copy:
62 | conns, dists = conns.copy(), dists.copy()
63 |
64 | threshold = np.percentile(np.array(dists[dists != 0]).squeeze(), distance_percentile)
65 | conns[dists > threshold] = 0
66 | dists[dists > threshold] = 0
67 |
68 | conns.eliminate_zeros()
69 | dists.eliminate_zeros()
70 |
71 | if copy:
72 | return conns, dists
73 | else:
74 | adata.uns[neighs_key]["params"]["radius"] = threshold
75 |
76 |
77 | def _remove_intra_cluster_links(labels, adjacency):
78 | target_labels = np.array(labels.iloc[adjacency.indices])
79 | source_labels = np.array(
80 | labels.iloc[np.repeat(np.arange(adjacency.indptr.shape[0] - 1), np.diff(adjacency.indptr))]
81 | )
82 |
83 | inter_cluster_mask = (source_labels != target_labels).astype(int)
84 |
85 | adjacency.data *= inter_cluster_mask
86 | adjacency.eliminate_zeros()
87 |
88 | return adjacency
89 |
90 |
91 | @d.dedent
92 | def remove_intra_cluster_links(
93 | adata: AnnData,
94 | cluster_key: str,
95 | connectivity_key: str | None = None,
96 | distances_key: str | None = None,
97 | copy: bool = False,
98 | ) -> tuple[csr_matrix, csr_matrix] | None:
99 | """
100 | Remove links between cells that belong to the same cluster.
101 |
102 | Used in :func:`cellcharter.gr.nhood_enrichment` to consider only interactions between cells of different clusters.
103 |
104 | Parameters
105 | ----------
106 | %(adata)s
107 |
108 | cluster_key
109 | Key in :attr:`anndata.AnnData.obs` of the cluster labeling to consider.
110 |
111 | %(conn_key)s
112 |
113 | distances_key
114 | Key in :attr:`anndata.AnnData.obsp` where spatial distances are stored.
115 | Default is: :attr:`anndata.AnnData.obsp` ``['{{Key.obsp.spatial_dist()}}']``.
116 |
117 | %(copy)s
118 |
119 | Returns
120 | -------
121 | If ``copy = True``, returns a :class:`tuple` with the new spatial connectivities and distances matrices.
122 |
123 | Otherwise, modifies the ``adata`` with the following keys:
124 | - :attr:`anndata.AnnData.obsp` ``['{{connectivity_key}}']`` - the new spatial connectivities.
125 | - :attr:`anndata.AnnData.obsp` ``['{{distances_key}}']`` - the new spatial distances.
126 | """
127 | connectivity_key = Key.obsp.spatial_conn(connectivity_key)
128 | distances_key = Key.obsp.spatial_dist(distances_key)
129 | _assert_connectivity_key(adata, connectivity_key)
130 | _assert_connectivity_key(adata, distances_key)
131 |
132 | conns = adata.obsp[connectivity_key].copy() if copy else adata.obsp[connectivity_key]
133 | dists = adata.obsp[distances_key].copy() if copy else adata.obsp[distances_key]
134 |
135 | conns, dists = (_remove_intra_cluster_links(adata.obs[cluster_key], adjacency) for adjacency in [conns, dists])
136 |
137 | if copy:
138 | return conns, dists
139 |
140 |
141 | def _connected_components(adj: sps.spmatrix, min_cells: int = 250, count: int = 0) -> np.ndarray:
142 | n_components, labels = sps.csgraph.connected_components(adj, return_labels=True)
143 | components, counts = np.unique(labels, return_counts=True)
144 |
145 | small_components = components[counts < min_cells]
146 | small_components_idxs = np.in1d(labels, small_components)
147 |
148 | labels[small_components_idxs] = -1
149 | labels[~small_components_idxs] = pd.factorize(labels[~small_components_idxs])[0] + count
150 |
151 | return labels, (n_components - len(small_components))
152 |
153 |
154 | @d.dedent
155 | def connected_components(
156 | adata: AnnData,
157 | cluster_key: str = None,
158 | min_cells: int = 250,
159 | connectivity_key: str = None,
160 | out_key: str = "component",
161 | copy: bool = False,
162 | ) -> None | np.ndarray:
163 | """
164 | Compute the connected components of the spatial graph.
165 |
166 | Parameters
167 | ----------
168 | %(adata)s
169 | cluster_key
170 | Key in :attr:`anndata.AnnData.obs` where the cluster labels are stored. If :class:`None`, the connected components are computed on the whole dataset.
171 | min_cells
172 | Minimum number of cells for a connected component to be considered.
173 | %(conn_key)s
174 | out_key
175 | Key in :attr:`anndata.AnnData.obs` where the output matrix is stored if ``copy = False``.
176 | %(copy)s
177 |
178 | Returns
179 | -------
180 | If ``copy = True``, returns a :class:`numpy.ndarray` with the connected components labels.
181 |
182 | Otherwise, modifies the ``adata`` with the following key:
183 | - :attr:`anndata.AnnData.obs` ``['{{out_key}}']`` - - the above mentioned :class:`numpy.ndarray`.
184 | """
185 | connectivity_key = Key.obsp.spatial_conn(connectivity_key)
186 | output = pd.Series(index=adata.obs.index, dtype="object")
187 |
188 | count = 0
189 |
190 | if cluster_key is not None:
191 | cluster_values = adata.obs[cluster_key].unique()
192 |
193 | for cluster in cluster_values:
194 | adata_cluster = adata[adata.obs[cluster_key] == cluster]
195 |
196 | labels, n_components = _connected_components(
197 | adj=adata_cluster.obsp[connectivity_key], min_cells=min_cells, count=count
198 | )
199 | output[adata.obs[cluster_key] == cluster] = labels
200 | count += n_components
201 | else:
202 | labels, n_components = _connected_components(
203 | adj=adata.obsp[connectivity_key],
204 | min_cells=min_cells,
205 | )
206 | output.loc[:] = labels
207 |
208 | output = output.astype("category")
209 | output[output == -1] = np.nan
210 | output = output.cat.remove_unused_categories()
211 |
212 | if copy:
213 | return output.values
214 |
215 | adata.obs[out_key] = output
216 |
--------------------------------------------------------------------------------
/src/cellcharter/gr/_group.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from anndata import AnnData
6 | from squidpy._docs import d
7 | from tqdm import tqdm
8 |
9 |
10 | def _proportion(obs, id_key, val_key, normalize=True):
11 | df = pd.pivot(obs[[id_key, val_key]].value_counts().reset_index(), index=id_key, columns=val_key)
12 | df[df.isna()] = 0
13 | df.columns = df.columns.droplevel(0)
14 | if normalize:
15 | return df.div(df.sum(axis=1), axis=0)
16 | else:
17 | return df
18 |
19 |
20 | def _observed_permuted(annotations, group_key, label_key):
21 | annotations[group_key] = annotations[group_key].sample(frac=1).reset_index(drop=True).values
22 | return _proportion(annotations, id_key=label_key, val_key=group_key).reindex().T
23 |
24 |
25 | def _enrichment(observed, expected, log=True):
26 | enrichment = observed.div(expected, axis="index", level=0)
27 |
28 | if log:
29 | with np.errstate(divide="ignore"):
30 | enrichment = np.log2(enrichment)
31 | enrichment = enrichment.fillna(enrichment.min())
32 | return enrichment
33 |
34 |
35 | def _empirical_pvalues(observed, expected):
36 | pvalues = np.zeros(observed.shape)
37 | pvalues[observed.values > 0] = (
38 | 1 - np.sum(expected[:, observed.values > 0] < observed.values[observed.values > 0], axis=0) / expected.shape[0]
39 | )
40 | pvalues[observed.values < 0] = (
41 | 1 - np.sum(expected[:, observed.values < 0] > observed.values[observed.values < 0], axis=0) / expected.shape[0]
42 | )
43 | return pd.DataFrame(pvalues, columns=observed.columns, index=observed.index)
44 |
45 |
46 | @d.dedent
47 | def enrichment(
48 | adata: AnnData,
49 | group_key: str,
50 | label_key: str,
51 | pvalues: bool = False,
52 | n_perms: int = 1000,
53 | log: bool = True,
54 | observed_expected: bool = False,
55 | copy: bool = False,
56 | ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] | None:
57 | """
58 | Compute the enrichment of `label_key` in `group_key`.
59 |
60 | Parameters
61 | ----------
62 | %(adata)s
63 | group_key
64 | Key in :attr:`anndata.AnnData.obs` where groups are stored.
65 | label_key
66 | Key in :attr:`anndata.AnnData.obs` where labels are stored.
67 | pvalues
68 | If `True`, compute empirical p-values by permutation. It will result in a slower computation.
69 | n_perms
70 | Number of permutations to compute empirical p-values.
71 | log
72 | If `True` use log2 fold change, otherwise use fold change.
73 | observed_expected
74 | If `True`, return also the observed and expected proportions.
75 | %(copy)s
76 |
77 | Returns
78 | -------
79 | If ``copy = True``, returns a :class:`dict` with the following keys:
80 | - ``'enrichment'`` - the enrichment values.
81 | - ``'pvalue'`` - the enrichment pvalues (if `pvalues is True`).
82 | - ``'observed'`` - the observed proportions (if `observed_expected is True`).
83 | - ``'expected'`` - the expected proportions (if `observed_expected is True`).
84 |
85 | Otherwise, modifies the ``adata`` with the following keys:
86 | - :attr:`anndata.AnnData.uns` ``['{group_key}_{label_key}_nhood_enrichment']`` - the above mentioned dict.
87 | - :attr:`anndata.AnnData.uns` ``['{group_key}_{label_key}_nhood_enrichment']['params']`` - the parameters used.
88 | """
89 | observed = _proportion(adata.obs, id_key=label_key, val_key=group_key).reindex().T
90 | observed[observed.isna()] = 0
91 | if not pvalues:
92 | expected = adata.obs[group_key].value_counts() / adata.shape[0]
93 | # Repeat over the number of labels
94 | expected = pd.concat([expected] * len(observed.columns), axis=1, keys=observed.columns)
95 | else:
96 | annotations = adata.obs.copy()
97 |
98 | expected = [_observed_permuted(annotations, group_key, label_key) for _ in tqdm(range(n_perms))]
99 | expected = np.stack(expected, axis=0)
100 |
101 | empirical_pvalues = _empirical_pvalues(observed, expected)
102 |
103 | expected = np.mean(expected, axis=0)
104 | expected = pd.DataFrame(expected, columns=observed.columns, index=observed.index)
105 |
106 | enrichment = _enrichment(observed, expected, log=log)
107 |
108 | result = {"enrichment": enrichment}
109 |
110 | if observed_expected:
111 | result["observed"] = observed
112 | result["expected"] = expected
113 |
114 | if pvalues:
115 | result["pvalue"] = empirical_pvalues
116 |
117 | if copy:
118 | return result
119 | else:
120 | adata.uns[f"{group_key}_{label_key}_enrichment"] = result
121 | adata.uns[f"{group_key}_{label_key}_enrichment"]["params"] = {"log": log}
122 |
--------------------------------------------------------------------------------
/src/cellcharter/gr/_utils.py:
--------------------------------------------------------------------------------
1 | """Graph utilities."""
2 |
3 | from __future__ import annotations
4 |
5 | from anndata import AnnData
6 |
7 |
8 | def _assert_distances_key(adata: AnnData, key: str) -> None:
9 | if key not in adata.obsp:
10 | key_added = key.replace("_distances", "")
11 | raise KeyError(
12 | f"Spatial distances key `{key}` not found in `adata.obsp`. "
13 | f"Please run `squidpy.gr.spatial_neighbors(..., key_added={key_added!r})` first."
14 | )
15 |
--------------------------------------------------------------------------------
/src/cellcharter/pl/__init__.py:
--------------------------------------------------------------------------------
1 | from ._autok import autok_stability
2 | from ._group import enrichment, proportion
3 | from ._nhood import diff_nhood_enrichment, nhood_enrichment
4 | from ._shape import boundaries, shape_metrics
5 |
--------------------------------------------------------------------------------
/src/cellcharter/pl/_autok.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 |
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | import seaborn as sns
8 |
9 | from cellcharter.tl import ClusterAutoK
10 |
11 |
12 | def autok_stability(autok: ClusterAutoK, save: str | Path | None = None, return_ax: bool = False) -> None:
13 | """
14 | Plot the clustering stability.
15 |
16 | The clustering stability is computed by :class:`cellcharter.tl.ClusterAutoK`.
17 |
18 | Parameters
19 | ----------
20 | autok
21 | The fitted :class:`cellcharter.tl.ClusterAutoK` model.
22 | save
23 | Whether to save the plot.
24 | similarity_function
25 | The similarity function to use. Defaults to :func:`sklearn.metrics.fowlkes_mallows_score`.
26 |
27 | Returns
28 | -------
29 | Nothing, just plots the figure and optionally saves the plot.
30 | """
31 | robustness_df = pd.melt(
32 | pd.DataFrame.from_dict({k: autok.stability[i] for i, k in enumerate(autok.n_clusters[1:-1])}, orient="columns"),
33 | var_name="N. clusters",
34 | value_name="Stability",
35 | )
36 | ax = sns.lineplot(data=robustness_df, x="N. clusters", y="Stability")
37 | ax.set_xticks(autok.n_clusters[1:-1])
38 | if save:
39 | plt.savefig(save)
40 |
41 | ax.spines.right.set_visible(False)
42 | ax.spines.top.set_visible(False)
43 |
44 | if return_ax:
45 | return ax
46 |
--------------------------------------------------------------------------------
/src/cellcharter/pl/_group.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import warnings
4 | from pathlib import Path
5 |
6 | import matplotlib
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import pandas as pd
10 | import seaborn as sns
11 | from anndata import AnnData
12 | from matplotlib.colors import LogNorm, Normalize
13 | from matplotlib.legend_handler import HandlerTuple
14 | from scipy.cluster import hierarchy
15 | from squidpy._docs import d
16 | from squidpy.gr._utils import _assert_categorical_obs
17 | from squidpy.pl._color_utils import Palette_t, _get_palette, _maybe_set_colors
18 |
19 | try:
20 | from matplotlib.colormaps import get_cmap
21 | except ImportError:
22 | from matplotlib.pyplot import get_cmap
23 |
24 | from cellcharter.gr._group import _proportion
25 |
26 | empty_handle = matplotlib.patches.Rectangle((0, 0), 1, 1, fill=False, edgecolor="none", visible=False)
27 |
28 |
29 | @d.dedent
30 | def proportion(
31 | adata: AnnData,
32 | group_key: str,
33 | label_key: str,
34 | groups: list | None = None,
35 | labels: list | None = None,
36 | rotation_xlabel: int = 45,
37 | ncols: int = 1,
38 | normalize: bool = True,
39 | palette: Palette_t = None,
40 | figsize: tuple[float, float] | None = None,
41 | dpi: int | None = None,
42 | save: str | Path | None = None,
43 | **kwargs,
44 | ) -> None:
45 | """
46 | Plot the proportion of `y_key` in `x_key`.
47 |
48 | Parameters
49 | ----------
50 | %(adata)s
51 | group_key
52 | Key in :attr:`anndata.AnnData.obs` where groups are stored.
53 | label_key
54 | Key in :attr:`anndata.AnnData.obs` where labels are stored.
55 | groups
56 | List of groups to plot.
57 | labels
58 | List of labels to plot.
59 | rotation_xlabel
60 | Rotation in degrees of the ticks of the x axis.
61 | ncols
62 | Number of columns for the legend.
63 | normalize
64 | If `True` use relative frequencies, outherwise use counts.
65 | palette
66 | Categorical colormap for the clusters.
67 | If ``None``, use :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_colors']``, if available.
68 | %(plotting)s
69 | kwargs
70 | Keyword arguments for :func:`pandas.DataFrame.plot.bar`.
71 | Returns
72 | -------
73 | %(plotting_returns)s
74 | """
75 | _assert_categorical_obs(adata, key=group_key)
76 | _assert_categorical_obs(adata, key=label_key)
77 | _maybe_set_colors(source=adata, target=adata, key=label_key, palette=palette)
78 |
79 | clusters = adata.obs[label_key].cat.categories
80 | palette = _get_palette(adata, cluster_key=label_key, categories=clusters)
81 |
82 | df = _proportion(obs=adata.obs, id_key=group_key, val_key=label_key, normalize=normalize)
83 | df = df[df.columns[::-1]]
84 |
85 | if groups is not None:
86 | df = df.loc[groups, :]
87 |
88 | if labels is not None:
89 | df = df.loc[:, labels]
90 |
91 | plt.figure(dpi=dpi)
92 | ax = df.plot.bar(stacked=True, figsize=figsize, color=palette, rot=rotation_xlabel, ax=plt.gca(), **kwargs)
93 | ax.grid(False)
94 |
95 | handles, labels = ax.get_legend_handles_labels()
96 | lgd = ax.legend(handles[::-1], labels[::-1], loc="center left", ncol=ncols, bbox_to_anchor=(1.0, 0.5))
97 |
98 | if save:
99 | plt.savefig(save, bbox_extra_artists=(lgd, lgd), bbox_inches="tight")
100 |
101 |
102 | def _select_labels(fold_change, pvalues, labels, groups):
103 | col_name = fold_change.columns.name
104 | idx_name = fold_change.index.name
105 |
106 | if labels is not None:
107 | fold_change = fold_change.loc[labels]
108 |
109 | # The indexing removes the name of the index, so we need to set it back
110 | fold_change.index.name = idx_name
111 |
112 | if pvalues is not None:
113 | pvalues = pvalues.loc[labels]
114 |
115 | # The indexing removes the name of the index, so we need to set it back
116 | pvalues.index.name = idx_name
117 |
118 | if groups is not None:
119 | fold_change = fold_change.loc[:, groups]
120 |
121 | # The indexing removes the name of the columns, so we need to set it back
122 | fold_change.columns.name = col_name
123 |
124 | if pvalues is not None:
125 | pvalues = pvalues.loc[:, groups]
126 |
127 | # The indexing removes the name of the columns, so we need to set it back
128 | pvalues.columns.name = col_name
129 |
130 | return fold_change, pvalues
131 |
132 |
133 | # Calculate the dendrogram for rows and columns clustering
134 | def _reorder_labels(fold_change, pvalues, group_cluster, label_cluster):
135 | if label_cluster:
136 | order_rows = hierarchy.leaves_list(hierarchy.linkage(fold_change, method="complete"))
137 | fold_change = fold_change.iloc[order_rows]
138 |
139 | if pvalues is not None:
140 | pvalues = pvalues.iloc[order_rows]
141 |
142 | if group_cluster:
143 | order_cols = hierarchy.leaves_list(hierarchy.linkage(fold_change.T, method="complete"))
144 | fold_change = fold_change.iloc[:, order_cols]
145 |
146 | if pvalues is not None:
147 | pvalues = pvalues.iloc[:, order_cols]
148 | return fold_change, pvalues
149 |
150 |
151 | def _significance_colors(color, pvalues, significance):
152 | color[pvalues <= significance] = 0.0
153 | color[pvalues > significance] = 0.8
154 | return color
155 |
156 |
157 | def _pvalue_colorbar(ax, cmap_enriched, cmap_depleted, norm):
158 | from matplotlib.colorbar import ColorbarBase
159 | from mpl_toolkits.axes_grid1 import make_axes_locatable
160 |
161 | divider = make_axes_locatable(ax)
162 |
163 | # Append axes to the right of ax, with 5% width of ax
164 | cax1 = divider.append_axes("right", size="2%", pad=0.05)
165 |
166 | cbar1 = ColorbarBase(cax1, cmap=cmap_enriched, norm=norm, orientation="vertical")
167 |
168 | cbar1.ax.invert_yaxis()
169 | cbar1.ax.tick_params(labelsize=10)
170 | cbar1.set_ticks([], minor=True)
171 | cbar1.ax.set_title("p-value", fontdict={"fontsize": 10})
172 |
173 | if cmap_depleted is not None:
174 | cax2 = divider.append_axes("right", size="2%", pad=0.10)
175 |
176 | # Place colorbars next to each other and share ticks
177 | cbar2 = ColorbarBase(cax2, cmap=cmap_depleted, norm=norm, orientation="vertical")
178 | cbar2.ax.invert_yaxis()
179 | cbar2.ax.tick_params(labelsize=10)
180 | cbar2.set_ticks([], minor=True)
181 |
182 | cbar1.set_ticks([])
183 |
184 |
185 | def _enrichment_legend(
186 | scatters, fold_change_melt, dot_scale, size_max, enriched_only, significant_only, significance, size_threshold
187 | ):
188 | handles_list = []
189 | labels_list = []
190 |
191 | if enriched_only is False:
192 | handles_list.extend(
193 | [scatter.legend_elements(prop="colors", num=None)[0][0] for scatter in scatters] + [empty_handle]
194 | )
195 | labels_list.extend(["Enriched", "Depleted", ""])
196 |
197 | if significance is not None:
198 | handles_list.append(tuple([scatter.legend_elements(prop="colors", num=None)[0][0] for scatter in scatters]))
199 | labels_list.append(f"p-value < {significance}")
200 |
201 | if significant_only is False:
202 | handles_list.append(tuple([scatter.legend_elements(prop="colors", num=None)[0][1] for scatter in scatters]))
203 | labels_list.append(f"p-value >= {significance}")
204 |
205 | handles_list.append(empty_handle)
206 | labels_list.append("")
207 |
208 | handles, labels = scatters[0].legend_elements(prop="sizes", num=5, func=lambda x: x / 100 / dot_scale * size_max)
209 |
210 | if size_threshold is not None:
211 | # Show the threshold as a label only if the threshold is lower than the maximum fold change
212 | if enriched_only is True and fold_change_melt[fold_change_melt["value"] >= 0]["value"].max() > size_threshold:
213 | labels[-1] = f">{size_threshold:.1f}"
214 | elif fold_change_melt["value"].max() > size_threshold:
215 | labels[-1] = f">{size_threshold:.1f}"
216 |
217 | handles_list.extend([empty_handle] + handles)
218 | labels_list.extend(["log2 FC"] + labels)
219 |
220 | return handles_list, labels_list
221 |
222 |
223 | @d.dedent
224 | def enrichment(
225 | adata: AnnData,
226 | group_key: str,
227 | label_key: str,
228 | dot_scale: float = 3,
229 | group_cluster: bool = True,
230 | label_cluster: bool = False,
231 | groups: list | None = None,
232 | labels: list | None = None,
233 | show_pvalues: bool = False,
234 | significance: float | None = None,
235 | enriched_only: bool = True,
236 | significant_only: bool = False,
237 | size_threshold: float | None = None,
238 | palette: str | matplotlib.colors.ListedColormap | None = None,
239 | fontsize: str | int = "small",
240 | figsize: tuple[float, float] | None = (7, 5),
241 | save: str | Path | None = None,
242 | **kwargs,
243 | ):
244 | """
245 | Plot a dotplot of the enrichment of `label_key` in `group_key`.
246 |
247 | Parameters
248 | ----------
249 | %(adata)s
250 | group_key
251 | Key in :attr:`anndata.AnnData.obs` where groups are stored.
252 | label_key
253 | Key in :attr:`anndata.AnnData.obs` where labels are stored.
254 | dot_scale
255 | Scale of the dots.
256 | group_cluster
257 | If `True`, display groups ordered according to hierarchical clustering.
258 | label_cluster
259 | If `True`, display labels ordered according to hierarchical clustering.
260 | groups
261 | The groups for which to show the enrichment.
262 | labels
263 | The labels for which to show the enrichment.
264 | show_pvalues
265 | If `True`, show p-values as colors.
266 | significance
267 | If not `None`, show fold changes with a p-value above this threshold in a lighter color.
268 | enriched_only
269 | If `True`, display only enriched values and hide depleted values.
270 | significant_only
271 | If `True`, display only significant values and hide non-significant values.
272 | size_threshold
273 | Threshold for the size of the dots. Enrichment or depletions with absolute value above this threshold will have all the same size.
274 | palette
275 | Colormap for the enrichment values. It must be a diverging colormap.
276 | %(plotting)s
277 | kwargs
278 | Keyword arguments for :func:`matplotlib.pyplot.scatter`.
279 | """
280 | if f"{group_key}_{label_key}_enrichment" not in adata.uns:
281 | raise ValueError("Run cellcharter.gr.enrichment first.")
282 |
283 | if size_threshold is not None and size_threshold <= 0:
284 | raise ValueError("size_threshold must be greater than 0.")
285 |
286 | if palette is None:
287 | palette = sns.diverging_palette(240, 10, as_cmap=True)
288 | elif isinstance(palette, str):
289 | palette = get_cmap(palette)
290 |
291 | pvalues = None
292 | if "pvalue" not in adata.uns[f"{group_key}_{label_key}_enrichment"]:
293 | if show_pvalues:
294 | ValueError("show_pvalues requires gr.enrichment to be run with pvalues=True.")
295 |
296 | if significance is not None:
297 | ValueError("significance requires gr.enrichment to be run with pvalues=True.")
298 |
299 | if significant_only:
300 | ValueError("significant_only requires gr.enrichment to be run with pvalues=True.")
301 | elif show_pvalues:
302 | pvalues = adata.uns[f"{group_key}_{label_key}_enrichment"]["pvalue"].copy().T
303 | else:
304 | if significance is not None:
305 | warnings.warn(
306 | "Significance requires show_pvalues=True. Ignoring significance.",
307 | UserWarning,
308 | stacklevel=2,
309 | )
310 | significance = None
311 |
312 | if significant_only is True and significance is None:
313 | warnings.warn(
314 | "Significant_only requires significance to be set. Ignoring significant_only.",
315 | UserWarning,
316 | stacklevel=2,
317 | )
318 | significant_only = False
319 |
320 | # Set kwargs['alpha'] to 1 if not set
321 | if "alpha" not in kwargs:
322 | kwargs["alpha"] = 1
323 |
324 | if "edgecolor" not in kwargs:
325 | kwargs["edgecolor"] = "none"
326 |
327 | fold_change = adata.uns[f"{group_key}_{label_key}_enrichment"]["enrichment"].copy().T
328 |
329 | fold_change, pvalues = _select_labels(fold_change, pvalues, labels, groups)
330 |
331 | # Set -inf values to minimum and inf values to maximum
332 | fold_change[:] = np.nan_to_num(
333 | fold_change,
334 | neginf=np.min(fold_change[np.isfinite(fold_change)]),
335 | posinf=np.max(fold_change[np.isfinite(fold_change)]),
336 | )
337 |
338 | fold_change, pvalues = _reorder_labels(fold_change, pvalues, group_cluster, label_cluster)
339 |
340 | fold_change_melt = pd.melt(fold_change.reset_index(), id_vars=label_key)
341 |
342 | # Normalize the size of dots based on the absolute values in the dataframe, scaled to your preference
343 | sizes = fold_change_melt.copy()
344 | sizes["value"] = np.abs(sizes["value"])
345 | size_max = sizes["value"].max() if size_threshold is None else size_threshold
346 | if size_threshold is not None:
347 | sizes["value"] = sizes["value"].clip(upper=size_threshold)
348 |
349 | sizes["value"] = sizes["value"] * 100 / sizes["value"].max() * dot_scale
350 |
351 | norm = Normalize(0, 1)
352 | # Set colormap to red if below 0, blue if above 0
353 | if significance is not None:
354 | color = _significance_colors(fold_change.copy(), pvalues, significance)
355 | else:
356 | if pvalues is not None:
357 | pvalues += 0.0001
358 | norm = LogNorm(vmin=pvalues.min().min(), vmax=pvalues.max().max())
359 | color = pvalues.copy()
360 | else:
361 | color = fold_change.copy()
362 | color[:] = 0.0
363 |
364 | color = pd.melt(color.reset_index(), id_vars=label_key)
365 |
366 | # Create a figure and axis for plotting
367 | fig, ax = plt.subplots(figsize=figsize)
368 |
369 | scatters = []
370 | enriched_mask = fold_change_melt["value"] >= 0
371 |
372 | significant_mask = np.ones_like(fold_change_melt["value"], dtype=bool)
373 |
374 | if significant_only:
375 | significant_mask = pd.melt(pvalues.reset_index(), id_vars=label_key)["value"] < significance
376 |
377 | cmap_enriched = matplotlib.colors.LinearSegmentedColormap.from_list("", [palette(1.0), palette(0.5)])
378 | scatter_enriched = ax.scatter(
379 | pd.factorize(sizes[label_key])[0][enriched_mask & significant_mask],
380 | pd.factorize(sizes[group_key])[0][enriched_mask & significant_mask],
381 | s=sizes["value"][enriched_mask & significant_mask],
382 | c=color["value"][enriched_mask & significant_mask],
383 | cmap=cmap_enriched,
384 | norm=norm,
385 | **kwargs,
386 | )
387 | scatters.append(scatter_enriched)
388 |
389 | cmap_depleted = None
390 | if enriched_only is False:
391 | cmap_depleted = matplotlib.colors.LinearSegmentedColormap.from_list("", [palette(0.0), palette(0.5)])
392 | scatter_depleted = ax.scatter(
393 | pd.factorize(sizes[label_key])[0][~enriched_mask & significant_mask],
394 | pd.factorize(sizes[group_key])[0][~enriched_mask & significant_mask],
395 | s=sizes["value"][~enriched_mask & significant_mask],
396 | c=color["value"][~enriched_mask & significant_mask],
397 | cmap=cmap_depleted,
398 | norm=norm,
399 | **kwargs,
400 | )
401 | scatters.append(scatter_depleted)
402 |
403 | if pvalues is not None and significance is None:
404 | _pvalue_colorbar(ax, cmap_enriched, cmap_depleted, norm)
405 |
406 | handles_list, labels_list = _enrichment_legend(
407 | scatters, fold_change_melt, dot_scale, size_max, enriched_only, significant_only, significance, size_threshold
408 | )
409 |
410 | fig.legend(
411 | handles_list,
412 | labels_list,
413 | loc="outside upper left",
414 | bbox_to_anchor=(0.98, 0.95),
415 | handler_map={tuple: HandlerTuple(ndivide=None, pad=1)},
416 | borderpad=1,
417 | handletextpad=1.0,
418 | fontsize=fontsize,
419 | )
420 |
421 | # Adjust the ticks to match the dataframe's indices and columns
422 | ax.set_xticks(range(len(fold_change.index)))
423 | ax.set_yticks(range(len(fold_change.columns)))
424 | ax.set_xticklabels(fold_change.index, rotation=90)
425 | ax.set_yticklabels(fold_change.columns)
426 | ax.tick_params(axis="both", which="major", labelsize=fontsize)
427 |
428 | # Remove grid lines
429 | ax.grid(False)
430 |
431 | plt.tight_layout()
432 |
433 | if save:
434 | plt.savefig(save, bbox_inches="tight")
435 |
--------------------------------------------------------------------------------
/src/cellcharter/pl/_nhood.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import warnings
4 | from itertools import combinations
5 | from pathlib import Path
6 | from types import MappingProxyType
7 | from typing import Any, Mapping
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import pandas as pd
12 | from anndata import AnnData
13 | from matplotlib import rcParams
14 | from matplotlib.axes import Axes
15 | from squidpy._docs import d
16 | from squidpy.gr._utils import _assert_categorical_obs
17 | from squidpy.pl._color_utils import Palette_t, _maybe_set_colors
18 | from squidpy.pl._graph import _get_data
19 | from squidpy.pl._spatial_utils import _panel_grid
20 |
21 | from cellcharter.pl._utils import _heatmap
22 |
23 |
24 | def _plot_nhood_enrichment(
25 | adata: AnnData,
26 | nhood_enrichment_values: dict,
27 | cluster_key: str,
28 | row_groups: str | None = None,
29 | col_groups: str | None = None,
30 | annotate: bool = False,
31 | n_digits: int = 2,
32 | significance: float | None = None,
33 | method: str | None = None,
34 | title: str | None = "Neighborhood enrichment",
35 | cmap: str = "bwr",
36 | palette: Palette_t = None,
37 | cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
38 | fontsize=None,
39 | figsize: tuple[float, float] | None = None,
40 | dpi: int | None = None,
41 | ax: Axes | None = None,
42 | **kwargs: Any,
43 | ):
44 | enrichment = nhood_enrichment_values["enrichment"]
45 | adata_enrichment = AnnData(X=enrichment.astype(np.float32))
46 | adata_enrichment.obs[cluster_key] = pd.Categorical(enrichment.index)
47 |
48 | if significance is not None:
49 | if "pvalue" not in nhood_enrichment_values:
50 | warnings.warn(
51 | "Significance requires gr.nhood_enrichment to be run with pvalues=True. Ignoring significance.",
52 | UserWarning,
53 | stacklevel=2,
54 | )
55 | else:
56 | adata_enrichment.layers["significant"] = np.empty_like(enrichment, dtype=str)
57 | adata_enrichment.layers["significant"][nhood_enrichment_values["pvalue"].values <= significance] = "*"
58 |
59 | _maybe_set_colors(source=adata, target=adata_enrichment, key=cluster_key, palette=palette)
60 |
61 | if figsize is None:
62 | figsize = list(adata_enrichment.shape[::-1])
63 |
64 | if row_groups is not None:
65 | figsize[1] = len(row_groups)
66 |
67 | if col_groups is not None:
68 | figsize[0] = len(col_groups)
69 |
70 | figsize = tuple(figsize)
71 |
72 | _heatmap(
73 | adata_enrichment,
74 | key=cluster_key,
75 | rows=row_groups,
76 | cols=col_groups,
77 | annotate=annotate,
78 | n_digits=n_digits,
79 | method=method,
80 | title=title,
81 | cont_cmap=cmap,
82 | fontsize=fontsize,
83 | figsize=figsize,
84 | dpi=dpi,
85 | cbar_kwargs=cbar_kwargs,
86 | ax=ax,
87 | **kwargs,
88 | )
89 |
90 |
91 | @d.dedent
92 | def nhood_enrichment(
93 | adata: AnnData,
94 | cluster_key: str,
95 | row_groups: list[str] | None = None,
96 | col_groups: list[str] | None = None,
97 | min_freq: float | None = None,
98 | annotate: bool = False,
99 | transpose: bool = False,
100 | method: str | None = None,
101 | title: str | None = "Neighborhood enrichment",
102 | cmap: str = "bwr",
103 | palette: Palette_t = None,
104 | cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
105 | figsize: tuple[float, float] | None = None,
106 | dpi: int | None = None,
107 | ax: Axes | None = None,
108 | n_digits: int = 2,
109 | significance: float | None = None,
110 | save: str | Path | None = None,
111 | **kwargs: Any,
112 | ) -> None:
113 | """
114 | A modified version of squidpy's function for `plotting neighborhood enrichment `_.
115 |
116 | The enrichment is computed by :func:`cellcharter.gr.nhood_enrichment`.
117 |
118 | Parameters
119 | ----------
120 | %(adata)s
121 | %(cluster_key)s
122 | row_groups
123 | Restrict the rows to these groups. If `None`, all groups are plotted.
124 | col_groups
125 | Restrict the columns to these groups. If `None`, all groups are plotted.
126 | %(heatmap_plotting)s
127 |
128 | n_digits
129 | The number of digits of the number in the annotations.
130 | significance
131 | Mark the values that are below this threshold with a star. If `None`, no significance is computed. It requires ``gr.nhood_enrichment`` to be run with ``pvalues=True``.
132 | kwargs
133 | Keyword arguments for :func:`matplotlib.pyplot.text`.
134 |
135 | Returns
136 | -------
137 | %(plotting_returns)s
138 | """
139 | _assert_categorical_obs(adata, key=cluster_key)
140 | nhood_enrichment_values = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment").copy()
141 | nhood_enrichment_values["enrichment"][np.isinf(nhood_enrichment_values["enrichment"])] = np.nan
142 |
143 | if transpose:
144 | nhood_enrichment_values["enrichment"] = nhood_enrichment_values["enrichment"].T
145 |
146 | if min_freq is not None:
147 | frequency = adata.obs[cluster_key].value_counts(normalize=True)
148 | nhood_enrichment_values["enrichment"].loc[frequency[frequency < min_freq].index] = np.nan
149 | nhood_enrichment_values["enrichment"].loc[:, frequency[frequency < min_freq].index] = np.nan
150 |
151 | _plot_nhood_enrichment(
152 | adata,
153 | nhood_enrichment_values,
154 | cluster_key,
155 | row_groups=row_groups,
156 | col_groups=col_groups,
157 | annotate=annotate,
158 | method=method,
159 | title=title,
160 | cmap=cmap,
161 | palette=palette,
162 | cbar_kwargs=cbar_kwargs,
163 | figsize=figsize,
164 | dpi=dpi,
165 | ax=ax,
166 | n_digits=n_digits,
167 | significance=significance,
168 | **kwargs,
169 | )
170 |
171 | if save is not None:
172 | plt.savefig(save, bbox_inches="tight")
173 |
174 |
175 | @d.dedent
176 | def diff_nhood_enrichment(
177 | adata: AnnData,
178 | cluster_key: str,
179 | condition_key: str,
180 | condition_groups: list[str] | None = None,
181 | hspace: float = 0.25,
182 | wspace: float | None = None,
183 | ncols: int = 1,
184 | **nhood_kwargs: Any,
185 | ) -> None:
186 | r"""
187 | Plot the difference in neighborhood enrichment between conditions.
188 |
189 | The difference is computed by :func:`cellcharter.gr.diff_nhood_enrichment`.
190 |
191 | Parameters
192 | ----------
193 | %(adata)s
194 | %(cluster_key)s
195 | condition_key
196 | Key in ``adata.obs`` that stores the sample condition (e.g., normal vs disease).
197 | condition_groups
198 | Restrict the conditions to these clusters. If `None`, all groups are plotted.
199 | hspace
200 | Height space between panels.
201 | wspace
202 | Width space between panels.
203 | ncols
204 | Number of panels per row.
205 | nhood_kwargs
206 | Keyword arguments for :func:`cellcharter.pl.nhood_enrichment`.
207 |
208 | Returns
209 | -------
210 | %(plotting_returns)s
211 | """
212 | _assert_categorical_obs(adata, key=cluster_key)
213 | _assert_categorical_obs(adata, key=condition_key)
214 |
215 | conditions = adata.obs[condition_key].cat.categories if condition_groups is None else condition_groups
216 |
217 | if nhood_kwargs is None:
218 | nhood_kwargs = {}
219 |
220 | cmap = nhood_kwargs.pop("cmap", "PRGn_r")
221 | save = nhood_kwargs.pop("save", None)
222 |
223 | n_combinations = len(conditions) * (len(conditions) - 1) // 2
224 |
225 | figsize = nhood_kwargs.get("figsize", rcParams["figure.figsize"])
226 |
227 | # Plot neighborhood enrichment for each condition pair as a subplot
228 | _, grid = _panel_grid(
229 | num_panels=n_combinations,
230 | hspace=hspace,
231 | wspace=0.75 / figsize[0] + 0.02 if wspace is None else wspace,
232 | ncols=ncols,
233 | dpi=nhood_kwargs.get("dpi", rcParams["figure.dpi"]),
234 | figsize=nhood_kwargs.get("dpi", figsize),
235 | )
236 |
237 | axs = [plt.subplot(grid[c]) for c in range(n_combinations)]
238 |
239 | for i, (condition1, condition2) in enumerate(combinations(conditions, 2)):
240 | if f"{condition1}_{condition2}" not in adata.uns[f"{cluster_key}_{condition_key}_diff_nhood_enrichment"]:
241 | nhood_enrichment_values = dict(
242 | adata.uns[f"{cluster_key}_{condition_key}_diff_nhood_enrichment"][f"{condition2}_{condition1}"]
243 | )
244 | nhood_enrichment_values["enrichment"] = -nhood_enrichment_values["enrichment"]
245 | else:
246 | nhood_enrichment_values = adata.uns[f"{cluster_key}_{condition_key}_diff_nhood_enrichment"][
247 | f"{condition1}_{condition2}"
248 | ]
249 |
250 | _plot_nhood_enrichment(
251 | adata,
252 | nhood_enrichment_values,
253 | cluster_key,
254 | cmap=cmap,
255 | ax=axs[i],
256 | title=f"{condition1} vs {condition2}",
257 | show_cols=i >= n_combinations - ncols, # Show column labels only the last subplot of each grid column
258 | show_rows=i % ncols == 0, # Show row labels only for the first subplot of each grid row
259 | **nhood_kwargs,
260 | )
261 |
262 | if save is not None:
263 | plt.savefig(save, bbox_inches="tight")
264 |
--------------------------------------------------------------------------------
/src/cellcharter/pl/_shape.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import warnings
4 | from itertools import combinations
5 | from pathlib import Path
6 |
7 | import anndata as ad
8 | import geopandas
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import pandas as pd
12 | import seaborn as sns
13 | import spatialdata as sd
14 | import spatialdata_plot # noqa: F401
15 | from anndata import AnnData
16 | from scipy.stats import ttest_ind
17 | from squidpy._docs import d
18 |
19 | from ._utils import adjust_box_widths
20 |
21 |
22 | def plot_boundaries(
23 | adata: AnnData,
24 | sample: str,
25 | library_key: str = "sample",
26 | component_key: str = "component",
27 | alpha_boundary: float = 0.5,
28 | show_cells: bool = True,
29 | save: str | Path | None = None,
30 | ) -> None:
31 | """
32 | Plot the boundaries of the clusters.
33 |
34 | Parameters
35 | ----------
36 | %(adata)s
37 | sample
38 | Sample to plot.
39 | library_key
40 | Key in :attr:`anndata.AnnData.obs` where the sample labels are stored.
41 | component_key
42 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored.
43 | alpha_boundary
44 | Transparency of the boundaries.
45 | show_cells
46 | Whether to show the cells or not.
47 |
48 | Returns
49 | -------
50 | %(plotting_returns)s
51 | """
52 | # Print warning and call boundaries
53 | warnings.warn(
54 | "plot_boundaries is deprecated and will be removed in the next release. " "Please use `boundaries` instead.",
55 | FutureWarning,
56 | stacklevel=2,
57 | )
58 | boundaries(
59 | adata=adata,
60 | sample=sample,
61 | library_key=library_key,
62 | component_key=component_key,
63 | alpha_boundary=alpha_boundary,
64 | show_cells=show_cells,
65 | save=save,
66 | )
67 |
68 |
69 | @d.dedent
70 | def boundaries(
71 | adata: AnnData,
72 | sample: str,
73 | library_key: str = "sample",
74 | component_key: str = "component",
75 | alpha_boundary: float = 0.5,
76 | show_cells: bool = True,
77 | save: str | Path | None = None,
78 | ) -> None:
79 | """
80 | Plot the boundaries of the clusters.
81 |
82 | Parameters
83 | ----------
84 | %(adata)s
85 | sample
86 | Sample to plot.
87 | library_key
88 | Key in :attr:`anndata.AnnData.obs` where the sample labels are stored.
89 | component_key
90 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored.
91 | alpha_boundary
92 | Transparency of the boundaries.
93 | show_cells
94 | Whether to show the cells or not.
95 |
96 | Returns
97 | -------
98 | %(plotting_returns)s
99 | """
100 | adata = adata[adata.obs[library_key] == sample].copy()
101 | del adata.raw
102 | clusters = adata.obs[component_key].unique()
103 |
104 | boundaries = {
105 | cluster: boundary
106 | for cluster, boundary in adata.uns[f"shape_{component_key}"]["boundary"].items()
107 | if cluster in clusters
108 | }
109 | gdf = geopandas.GeoDataFrame(geometry=list(boundaries.values()), index=np.arange(len(boundaries)).astype(str))
110 | adata.obs.loc[adata.obs[component_key] == -1, component_key] = np.nan
111 | adata.obs.index = "cell_" + adata.obs.index
112 | adata.obs["instance_id"] = adata.obs.index
113 | adata.obs["region"] = "cells"
114 |
115 | xy = adata.obsm["spatial"]
116 | cell_circles = sd.models.ShapesModel.parse(xy, geometry=0, radius=3000, index=adata.obs["instance_id"])
117 |
118 | obs = pd.DataFrame(list(boundaries.keys()), columns=[component_key], index=np.arange(len(boundaries)).astype(str))
119 | adata_obs = ad.AnnData(X=pd.DataFrame(index=obs.index, columns=adata.var_names), obs=obs)
120 | adata_obs.obs["region"] = "clusters"
121 | adata_obs.index = "cluster_" + adata_obs.obs.index
122 | adata_obs.obs["instance_id"] = np.arange(len(boundaries)).astype(str)
123 | adata_obs.obs[component_key] = pd.Categorical(adata_obs.obs[component_key])
124 |
125 | adata = ad.concat((adata, adata_obs), join="outer")
126 |
127 | adata.obs["region"] = adata.obs["region"].astype("category")
128 |
129 | table = sd.models.TableModel.parse(
130 | adata, region_key="region", region=["clusters", "cells"], instance_key="instance_id"
131 | )
132 |
133 | shapes = {
134 | "clusters": sd.models.ShapesModel.parse(gdf),
135 | "cells": sd.models.ShapesModel.parse(cell_circles),
136 | }
137 |
138 | sdata = sd.SpatialData(shapes=shapes, tables=table)
139 |
140 | ax = plt.gca()
141 | if show_cells:
142 | try:
143 | sdata.pl.render_shapes(elements="cells", color=component_key).pl.show(ax=ax, legend_loc=None)
144 | except TypeError: # TODO: remove after spatialdata-plot issue #256 is fixed
145 | warnings.warn(
146 | "Until the next spatialdata_plot release, the cells that do not belong to any component will be displayed with a random color instead of grey.",
147 | stacklevel=2,
148 | )
149 | sdata.tables["table"].obs[component_key] = sdata.tables["table"].obs[component_key].cat.add_categories([-1])
150 | sdata.tables["table"].obs[component_key] = sdata.tables["table"].obs[component_key].fillna(-1)
151 | sdata.pl.render_shapes(elements="cells", color=component_key).pl.show(ax=ax, legend_loc=None)
152 |
153 | sdata.pl.render_shapes(
154 | element="clusters",
155 | color=component_key,
156 | fill_alpha=alpha_boundary,
157 | ).pl.show(ax=ax)
158 |
159 | if save is not None:
160 | plt.savefig(save, bbox_inches="tight")
161 |
162 |
163 | def plot_shape_metrics(
164 | adata: AnnData,
165 | condition_key: str,
166 | condition_groups: list[str] | None = None,
167 | cluster_key: str | None = None,
168 | cluster_id: list[str] | None = None,
169 | component_key: str = "component",
170 | metrics: str | tuple[str] | list[str] = ("linearity", "curl"),
171 | figsize: tuple[float, float] = (8, 7),
172 | title: str | None = None,
173 | ) -> None:
174 | """
175 | Boxplots of the shape metrics between two conditions.
176 |
177 | Parameters
178 | ----------
179 | %(adata)s
180 | condition_key
181 | Key in :attr:`anndata.AnnData.obs` where the condition labels are stored.
182 | condition_groups
183 | List of two conditions to compare. If None, all pairwise comparisons are made.
184 | cluster_key
185 | Key in :attr:`anndata.AnnData.obs` where the cluster labels are stored. This is used to filter the clusters to plot.
186 | cluster_id
187 | List of clusters to plot. If None, all clusters are plotted.
188 | component_key
189 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored.
190 | metrics
191 | List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``.
192 | figsize
193 | Figure size.
194 | title
195 | Title of the plot.
196 |
197 | Returns
198 | -------
199 | %(plotting_returns)s
200 | """
201 | # Print warning and call shape_metrics
202 | warnings.warn(
203 | "plot_shape_metrics is deprecated and will be removed in the next release. "
204 | "Please use `shape_metrics` instead.",
205 | FutureWarning,
206 | stacklevel=2,
207 | )
208 | shape_metrics(
209 | adata=adata,
210 | condition_key=condition_key,
211 | condition_groups=condition_groups,
212 | cluster_key=cluster_key,
213 | cluster_id=cluster_id,
214 | component_key=component_key,
215 | metrics=metrics,
216 | figsize=figsize,
217 | title=title,
218 | )
219 |
220 |
221 | @d.dedent
222 | def shape_metrics(
223 | adata: AnnData,
224 | condition_key: str,
225 | condition_groups: list[str] | None = None,
226 | cluster_key: str | None = None,
227 | cluster_id: list[str] | None = None,
228 | component_key: str = "component",
229 | metrics: str | tuple[str] | list[str] = ("linearity", "curl"),
230 | fontsize: str | int = "small",
231 | figsize: tuple[float, float] = (8, 7),
232 | title: str | None = None,
233 | ) -> None:
234 | """
235 | Boxplots of the shape metrics between two conditions.
236 |
237 | Parameters
238 | ----------
239 | %(adata)s
240 | condition_key
241 | Key in :attr:`anndata.AnnData.obs` where the condition labels are stored.
242 | condition_groups
243 | List of two conditions to compare. If None, all pairwise comparisons are made.
244 | cluster_key
245 | Key in :attr:`anndata.AnnData.obs` where the cluster labels are stored. This is used to filter the clusters to plot.
246 | cluster_id
247 | List of clusters to plot. If None, all clusters are plotted.
248 | component_key
249 | Key in :attr:`anndata.AnnData.obs` where the component labels are stored.
250 | metrics
251 | List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``.
252 | figsize
253 | Figure size.
254 | title
255 | Title of the plot.
256 |
257 | Returns
258 | -------
259 | %(plotting_returns)s
260 | """
261 | if isinstance(metrics, str):
262 | metrics = [metrics]
263 | elif isinstance(metrics, tuple):
264 | metrics = list(metrics)
265 |
266 | metrics_df = {metric: adata.uns[f"shape_{component_key}"][metric] for metric in metrics}
267 | metrics_df[condition_key] = (
268 | adata[~adata.obs[condition_key].isna()]
269 | .obs[[component_key, condition_key]]
270 | .drop_duplicates()
271 | .set_index(component_key)
272 | .to_dict()[condition_key]
273 | )
274 |
275 | metrics_df[cluster_key] = (
276 | adata[~adata.obs[condition_key].isna()]
277 | .obs[[component_key, cluster_key]]
278 | .drop_duplicates()
279 | .set_index(component_key)
280 | .to_dict()[cluster_key]
281 | )
282 |
283 | metrics_df = pd.DataFrame(metrics_df)
284 | if cluster_id is not None:
285 | metrics_df = metrics_df[metrics_df[cluster_key].isin(cluster_id)]
286 |
287 | metrics_df = pd.melt(
288 | metrics_df[metrics + [condition_key]],
289 | id_vars=[condition_key],
290 | var_name="metric",
291 | )
292 |
293 | conditions = (
294 | enumerate(combinations(adata.obs[condition_key].cat.categories, 2))
295 | if condition_groups is None
296 | else [condition_groups]
297 | )
298 |
299 | for condition1, condition2 in conditions:
300 | fig = plt.figure(figsize=figsize)
301 | metrics_condition_pair = metrics_df[metrics_df[condition_key].isin([condition1, condition2])]
302 | ax = sns.boxplot(
303 | data=metrics_condition_pair,
304 | x="metric",
305 | hue=condition_key,
306 | y="value",
307 | showfliers=False,
308 | hue_order=[condition1, condition2],
309 | )
310 |
311 | ax.tick_params(labelsize=fontsize)
312 | ax.set_xlabel(ax.get_xlabel(), fontsize=fontsize)
313 | ax.tick_params(labelsize=fontsize)
314 | ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)
315 |
316 | adjust_box_widths(fig, 0.9)
317 |
318 | ax = sns.stripplot(
319 | data=metrics_condition_pair,
320 | x="metric",
321 | hue=condition_key,
322 | y="value",
323 | color="0.08",
324 | size=4,
325 | jitter=0.13,
326 | dodge=True,
327 | hue_order=condition_groups if condition_groups else None,
328 | )
329 | handles, labels = ax.get_legend_handles_labels()
330 | plt.legend(
331 | handles[0 : len(metrics_condition_pair[condition_key].unique())],
332 | labels[0 : len(metrics_condition_pair[condition_key].unique())],
333 | bbox_to_anchor=(1.24, 1.02),
334 | fontsize=fontsize,
335 | )
336 |
337 | for count, metric in enumerate(["linearity", "curl"]):
338 | pvalue = ttest_ind(
339 | metrics_condition_pair[
340 | (metrics_condition_pair[condition_key] == condition1) & (metrics_condition_pair["metric"] == metric)
341 | ]["value"],
342 | metrics_condition_pair[
343 | (metrics_condition_pair[condition_key] == condition2) & (metrics_condition_pair["metric"] == metric)
344 | ]["value"],
345 | )[1]
346 | x1, x2 = count, count
347 | y, h, col = (
348 | metrics_condition_pair[(metrics_condition_pair["metric"] == metric)]["value"].max()
349 | + 0.02
350 | + 0.05 * count,
351 | 0.01,
352 | "k",
353 | )
354 | plt.plot([x1 - 0.2, x1 - 0.2, x2 + 0.2, x2 + 0.2], [y, y + h, y + h, y], lw=1.5, c=col)
355 | if pvalue < 0.05:
356 | plt.text(
357 | (x1 + x2) * 0.5,
358 | y + h * 2,
359 | f"p = {pvalue:.2e}",
360 | ha="center",
361 | va="bottom",
362 | color=col,
363 | fontdict={"fontsize": fontsize},
364 | )
365 | else:
366 | plt.text(
367 | (x1 + x2) * 0.5,
368 | y + h * 2,
369 | "ns",
370 | ha="center",
371 | va="bottom",
372 | color=col,
373 | fontdict={"fontsize": fontsize},
374 | )
375 | if title is not None:
376 | plt.title(title, fontdict={"fontsize": fontsize})
377 | plt.show()
378 |
--------------------------------------------------------------------------------
/src/cellcharter/pl/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from copy import copy
4 | from types import MappingProxyType
5 | from typing import Any, Mapping
6 |
7 | import matplotlib as mpl
8 | import numpy as np
9 | import seaborn as sns
10 | import squidpy as sq
11 | from anndata import AnnData
12 | from matplotlib import colors as mcolors
13 | from matplotlib import pyplot as plt
14 | from matplotlib.axes import Axes
15 | from matplotlib.patches import PathPatch
16 | from mpl_toolkits.axes_grid1 import make_axes_locatable
17 | from scipy.cluster import hierarchy as sch
18 | from squidpy._constants._pkg_constants import Key
19 |
20 | try:
21 | from matplotlib import colormaps as cm
22 | except ImportError:
23 | from matplotlib import cm
24 |
25 |
26 | def _get_cmap_norm(
27 | adata: AnnData,
28 | key: str,
29 | order: tuple[list[int], list[int]] | None | None = None,
30 | ) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, mcolors.BoundaryNorm, mcolors.BoundaryNorm, int]:
31 | n_rows = n_cols = adata.obs[key].nunique()
32 |
33 | colors = adata.uns[Key.uns.colors(key)]
34 |
35 | if order is not None:
36 | row_order, col_order = order
37 | row_colors = [colors[i] for i in row_order]
38 | col_colors = [colors[i] for i in col_order]
39 |
40 | n_rows = len(row_order)
41 | n_cols = len(col_order)
42 | else:
43 | row_colors = col_colors = colors
44 |
45 | row_cmap = mcolors.ListedColormap(row_colors)
46 | col_cmap = mcolors.ListedColormap(col_colors)
47 | row_norm = mcolors.BoundaryNorm(np.arange(n_rows + 1), row_cmap.N)
48 | col_norm = mcolors.BoundaryNorm(np.arange(n_cols + 1), col_cmap.N)
49 |
50 | return row_cmap, col_cmap, row_norm, col_norm, n_rows
51 |
52 |
53 | def _heatmap(
54 | adata: AnnData,
55 | key: str,
56 | rows: list[str] | None = None,
57 | cols: list[str] | None = None,
58 | title: str = "",
59 | method: str | None = None,
60 | cont_cmap: str | mcolors.Colormap = "bwr",
61 | annotate: bool = True,
62 | fontsize: int | None = None,
63 | figsize: tuple[float, float] | None = None,
64 | dpi: int | None = None,
65 | cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
66 | ax: Axes | None = None,
67 | n_digits: int = 2,
68 | show_cols: bool = True,
69 | show_rows: bool = True,
70 | **kwargs: Any,
71 | ) -> mpl.figure.Figure:
72 | cbar_kwargs = dict(cbar_kwargs)
73 |
74 | if fontsize is not None:
75 | kwargs["annot_kws"] = {"fontdict": {"fontsize": fontsize}}
76 |
77 | if ax is None:
78 | fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize)
79 | else:
80 | fig = ax.figure
81 |
82 | if method is not None:
83 | row_order, col_order, row_link, col_link = sq.pl._utils._dendrogram(
84 | adata.X, method, optimal_ordering=adata.n_obs <= 1500
85 | )
86 | else:
87 | row_order = (
88 | np.arange(len(adata.obs[key]))
89 | if rows is None
90 | else np.argwhere(adata.obs.index.isin(np.array(rows).astype(str))).flatten()
91 | )
92 | col_order = (
93 | np.arange(len(adata.var_names))
94 | if cols is None
95 | else np.argwhere(adata.var_names.isin(np.array(cols).astype(str))).flatten()
96 | )
97 |
98 | row_order = row_order[::-1]
99 | row_labels = adata.obs[key].iloc[row_order]
100 | col_labels = adata.var_names[col_order]
101 |
102 | data = adata[row_order, col_order].copy().X
103 |
104 | # row_cmap, col_cmap, row_norm, col_norm, n_cls = sq.pl._utils._get_cmap_norm(adata, key, order=(row_order, len(row_order) + col_order))
105 | row_cmap, col_cmap, row_norm, col_norm, n_cls = _get_cmap_norm(adata, key, order=(row_order, col_order))
106 | col_norm = mcolors.BoundaryNorm(np.arange(len(col_order) + 1), col_cmap.N)
107 |
108 | row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm)
109 | col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm)
110 |
111 | vmin = kwargs.pop("vmin", np.nanmin(data))
112 | vmax = kwargs.pop("vmax", np.nanmax(data))
113 | vcenter = kwargs.pop("vcenter", 0)
114 | norm = mpl.colors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
115 | cont_cmap = copy(cm.get_cmap(cont_cmap))
116 | cont_cmap.set_bad(color="grey")
117 |
118 | annot = np.round(data[::-1], n_digits).astype(str) if annotate else None
119 | if "significant" in adata.layers:
120 | significant = adata.layers["significant"].astype(str)
121 | annot = np.char.add(np.empty_like(data[::-1], dtype=str), significant[row_order[:, None], col_order][::-1])
122 |
123 | ax = sns.heatmap(
124 | data[::-1],
125 | cmap=cont_cmap,
126 | norm=norm,
127 | ax=ax,
128 | square=True,
129 | annot=annot,
130 | cbar=False,
131 | fmt="",
132 | **kwargs,
133 | )
134 |
135 | for _, spine in ax.spines.items():
136 | spine.set_visible(True)
137 |
138 | ax.tick_params(top=False, bottom=False, labeltop=False, labelbottom=False)
139 | ax.set_xticks([])
140 | ax.set_yticks([])
141 |
142 | divider = make_axes_locatable(ax)
143 | row_cats = divider.append_axes("left", size=0.1, pad=0.1)
144 | col_cats = divider.append_axes("bottom", size=0.1, pad=0.1)
145 | cax = divider.append_axes("right", size="2%", pad=0.1)
146 |
147 | if method is not None: # cluster rows but don't plot dendrogram
148 | col_ax = divider.append_axes("top", size="5%")
149 | sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
150 | col_ax.axis("off")
151 |
152 | c = fig.colorbar(
153 | mpl.cm.ScalarMappable(norm=norm, cmap=cont_cmap),
154 | cax=cax,
155 | ticks=np.linspace(norm.vmin, norm.vmax, 10),
156 | orientation="vertical",
157 | format="%0.2f",
158 | **cbar_kwargs,
159 | )
160 | c.ax.tick_params(labelsize=fontsize)
161 |
162 | # column labels colorbar
163 | c = fig.colorbar(col_sm, cax=col_cats, orientation="horizontal", ticklocation="bottom")
164 |
165 | if rows == cols or show_cols is False:
166 | c.set_ticks([])
167 | c.set_ticklabels([])
168 | else:
169 | c.set_ticks(np.arange(len(col_labels)) + 0.5)
170 | c.set_ticklabels(col_labels, fontdict={"fontsize": fontsize})
171 | if np.any([len(l) > 3 for l in col_labels]):
172 | c.ax.tick_params(rotation=90)
173 | c.outline.set_visible(False)
174 |
175 | # row labels colorbar
176 | c = fig.colorbar(row_sm, cax=row_cats, orientation="vertical", ticklocation="left")
177 | if show_rows is False:
178 | c.set_ticks([])
179 | c.set_ticklabels([])
180 | else:
181 | c.set_ticks(np.arange(n_cls) + 0.5)
182 | c.set_ticklabels(row_labels, fontdict={"fontsize": fontsize})
183 | c.set_label(key, fontsize=fontsize)
184 | c.outline.set_visible(False)
185 |
186 | ax.set_title(title, fontdict={"fontsize": fontsize})
187 |
188 | return fig, ax
189 |
190 |
191 | def _reorder(values, order, axis=1):
192 | if axis == 0:
193 | values = values.iloc[order, :]
194 | elif axis == 1:
195 | values = values.iloc[:, order]
196 | else:
197 | raise ValueError("The axis parameter accepts only values 0 and 1.")
198 | return values
199 |
200 |
201 | def _clip(values, min_threshold=None, max_threshold=None, new_min=None, new_max=None, new_middle=None):
202 | values_clipped = values.copy()
203 | if new_middle is not None:
204 | values_clipped[:] = new_middle
205 | if min_threshold is not None:
206 | values_clipped[values < min_threshold] = new_min if new_min is not None else min_threshold
207 | if max_threshold is not None:
208 | values_clipped[values > max_threshold] = new_max if new_max is not None else max_threshold
209 | return values_clipped
210 |
211 |
212 | def adjust_box_widths(g, fac):
213 | """Adjust the widths of a seaborn-generated boxplot."""
214 | # iterating through Axes instances
215 | for ax in g.axes:
216 | # iterating through axes artists:
217 | for c in ax.get_children():
218 | # searching for PathPatches
219 | if isinstance(c, PathPatch):
220 | # getting current width of box:
221 | p = c.get_path()
222 | verts = p.vertices
223 | verts_sub = verts[:-1]
224 | xmin = np.min(verts_sub[:, 0])
225 | xmax = np.max(verts_sub[:, 0])
226 | xmid = 0.5 * (xmin + xmax)
227 | xhalf = 0.5 * (xmax - xmin)
228 |
229 | # setting new width of box
230 | xmin_new = xmid - fac * xhalf
231 | xmax_new = xmid + fac * xhalf
232 | verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new
233 | verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new
234 |
235 | # setting new width of median line
236 | for l in ax.lines:
237 | if np.all(l.get_xdata() == [xmin, xmax]):
238 | l.set_xdata([xmin_new, xmax_new])
239 |
--------------------------------------------------------------------------------
/src/cellcharter/tl/__init__.py:
--------------------------------------------------------------------------------
1 | from ._autok import ClusterAutoK
2 | from ._gmm import Cluster, GaussianMixture
3 | from ._shape import boundaries, curl, elongation, linearity, purity
4 | from ._trvae import TRVAE
5 |
--------------------------------------------------------------------------------
/src/cellcharter/tl/_gmm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import List, Tuple, cast
5 |
6 | import anndata as ad
7 | import numpy as np
8 | import pandas as pd
9 | import scipy.sparse as sps
10 | import torch
11 | from pytorch_lightning import Trainer
12 | from torchgmm.base.data import (
13 | DataLoader,
14 | TensorLike,
15 | collate_tensor,
16 | dataset_from_tensors,
17 | )
18 | from torchgmm.bayes import GaussianMixture as TorchGaussianMixture
19 | from torchgmm.bayes.gmm.lightning_module import GaussianMixtureLightningModule
20 | from torchgmm.bayes.gmm.model import GaussianMixtureModel
21 |
22 | from .._utils import AnyRandom
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | class GaussianMixture(TorchGaussianMixture):
28 | """
29 | Adapted version of GaussianMixture clustering model from the `torchgmm `_ library.
30 |
31 | Parameters
32 | ----------
33 | n_clusters
34 | The number of components in the GMM. The dimensionality of each component is automatically inferred from the data.
35 | covariance_type
36 | The type of covariance to assume for all Gaussian components.
37 | init_strategy
38 | The strategy for initializing component means and covariances.
39 | init_means
40 | An optional initial guess for the means of the components. If provided,
41 | must be a tensor of shape ``[num_components, num_features]``. If this is given,
42 | the ``init_strategy`` is ignored and the means are handled as if K-means
43 | initialization has been run.
44 | convergence_tolerance
45 | The change in the per-datapoint negative log-likelihood which
46 | implies that training has converged.
47 | covariance_regularization
48 | A small value which is added to the diagonal of the
49 | covariance matrix to ensure that it is positive semi-definite.
50 | batch_size: The batch size to use when fitting the model. If not provided, the full
51 | data will be used as a single batch. Set this if the full data does not fit into
52 | memory.
53 | trainer_params
54 | Initialization parameters to use when initializing a PyTorch Lightning
55 | trainer. By default, it disables various stdout logs unless TorchGMM is configured to
56 | do verbose logging. Checkpointing and logging are disabled regardless of the log
57 | level. This estimator further sets the following overridable defaults:
58 | - ``max_epochs=100``.
59 | random_state
60 | Initialization seed.
61 |
62 | """
63 |
64 | #: The fitted PyTorch module with all estimated parameters.
65 | model_: GaussianMixtureModel
66 | #: A boolean indicating whether the model converged during training.
67 | converged_: bool
68 | #: The number of iterations the model was fitted for, excluding initialization.
69 | num_iter_: int
70 | #: The average per-datapoint negative log-likelihood at the last training step.
71 | nll_: float
72 |
73 | def __init__(
74 | self,
75 | n_clusters: int = 1,
76 | *,
77 | covariance_type: str = "full",
78 | init_strategy: str = "kmeans",
79 | init_means: torch.Tensor = None,
80 | convergence_tolerance: float = 0.001,
81 | covariance_regularization: float = 1e-06,
82 | batch_size: int = None,
83 | trainer_params: dict = None,
84 | random_state: AnyRandom = 0,
85 | ):
86 | super().__init__(
87 | num_components=n_clusters,
88 | covariance_type=covariance_type,
89 | init_strategy=init_strategy,
90 | init_means=init_means,
91 | convergence_tolerance=convergence_tolerance,
92 | covariance_regularization=covariance_regularization,
93 | batch_size=batch_size,
94 | trainer_params=trainer_params,
95 | )
96 | self.n_clusters = n_clusters
97 | self.random_state = random_state
98 |
99 | def fit(self, data: TensorLike) -> GaussianMixture:
100 | """
101 | Fits the Gaussian mixture on the provided data, estimating component priors, means and covariances. Parameters are estimated using the EM algorithm.
102 |
103 | Parameters
104 | ----------
105 | data
106 | The tabular data to fit on. The dimensionality of the Gaussian mixture is automatically inferred from this data.
107 | Returns
108 | ----------
109 | The fitted Gaussian mixture.
110 | """
111 | if sps.issparse(data):
112 | raise ValueError(
113 | "Sparse data is not supported. You may have forgotten to reduce the dimensionality of the data. Otherwise, please convert the data to a dense format."
114 | )
115 | return self._fit(data)
116 |
117 | def _fit(self, data) -> GaussianMixture:
118 | try:
119 | return super().fit(data)
120 | except torch._C._LinAlgError as e:
121 | if self.covariance_regularization >= 1:
122 | raise ValueError(
123 | "Cholesky decomposition failed even with covariance regularization = 1. The matrix may be singular."
124 | ) from e
125 | else:
126 | self.covariance_regularization *= 10
127 | logger.warning(
128 | f"Cholesky decomposition failed. Retrying with covariance regularization {self.covariance_regularization}."
129 | )
130 | return self._fit(data)
131 |
132 | def predict(self, data: TensorLike) -> torch.Tensor:
133 | """
134 | Computes the most likely components for each of the provided datapoints.
135 |
136 | Parameters
137 | ----------
138 | data
139 | The datapoints for which to obtain the most likely components.
140 | Returns
141 | ----------
142 | A tensor of shape ``[num_datapoints]`` with the indices of the most likely components.
143 | Note
144 | ----------
145 | Use :func:`predict_proba` to obtain probabilities for each component instead of the
146 | most likely component only.
147 | Attention
148 | ----------
149 | When calling this function in a multi-process environment, each process receives only
150 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
151 | the values returned from this method.
152 | """
153 | return super().predict(data).numpy()
154 |
155 | def predict_proba(self, data: TensorLike) -> torch.Tensor:
156 | """
157 | Computes a distribution over the components for each of the provided datapoints.
158 |
159 | Parameters
160 | ----------
161 | data
162 | The datapoints for which to compute the component assignment probabilities.
163 | Returns
164 | ----------
165 | A tensor of shape ``[num_datapoints, num_components]`` with the assignment
166 | probabilities for each component and datapoint. Note that each row of the vector sums
167 | to 1, i.e. the returned tensor provides a proper distribution over the components for
168 | each datapoint.
169 | Attention
170 | ----------
171 | When calling this function in a multi-process environment, each process receives only
172 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
173 | the values returned from this method.
174 | """
175 | loader = DataLoader(
176 | dataset_from_tensors(data),
177 | batch_size=self.batch_size or len(data),
178 | collate_fn=collate_tensor,
179 | )
180 | trainer_params = self.trainer_params.copy()
181 | trainer_params["logger"] = False
182 | result = Trainer(**trainer_params).predict(GaussianMixtureLightningModule(self.model_), loader)
183 | return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
184 |
185 | def score_samples(self, data: TensorLike) -> torch.Tensor:
186 | """
187 | Computes the negative log-likelihood (NLL) of each of the provided datapoints.
188 |
189 | Parameters
190 | ----------
191 | data
192 | The datapoints for which to compute the NLL.
193 | Returns
194 | ----------
195 | A tensor of shape ``[num_datapoints]`` with the NLL for each datapoint.
196 | Attention
197 | ----------
198 | When calling this function in a multi-process environment, each process receives only
199 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
200 | the values returned from this method.
201 | """
202 | loader = DataLoader(
203 | dataset_from_tensors(data),
204 | batch_size=self.batch_size or len(data),
205 | collate_fn=collate_tensor,
206 | )
207 | trainer_params = self.trainer_params.copy()
208 | trainer_params["logger"] = False
209 | result = Trainer(**trainer_params).predict(GaussianMixtureLightningModule(self.model_), loader)
210 | return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
211 |
212 |
213 | class Cluster(GaussianMixture):
214 | """
215 | Cluster cells or spots based on the neighborhood aggregated features from CellCharter.
216 |
217 | Parameters
218 | ----------
219 | n_clusters
220 | The number of components in the GMM. The dimensionality of each component is automatically inferred from the data.
221 | covariance_type
222 | The type of covariance to assume for all Gaussian components.
223 | init_strategy
224 | The strategy for initializing component means and covariances.
225 | init_means
226 | An optional initial guess for the means of the components. If provided,
227 | must be a tensor of shape ``[num_components, num_features]``. If this is given,
228 | the ``init_strategy`` is ignored and the means are handled as if K-means
229 | initialization has been run.
230 | convergence_tolerance
231 | The change in the per-datapoint negative log-likelihood which
232 | implies that training has converged.
233 | covariance_regularization
234 | A small value which is added to the diagonal of the
235 | covariance matrix to ensure that it is positive semi-definite.
236 | batch_size: The batch size to use when fitting the model. If not provided, the full
237 | data will be used as a single batch. Set this if the full data does not fit into
238 | memory.
239 | trainer_params
240 | Initialization parameters to use when initializing a PyTorch Lightning
241 | trainer. By default, it disables various stdout logs unless TorchGMM is configured to
242 | do verbose logging. Checkpointing and logging are disabled regardless of the log
243 | level. This estimator further sets the following overridable defaults:
244 | - ``max_epochs=100``.
245 | random_state
246 | Initialization seed.
247 | Examples
248 | --------
249 | >>> adata = anndata.read_h5ad(path_to_anndata)
250 | >>> sq.gr.spatial_neighbors(adata, coord_type='generic', delaunay=True)
251 | >>> cc.gr.remove_long_links(adata)
252 | >>> cc.gr.aggregate_neighbors(adata, n_layers=3)
253 | >>> model = cc.tl.Cluster(n_clusters=11)
254 | >>> model.fit(adata, use_rep='X_cellcharter')
255 | """
256 |
257 | def __init__(
258 | self,
259 | n_clusters: int = 1,
260 | *,
261 | covariance_type: str = "full",
262 | init_strategy: str = "kmeans",
263 | init_means: torch.Tensor = None,
264 | convergence_tolerance: float = 0.001,
265 | covariance_regularization: float = 1e-06,
266 | batch_size: int = None,
267 | trainer_params: dict = None,
268 | random_state: AnyRandom = 0,
269 | ):
270 | super().__init__(
271 | n_clusters=n_clusters,
272 | covariance_type=covariance_type,
273 | init_strategy=init_strategy,
274 | init_means=init_means,
275 | convergence_tolerance=convergence_tolerance,
276 | covariance_regularization=covariance_regularization,
277 | batch_size=batch_size,
278 | trainer_params=trainer_params,
279 | random_state=random_state,
280 | )
281 |
282 | def fit(self, adata: ad.AnnData, use_rep: str = "X_cellcharter"):
283 | """
284 | Fit data into ``n_clusters`` clusters.
285 |
286 | Parameters
287 | ----------
288 | adata
289 | Annotated data object.
290 | use_rep
291 | Key in :attr:`anndata.AnnData.obsm` to use as data to fit the clustering model.
292 | """
293 | logging_level = logging.root.level
294 |
295 | X = adata.X if use_rep is None else adata.obsm[use_rep]
296 |
297 | logging_level = logging.getLogger("lightning.pytorch").getEffectiveLevel()
298 | logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
299 |
300 | super().fit(X)
301 |
302 | logging.getLogger("lightning.pytorch").setLevel(logging_level)
303 |
304 | adata.uns["_cellcharter"] = {k: v for k, v in self.get_params().items() if k != "init_means"}
305 |
306 | def predict(self, adata: ad.AnnData, use_rep: str = "X_cellcharter") -> pd.Categorical:
307 | """
308 | Predict the labels for the data in ``use_rep`` using the fitted model.
309 |
310 | Parameters
311 | ----------
312 | adata
313 | Annotated data object.
314 | use_rep
315 | Key in :attr:`anndata.AnnData.obsm` used as data to fit the clustering model. If ``None``, uses :attr:`anndata.AnnData.X`.
316 | k
317 | Number of clusters to predict using the fitted model. If ``None``, the number of clusters with the highest stability will be selected. If ``max_runs > 1``, the model with the largest marginal likelihood will be used among the ones fitted on ``k``.
318 | """
319 | X = adata.X if use_rep is None else adata.obsm[use_rep]
320 | return pd.Categorical(super().predict(X), categories=np.arange(self.n_clusters))
321 |
--------------------------------------------------------------------------------
/src/cellcharter/tl/_trvae.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from typing import Optional
5 |
6 | from anndata import AnnData, read_h5ad
7 | from torch import nn
8 |
9 | try:
10 | from scarches.models import TRVAE as scaTRVAE
11 | from scarches.models import trVAE
12 | from scarches.models.base._utils import _validate_var_names
13 | except ImportError:
14 |
15 | class TRVAE:
16 | r"""
17 | scArches\'s trVAE model adapted to image-based proteomics data.
18 |
19 | The last ReLU layer of the neural network is removed to allow for continuous and real output values
20 |
21 | Parameters
22 | ----------
23 | adata: `~anndata.AnnData`
24 | Annotated data matrix. Has to be count data for 'nb' and 'zinb' loss and normalized log transformed data
25 | for 'mse' loss.
26 | condition_key: String
27 | column name of conditions in `adata.obs` data frame.
28 | conditions: List
29 | List of Condition names that the used data will contain to get the right encoding when used after reloading.
30 | hidden_layer_sizes: List
31 | A list of hidden layer sizes for encoder network. Decoder network will be the reversed order.
32 | latent_dim: Integer
33 | Bottleneck layer (z) size.
34 | dr_rate: Float
35 | Dropout rate applied to all layers, if `dr_rate==0` no dropout will be applied.
36 | use_mmd: Boolean
37 | If 'True' an additional MMD loss will be calculated on the latent dim. 'z' or the first decoder layer 'y'.
38 | mmd_on: String
39 | Choose on which layer MMD loss will be calculated on if 'use_mmd=True': 'z' for latent dim or 'y' for first
40 | decoder layer.
41 | mmd_boundary: Integer or None
42 | Choose on how many conditions the MMD loss should be calculated on. If 'None' MMD will be calculated on all
43 | conditions.
44 | recon_loss: String
45 | Definition of Reconstruction-Loss-Method, 'mse', 'nb' or 'zinb'.
46 | beta: Float
47 | Scaling Factor for MMD loss
48 | use_bn: Boolean
49 | If `True` batch normalization will be applied to layers.
50 | use_ln: Boolean
51 | If `True` layer normalization will be applied to layers.
52 | """
53 |
54 | def __init__(
55 | self,
56 | adata: AnnData,
57 | condition_key: str = None,
58 | conditions: Optional[list] = None,
59 | hidden_layer_sizes: list | tuple = (256, 64),
60 | latent_dim: int = 10,
61 | dr_rate: float = 0.05,
62 | use_mmd: bool = True,
63 | mmd_on: str = "z",
64 | mmd_boundary: Optional[int] = None,
65 | recon_loss: Optional[str] = "nb",
66 | beta: float = 1,
67 | use_bn: bool = False,
68 | use_ln: bool = True,
69 | ):
70 | raise ImportError("scarches is not installed. Please install scarches to use this method.")
71 |
72 | @classmethod
73 | def load(cls, dir_path: str, adata: Optional[AnnData] = None, map_location: Optional[str] = None):
74 | """
75 | Instantiate a model from the saved output.
76 |
77 | Parameters
78 | ----------
79 | dir_path
80 | Path to saved outputs.
81 | adata
82 | AnnData object.
83 | If None, will check for and load anndata saved with the model.
84 | map_location
85 | Location where all tensors should be loaded (e.g., `torch.device('cpu')`)
86 | Returns
87 | -------
88 | Model with loaded state dictionaries.
89 | """
90 | raise ImportError("scarches is not installed. Please install scarches to use this method.")
91 |
92 | else:
93 |
94 | class TRVAE(scaTRVAE):
95 | r"""
96 | scArches\'s trVAE model adapted to image-based proteomics data.
97 |
98 | The last ReLU layer of the neural network is removed to allow for continuous and real output values
99 |
100 | Parameters
101 | ----------
102 | adata: `~anndata.AnnData`
103 | Annotated data matrix. Has to be count data for 'nb' and 'zinb' loss and normalized log transformed data
104 | for 'mse' loss.
105 | condition_key: String
106 | column name of conditions in `adata.obs` data frame.
107 | conditions: List
108 | List of Condition names that the used data will contain to get the right encoding when used after reloading.
109 | hidden_layer_sizes: List
110 | A list of hidden layer sizes for encoder network. Decoder network will be the reversed order.
111 | latent_dim: Integer
112 | Bottleneck layer (z) size.
113 | dr_rate: Float
114 | Dropout rate applied to all layers, if `dr_rate==0` no dropout will be applied.
115 | use_mmd: Boolean
116 | If 'True' an additional MMD loss will be calculated on the latent dim. 'z' or the first decoder layer 'y'.
117 | mmd_on: String
118 | Choose on which layer MMD loss will be calculated on if 'use_mmd=True': 'z' for latent dim or 'y' for first
119 | decoder layer.
120 | mmd_boundary: Integer or None
121 | Choose on how many conditions the MMD loss should be calculated on. If 'None' MMD will be calculated on all
122 | conditions.
123 | recon_loss: String
124 | Definition of Reconstruction-Loss-Method, 'mse', 'nb' or 'zinb'.
125 | beta: Float
126 | Scaling Factor for MMD loss
127 | use_bn: Boolean
128 | If `True` batch normalization will be applied to layers.
129 | use_ln: Boolean
130 | If `True` layer normalization will be applied to layers.
131 | """
132 |
133 | def __init__(
134 | self,
135 | adata: AnnData,
136 | condition_key: str = None,
137 | conditions: Optional[list] = None,
138 | hidden_layer_sizes: list | tuple = (256, 64),
139 | latent_dim: int = 10,
140 | dr_rate: float = 0.05,
141 | use_mmd: bool = True,
142 | mmd_on: str = "z",
143 | mmd_boundary: Optional[int] = None,
144 | recon_loss: Optional[str] = "mse",
145 | beta: float = 1,
146 | use_bn: bool = False,
147 | use_ln: bool = True,
148 | ):
149 | self.adata = adata
150 |
151 | self.condition_key_ = condition_key
152 |
153 | if conditions is None:
154 | if condition_key is not None:
155 | self.conditions_ = adata.obs[condition_key].unique().tolist()
156 | else:
157 | self.conditions_ = []
158 | else:
159 | self.conditions_ = conditions
160 |
161 | self.hidden_layer_sizes_ = hidden_layer_sizes
162 | self.latent_dim_ = latent_dim
163 | self.dr_rate_ = dr_rate
164 | self.use_mmd_ = use_mmd
165 | self.mmd_on_ = mmd_on
166 | self.mmd_boundary_ = mmd_boundary
167 | self.recon_loss_ = recon_loss
168 | self.beta_ = beta
169 | self.use_bn_ = use_bn
170 | self.use_ln_ = use_ln
171 |
172 | self.input_dim_ = adata.n_vars
173 |
174 | self.model = trVAE(
175 | self.input_dim_,
176 | self.conditions_,
177 | list(self.hidden_layer_sizes_),
178 | self.latent_dim_,
179 | self.dr_rate_,
180 | self.use_mmd_,
181 | self.mmd_on_,
182 | self.mmd_boundary_,
183 | self.recon_loss_,
184 | self.beta_,
185 | self.use_bn_,
186 | self.use_ln_,
187 | )
188 |
189 | decoder_layer_sizes = self.model.hidden_layer_sizes.copy()
190 | decoder_layer_sizes.reverse()
191 | decoder_layer_sizes.append(self.model.input_dim)
192 |
193 | self.model.decoder.recon_decoder = nn.Linear(decoder_layer_sizes[-2], decoder_layer_sizes[-1])
194 |
195 | self.is_trained_ = False
196 |
197 | self.trainer = None
198 |
199 | @classmethod
200 | def load(cls, dir_path: str, adata: Optional[AnnData] = None, map_location: Optional[str] = None):
201 | """
202 | Instantiate a model from the saved output.
203 |
204 | Parameters
205 | ----------
206 | dir_path
207 | Path to saved outputs.
208 | adata
209 | AnnData object.
210 | If None, will check for and load anndata saved with the model.
211 | map_location
212 | Location where all tensors should be loaded (e.g., `torch.device('cpu')`)
213 | Returns
214 | -------
215 | Model with loaded state dictionaries.
216 | """
217 | adata_path = os.path.join(dir_path, "adata.h5ad")
218 |
219 | load_adata = adata is None
220 |
221 | if os.path.exists(adata_path) and load_adata:
222 | adata = read_h5ad(adata_path)
223 | elif not os.path.exists(adata_path) and load_adata:
224 | raise ValueError("Save path contains no saved anndata and no adata was passed.")
225 |
226 | attr_dict, model_state_dict, var_names = cls._load_params(dir_path, map_location=map_location)
227 |
228 | # Overwrite adata with new genes
229 | adata = _validate_var_names(adata, var_names)
230 |
231 | cls._validate_adata(adata, attr_dict)
232 | init_params = cls._get_init_params_from_dict(attr_dict)
233 |
234 | model = cls(adata, **init_params)
235 | model.model.load_state_dict(model_state_dict)
236 | model.model.eval()
237 |
238 | model.is_trained_ = attr_dict["is_trained_"]
239 |
240 | return model
241 |
--------------------------------------------------------------------------------
/src/cellcharter/tl/_utils.py:
--------------------------------------------------------------------------------
1 | from itertools import combinations
2 |
3 | import numpy as np
4 | from joblib import Parallel, delayed
5 | from sklearn.metrics import adjusted_rand_score
6 |
7 |
8 | def _stability(labels, similarity_function=adjusted_rand_score, n_jobs=-1):
9 | clusters = list(labels.keys())
10 | max_runs = len(labels[clusters[0]])
11 | num_combinations = max_runs * (max_runs - 1) // 2
12 |
13 | stabilities = Parallel(n_jobs=n_jobs)(
14 | delayed(similarity_function)(labels[clusters[k]][i], labels[clusters[k] + 1][j])
15 | for k in range(len(clusters) - 1)
16 | for i, j in combinations(range(max_runs), 2)
17 | )
18 |
19 | # Transform test into a list of chunks of size num_combinations
20 | stabilities = [stabilities[i : i + num_combinations] for i in range(0, len(stabilities), num_combinations)]
21 |
22 | # Append to every element of test the previous element
23 | stabilities = [stabilities[i] + stabilities[i - 1] for i in range(1, len(stabilities))]
24 |
25 | return np.array(stabilities)
26 |
--------------------------------------------------------------------------------
/tests/_data/test_data.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_data/test_data.h5ad
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 1,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k1/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 10,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k10/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 11,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k11/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 12,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k12/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 13,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k13/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 14,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k14/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 15,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k15/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 16,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k16/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 2,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k2/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 3,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k3/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 4,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k4/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 5,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k5/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 6,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k6/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 7,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k7/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 8,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k8/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 9,
3 | "num_features": 136,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/best_models/GaussianMixture_k9/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_imc/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_imc/params.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 1,
3 | "num_features": 144,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k1/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_clusters": 1,
3 | "covariance_type": "full",
4 | "init_strategy": "kmeans++",
5 | "init_means": null,
6 | "convergence_tolerance": 0.001,
7 | "covariance_regularization": 1e-6,
8 | "batch_size": null,
9 | "trainer_params": {
10 | "logger": false,
11 | "log_every_n_steps": 1,
12 | "enable_progress_bar": false,
13 | "enable_checkpointing": false,
14 | "enable_model_summary": false,
15 | "max_epochs": 100,
16 | "accelerator": "cpu"
17 | },
18 | "random_state": 42
19 | }
20 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 2,
3 | "num_features": 144,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k2/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_clusters": 2,
3 | "covariance_type": "full",
4 | "init_strategy": "kmeans++",
5 | "init_means": null,
6 | "convergence_tolerance": 0.001,
7 | "covariance_regularization": 1e-6,
8 | "batch_size": null,
9 | "trainer_params": {
10 | "logger": false,
11 | "log_every_n_steps": 1,
12 | "enable_progress_bar": false,
13 | "enable_checkpointing": false,
14 | "enable_model_summary": false,
15 | "max_epochs": 100,
16 | "accelerator": "cpu"
17 | },
18 | "random_state": 44
19 | }
20 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 3,
3 | "num_features": 144,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k3/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_clusters": 3,
3 | "covariance_type": "full",
4 | "init_strategy": "kmeans++",
5 | "init_means": null,
6 | "convergence_tolerance": 0.001,
7 | "covariance_regularization": 1e-6,
8 | "batch_size": null,
9 | "trainer_params": {
10 | "logger": false,
11 | "log_every_n_steps": 1,
12 | "enable_progress_bar": false,
13 | "enable_checkpointing": false,
14 | "enable_model_summary": false,
15 | "max_epochs": 100,
16 | "accelerator": "cpu"
17 | },
18 | "random_state": 44
19 | }
20 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 4,
3 | "num_features": 144,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k4/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_clusters": 4,
3 | "covariance_type": "full",
4 | "init_strategy": "kmeans++",
5 | "init_means": null,
6 | "convergence_tolerance": 0.001,
7 | "covariance_regularization": 1e-6,
8 | "batch_size": null,
9 | "trainer_params": {
10 | "logger": false,
11 | "log_every_n_steps": 1,
12 | "enable_progress_bar": false,
13 | "enable_checkpointing": false,
14 | "enable_model_summary": false,
15 | "max_epochs": 100,
16 | "accelerator": "cpu"
17 | },
18 | "random_state": 42
19 | }
20 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 5,
3 | "num_features": 144,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k5/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_clusters": 5,
3 | "covariance_type": "full",
4 | "init_strategy": "kmeans++",
5 | "init_means": null,
6 | "convergence_tolerance": 0.001,
7 | "covariance_regularization": 1e-6,
8 | "batch_size": null,
9 | "trainer_params": {
10 | "logger": false,
11 | "log_every_n_steps": 1,
12 | "enable_progress_bar": false,
13 | "enable_checkpointing": false,
14 | "enable_model_summary": false,
15 | "max_epochs": 100,
16 | "accelerator": "cpu"
17 | },
18 | "random_state": 44
19 | }
20 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/attributes.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/attributes.pickle
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_components": 6,
3 | "num_features": 144,
4 | "covariance_type": "full"
5 | }
6 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/model/parameters.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/model/parameters.pt
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/best_models/GaussianMixture_k6/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_clusters": 6,
3 | "covariance_type": "full",
4 | "init_strategy": "kmeans++",
5 | "init_means": null,
6 | "convergence_tolerance": 0.001,
7 | "covariance_regularization": 1e-6,
8 | "batch_size": null,
9 | "trainer_params": {
10 | "logger": false,
11 | "log_every_n_steps": 1,
12 | "enable_progress_bar": false,
13 | "enable_checkpointing": false,
14 | "enable_model_summary": false,
15 | "max_epochs": 100,
16 | "accelerator": "cpu"
17 | },
18 | "random_state": 43
19 | }
20 |
--------------------------------------------------------------------------------
/tests/_models/cellcharter_autok_mibitof/params.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSOgroup/cellcharter/ff55ad4da7b9e6896abf8775c94ce01cf13a6c19/tests/_models/cellcharter_autok_mibitof/params.pickle
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import time
2 | from urllib.error import HTTPError
3 |
4 | import anndata as ad
5 | import numpy as np
6 | import pytest
7 | import scanpy as sc
8 | from squidpy._constants._pkg_constants import Key
9 |
10 | _adata = sc.read("tests/_data/test_data.h5ad")
11 | _adata.raw = _adata.copy()
12 |
13 |
14 | @pytest.fixture()
15 | def non_visium_adata() -> ad.AnnData:
16 | non_visium_coords = np.array([[1, 0], [3, 0], [5, 6], [0, 4]])
17 | adata = ad.AnnData(X=non_visium_coords, dtype=int)
18 | adata.obsm[Key.obsm.spatial] = non_visium_coords
19 | return adata
20 |
21 |
22 | @pytest.fixture()
23 | def adata() -> ad.AnnData:
24 | return _adata.copy()
25 |
26 |
27 | @pytest.fixture(scope="session")
28 | def codex_adata() -> ad.AnnData:
29 | max_retries = 3
30 | retry_delay = 5 # seconds
31 |
32 | for attempt in range(max_retries):
33 | try:
34 | adata = sc.read(
35 | "tests/_data/codex_adata.h5ad", backup_url="https://figshare.com/ndownloader/files/46832722"
36 | )
37 | return adata[adata.obs["sample"].isin(["BALBc-1", "MRL-5"])].copy()
38 | except HTTPError as e:
39 | if attempt == max_retries - 1: # Last attempt
40 | pytest.skip(f"Failed to download test data after {max_retries} attempts: {str(e)}")
41 | time.sleep(retry_delay)
42 |
--------------------------------------------------------------------------------
/tests/graph/test_aggregate_neighbors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sps
3 | import squidpy as sq
4 | from anndata import AnnData
5 |
6 | from cellcharter.gr import aggregate_neighbors
7 |
8 |
9 | class TestAggregateNeighbors:
10 | def test_aggregate_neighbors(self):
11 | n_layers = 2
12 | aggregations = ["mean", "var"]
13 |
14 | G = sps.csr_matrix(
15 | np.array(
16 | [
17 | [0, 1, 1, 1, 1, 0, 0, 0, 0],
18 | [1, 0, 0, 0, 0, 1, 0, 0, 0],
19 | [1, 0, 0, 0, 0, 0, 1, 1, 0],
20 | [1, 0, 0, 0, 1, 0, 0, 0, 1],
21 | [1, 0, 0, 1, 0, 0, 0, 0, 0],
22 | [0, 1, 0, 0, 0, 0, 0, 0, 0],
23 | [0, 0, 1, 0, 0, 0, 0, 1, 0],
24 | [0, 0, 1, 0, 0, 0, 1, 0, 0],
25 | [0, 0, 0, 1, 0, 0, 0, 0, 0],
26 | ]
27 | )
28 | )
29 |
30 | X = np.vstack((np.power(2, np.arange(G.shape[0])), np.power(2, np.arange(G.shape[0])[::-1]))).T.astype(
31 | np.float32
32 | )
33 |
34 | adata = AnnData(X=X, obsp={"spatial_connectivities": G})
35 |
36 | L1_mean_truth = np.vstack(
37 | ([7.5, 16.5, 64.66, 91, 4.5, 2, 66, 34, 8], [60, 132, 87.33, 91, 144, 128, 33, 34, 32])
38 | ).T.astype(np.float32)
39 |
40 | L2_mean_truth = np.vstack(
41 | ([120, 9.33, 8.67, 3, 87.33, 1, 1, 1, 8.5], [3.75, 37.33, 58.67, 96, 64.33, 256, 256, 256, 136])
42 | ).T.astype(np.float32)
43 |
44 | L1_var_truth = np.vstack(
45 | ([5.36, 15.5, 51.85, 116.83, 3.5, 0, 62, 30, 0], [42.90, 124, 119.27, 116.83, 112, 0, 31, 30, 0])
46 | ).T.astype(np.float32)
47 |
48 | L2_var_truth = np.vstack(
49 | ([85.79, 4.99, 5.73, 1.0, 119.27, 0, 0, 0, 7.5], [2.68, 19.96, 49.46, 32, 51.85, 0, 0, 0, 120])
50 | ).T.astype(np.float32)
51 |
52 | aggregate_neighbors(adata, n_layers=n_layers, aggregations=["mean", "var"])
53 |
54 | assert adata.obsm["X_cellcharter"].shape == (adata.shape[0], X.shape[1] * (n_layers * len(aggregations) + 1))
55 |
56 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [0, 1]], X, rtol=0.01)
57 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [2, 3]], L1_mean_truth, rtol=0.01)
58 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [4, 5]], L1_var_truth**2, rtol=0.01)
59 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [6, 7]], L2_mean_truth, rtol=0.01)
60 | np.testing.assert_allclose(adata.obsm["X_cellcharter"][:, [8, 9]], L2_var_truth**2, rtol=0.01)
61 |
62 | def test_aggregations(self, adata: AnnData):
63 | sq.gr.spatial_neighbors(adata)
64 | aggregate_neighbors(adata, n_layers=3, aggregations="mean", out_key="X_str")
65 | aggregate_neighbors(adata, n_layers=3, aggregations=["mean"], out_key="X_list")
66 |
67 | assert (adata.obsm["X_str"] != adata.obsm["X_list"]).nnz == 0
68 |
--------------------------------------------------------------------------------
/tests/graph/test_build.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sps
3 | import squidpy as sq
4 | from anndata import AnnData
5 | from squidpy._constants._pkg_constants import Key
6 |
7 | import cellcharter as cc
8 |
9 |
10 | class TestRemoveLongLinks:
11 | def test_remove_long_links(self, non_visium_adata: AnnData):
12 | # ground-truth removing connections longer that 50th percentile
13 | correct_dist_perc = np.array(
14 | [
15 | [0.0, 2.0, 0.0, 4.12310563],
16 | [2.0, 0.0, 0, 5.0],
17 | [0.0, 0, 0.0, 0.0],
18 | [4.12310563, 5.0, 0.0, 0.0],
19 | ]
20 | )
21 | correct_graph_perc = np.array(
22 | [[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]
23 | )
24 |
25 | sq.gr.spatial_neighbors(non_visium_adata, coord_type="generic", delaunay=True)
26 | cc.gr.remove_long_links(non_visium_adata, distance_percentile=50)
27 |
28 | spatial_graph = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
29 | spatial_dist = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray()
30 |
31 | np.testing.assert_array_equal(spatial_graph, correct_graph_perc)
32 | np.testing.assert_allclose(spatial_dist, correct_dist_perc)
33 |
34 |
35 | class TestRemoveIntraClusterLinks:
36 | def test_mixed_clusters(self, non_visium_adata: AnnData):
37 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
38 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0]))
39 | )
40 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix(
41 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]]
42 | )
43 | non_visium_adata.obs["cluster"] = [0, 0, 1, 1]
44 |
45 | correct_conns = np.array([[0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0], [1, 1, 0, 0]])
46 | correct_dists = np.array([[0, 0, 4, 4], [0, 0, 6, 3], [4, 6, 0, 0], [4, 3, 0, 0]])
47 |
48 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster")
49 |
50 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
51 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray()
52 |
53 | np.testing.assert_array_equal(trimmed_conns, correct_conns)
54 | np.testing.assert_allclose(trimmed_dists, correct_dists)
55 |
56 | def test_same_clusters(self, non_visium_adata: AnnData):
57 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
58 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0]))
59 | )
60 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix(
61 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]]
62 | )
63 | non_visium_adata.obs["cluster"] = [0, 0, 0, 0]
64 |
65 | correct_conns = np.zeros((non_visium_adata.shape[0], non_visium_adata.shape[0]))
66 | correct_dists = np.zeros((non_visium_adata.shape[0], non_visium_adata.shape[0]))
67 |
68 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster")
69 |
70 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
71 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray()
72 |
73 | np.testing.assert_array_equal(trimmed_conns, correct_conns)
74 | np.testing.assert_allclose(trimmed_dists, correct_dists)
75 |
76 | def test_different_clusters(self, non_visium_adata: AnnData):
77 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
78 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0]))
79 | )
80 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix(
81 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]]
82 | )
83 | non_visium_adata.obs["cluster"] = [0, 1, 2, 3]
84 |
85 | correct_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].copy()
86 | correct_conns.setdiag(0)
87 | correct_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()]
88 |
89 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster")
90 |
91 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
92 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray()
93 |
94 | np.testing.assert_array_equal(trimmed_conns, correct_conns.toarray())
95 | np.testing.assert_allclose(trimmed_dists, correct_dists.toarray())
96 |
97 | def test_copy(self, non_visium_adata: AnnData):
98 | non_visium_adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
99 | np.ones((non_visium_adata.shape[0], non_visium_adata.shape[0]))
100 | )
101 | non_visium_adata.obsp[Key.obsp.spatial_dist()] = sps.csr_matrix(
102 | [[0, 1, 4, 4], [1, 0, 6, 3], [4, 6, 0, 9], [4, 3, 9, 0]]
103 | )
104 | non_visium_adata.obs["cluster"] = [0, 0, 1, 1]
105 |
106 | correct_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].copy()
107 | correct_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].copy()
108 |
109 | cc.gr.remove_intra_cluster_links(non_visium_adata, cluster_key="cluster", copy=True)
110 |
111 | trimmed_conns = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
112 | trimmed_dists = non_visium_adata.obsp[Key.obsp.spatial_dist()].toarray()
113 |
114 | np.testing.assert_array_equal(trimmed_conns, correct_conns.toarray())
115 | np.testing.assert_allclose(trimmed_dists, correct_dists.toarray())
116 |
117 |
118 | class TestConnectedComponents:
119 | def test_component_present(self, adata: AnnData):
120 | sq.gr.spatial_neighbors(adata, coord_type="grid", n_neighs=6, delaunay=False)
121 | cc.gr.connected_components(adata, min_cells=10)
122 |
123 | assert "component" in adata.obs
124 |
125 | def test_connected_components_no_cluster(self):
126 | adata = AnnData(
127 | X=np.full((4, 2), 1),
128 | )
129 |
130 | adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
131 | np.array(
132 | [
133 | [0, 1, 0, 0],
134 | [1, 0, 0, 0],
135 | [0, 0, 0, 1],
136 | [0, 0, 1, 0],
137 | ]
138 | )
139 | )
140 | correct_components = np.array([0, 0, 1, 1])
141 |
142 | components = cc.gr.connected_components(adata, min_cells=0, copy=True)
143 |
144 | np.testing.assert_array_equal(components, correct_components)
145 |
146 | components = cc.gr.connected_components(adata, min_cells=0, copy=False, out_key="comp")
147 | assert "comp" in adata.obs
148 | np.testing.assert_array_equal(adata.obs["comp"].values, correct_components)
149 |
150 | def test_connected_components_cluster(self):
151 | adata = AnnData(X=np.full((4, 2), 1), obs={"cluster": [0, 0, 1, 1]})
152 |
153 | adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
154 | np.array(
155 | [
156 | [0, 1, 0, 0],
157 | [1, 0, 0, 1],
158 | [0, 0, 0, 1],
159 | [0, 1, 1, 0],
160 | ]
161 | )
162 | )
163 | correct_components = np.array([0, 0, 1, 1])
164 |
165 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=0, copy=True)
166 |
167 | np.testing.assert_array_equal(components, correct_components)
168 |
169 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=0, copy=False, out_key="comp")
170 | assert "comp" in adata.obs
171 | np.testing.assert_array_equal(adata.obs["comp"].values, correct_components)
172 |
173 | def test_connected_components_min_cells(self):
174 | adata = AnnData(X=np.full((5, 2), 1), obs={"cluster": [0, 0, 0, 1, 1]})
175 |
176 | adata.obsp[Key.obsp.spatial_conn()] = sps.csr_matrix(
177 | np.array(
178 | [
179 | [0, 1, 1, 0, 0],
180 | [1, 0, 0, 0, 0],
181 | [1, 0, 0, 0, 1],
182 | [0, 0, 0, 0, 1],
183 | [0, 0, 1, 1, 0],
184 | ]
185 | )
186 | )
187 | correct_components = np.array([0, 0, 0, 1, 1])
188 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=2, copy=True)
189 | np.testing.assert_array_equal(components, correct_components)
190 |
191 | correct_components = np.array([0, 0, 0, np.nan, np.nan])
192 | components = cc.gr.connected_components(adata, cluster_key="cluster", min_cells=3, copy=True)
193 | np.testing.assert_array_equal(components, correct_components)
194 |
195 | def test_codex(self, codex_adata: AnnData):
196 | min_cells = 250
197 | correct_number_components = 97
198 | if "component" in codex_adata.obs:
199 | del codex_adata.obs["component"]
200 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=min_cells)
201 |
202 | assert codex_adata.obs["component"].dtype == "category"
203 | assert len(codex_adata.obs["component"].cat.categories) == correct_number_components
204 | for component in codex_adata.obs["component"].cat.categories:
205 | # Check that all components have at least min_cells cells
206 | assert np.sum(codex_adata.obs["component"] == component) >= min_cells
207 |
208 | # Check that all cells in the component are in the same cluster
209 | assert len(codex_adata.obs["cluster_cellcharter"][codex_adata.obs["component"] == component].unique()) == 1
210 |
--------------------------------------------------------------------------------
/tests/graph/test_diff_nhood.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from anndata import AnnData
4 |
5 | import cellcharter as cc
6 |
7 | _CLUSTER_KEY = "cell_type"
8 | _CONDITION_KEY = "sample"
9 | key = f"{_CLUSTER_KEY}_{_CONDITION_KEY}_diff_nhood_enrichment"
10 |
11 |
12 | class TestDiffNhoodEnrichment:
13 | def test_enrichment(self, codex_adata: AnnData):
14 | n_conditions = codex_adata.obs[_CONDITION_KEY].cat.categories.shape[0]
15 | cc.gr.diff_nhood_enrichment(
16 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, only_inter=False, log_fold_change=False
17 | )
18 |
19 | assert len(codex_adata.uns[key]) == n_conditions * (n_conditions - 1) / 2
20 |
21 | for nhood_enrichment in codex_adata.uns[key].values():
22 | enrichment = nhood_enrichment["enrichment"]
23 | assert np.all((enrichment >= -1) & (enrichment <= 1))
24 |
25 | del codex_adata.uns[key]
26 |
27 | def test_pvalues(self, codex_adata: AnnData):
28 | n_conditions = codex_adata.obs[_CONDITION_KEY].cat.categories.shape[0]
29 | cc.gr.diff_nhood_enrichment(
30 | codex_adata,
31 | cluster_key=_CLUSTER_KEY,
32 | condition_key=_CONDITION_KEY,
33 | library_key="sample",
34 | only_inter=True,
35 | log_fold_change=True,
36 | pvalues=True,
37 | n_perms=100,
38 | )
39 |
40 | assert len(codex_adata.uns[key]) == n_conditions * (n_conditions - 1) / 2
41 |
42 | for nhood_enrichment in codex_adata.uns[key].values():
43 | assert "pvalue" in nhood_enrichment
44 | pvalue = nhood_enrichment["pvalue"]
45 | assert np.all((pvalue >= 0) & (pvalue <= 1))
46 |
47 | del codex_adata.uns[key]
48 |
49 | def test_symmetric_vs_nonsymmetric(self, codex_adata: AnnData):
50 | # Test symmetric case
51 | cc.gr.diff_nhood_enrichment(codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, symmetric=True)
52 | symmetric_result = codex_adata.uns[key].copy()
53 | del codex_adata.uns[key]
54 |
55 | # Test non-symmetric case
56 | cc.gr.diff_nhood_enrichment(
57 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, symmetric=False
58 | )
59 | nonsymmetric_result = codex_adata.uns[key]
60 |
61 | # Results should be different when symmetric=False
62 | for pair_key in symmetric_result:
63 | assert not np.allclose(
64 | symmetric_result[pair_key]["enrichment"], nonsymmetric_result[pair_key]["enrichment"], equal_nan=True
65 | )
66 |
67 | del codex_adata.uns[key]
68 |
69 | def test_condition_groups(self, codex_adata: AnnData):
70 | conditions = codex_adata.obs[_CONDITION_KEY].cat.categories[:2]
71 | cc.gr.diff_nhood_enrichment(
72 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, condition_groups=conditions
73 | )
74 |
75 | # Should only have one comparison
76 | assert len(codex_adata.uns[key]) == 1
77 | pair_key = f"{conditions[0]}_{conditions[1]}"
78 | assert pair_key in codex_adata.uns[key]
79 |
80 | del codex_adata.uns[key]
81 |
82 | def test_invalid_inputs(self, codex_adata: AnnData):
83 | # Test invalid cluster key
84 | with pytest.raises(KeyError):
85 | cc.gr.diff_nhood_enrichment(codex_adata, cluster_key="invalid_key", condition_key=_CONDITION_KEY)
86 |
87 | # Test invalid condition key
88 | with pytest.raises(KeyError):
89 | cc.gr.diff_nhood_enrichment(codex_adata, cluster_key=_CLUSTER_KEY, condition_key="invalid_key")
90 |
91 | # Test invalid library key when using pvalues
92 | with pytest.raises(KeyError):
93 | cc.gr.diff_nhood_enrichment(
94 | codex_adata,
95 | cluster_key=_CLUSTER_KEY,
96 | condition_key=_CONDITION_KEY,
97 | library_key="invalid_key",
98 | pvalues=True,
99 | )
100 |
101 | def test_copy_return(self, codex_adata: AnnData):
102 | # Test copy=True returns results without modifying adata
103 | result = cc.gr.diff_nhood_enrichment(
104 | codex_adata, cluster_key=_CLUSTER_KEY, condition_key=_CONDITION_KEY, copy=True
105 | )
106 |
107 | assert key not in codex_adata.uns
108 | assert isinstance(result, dict)
109 | assert len(result) > 0
110 |
--------------------------------------------------------------------------------
/tests/graph/test_group.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import squidpy as sq
3 |
4 | import cellcharter as cc
5 |
6 | GROUP_KEY = "batch"
7 | LABEL_KEY = "Cell_class"
8 | key = f"{GROUP_KEY}_{LABEL_KEY}_enrichment"
9 |
10 | adata = sq.datasets.merfish()
11 |
12 |
13 | class TestEnrichment:
14 | def test_enrichment(self):
15 | cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY)
16 |
17 | assert key in adata.uns
18 | assert "enrichment" in adata.uns[key]
19 | assert "params" in adata.uns[key]
20 |
21 | del adata.uns[key]
22 |
23 | def test_copy(self):
24 | enrichment_dict = cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY, copy=True)
25 |
26 | assert "enrichment" in enrichment_dict
27 |
28 | enrichment_dict = cc.gr.enrichment(
29 | adata, group_key=GROUP_KEY, label_key=LABEL_KEY, copy=True, observed_expected=True
30 | )
31 |
32 | assert "enrichment" in enrichment_dict
33 | assert "observed" in enrichment_dict
34 | assert "expected" in enrichment_dict
35 |
36 | def test_obs_exp(self):
37 | cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY, observed_expected=True)
38 |
39 | assert key in adata.uns
40 | assert "enrichment" in adata.uns[key]
41 | assert "observed" in adata.uns[key]
42 | assert "expected" in adata.uns[key]
43 |
44 | observed = adata.uns[key]["observed"]
45 | expected = adata.uns[key]["expected"]
46 |
47 | assert observed.shape == (
48 | adata.obs[GROUP_KEY].cat.categories.shape[0],
49 | adata.obs[LABEL_KEY].cat.categories.shape[0],
50 | )
51 | assert expected.shape[0] == adata.obs[GROUP_KEY].cat.categories.shape[0]
52 | assert np.all((observed >= 0) & (observed <= 1))
53 | assert np.all((expected >= 0) & (expected <= 1))
54 |
55 | def test_perm(self):
56 | result_analytical = cc.gr.enrichment(
57 | adata, group_key=GROUP_KEY, label_key=LABEL_KEY, pvalues=False, copy=True, observed_expected=True
58 | )
59 | result_perm = cc.gr.enrichment(
60 | adata,
61 | group_key=GROUP_KEY,
62 | label_key=LABEL_KEY,
63 | pvalues=True,
64 | n_perms=5000,
65 | observed_expected=True,
66 | copy=True,
67 | )
68 | np.testing.assert_allclose(result_analytical["enrichment"], result_perm["enrichment"], atol=0.1)
69 | np.testing.assert_allclose(result_analytical["observed"], result_perm["observed"], atol=0.1)
70 | np.testing.assert_allclose(result_analytical["expected"], result_perm["expected"], atol=0.1)
71 |
--------------------------------------------------------------------------------
/tests/graph/test_nhood.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import squidpy as sq
4 | from squidpy._constants._pkg_constants import Key
5 |
6 | import cellcharter as cc
7 |
8 | _CK = "cell type"
9 | key = Key.uns.nhood_enrichment(_CK)
10 |
11 | adata = sq.datasets.imc()
12 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
13 | cc.gr.remove_long_links(adata)
14 |
15 |
16 | class TestNhoodEnrichment:
17 | def test_enrichment(self):
18 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, only_inter=False, log_fold_change=False)
19 | enrichment = adata.uns[key]["enrichment"]
20 | assert np.all((enrichment >= -1) & (enrichment <= 1))
21 |
22 | del adata.uns[key]
23 |
24 | def test_fold_change(self):
25 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, log_fold_change=True)
26 |
27 | del adata.uns[key]
28 |
29 | def test_nhood_obs_exp(self):
30 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, only_inter=False, observed_expected=True)
31 | observed = adata.uns[key]["observed"]
32 | expected = adata.uns[key]["expected"]
33 |
34 | assert observed.shape[0] == adata.obs[_CK].cat.categories.shape[0]
35 | assert observed.shape == expected.shape
36 | assert np.all((observed >= 0) & (observed <= 1))
37 | assert np.all((expected >= 0) & (expected <= 1))
38 |
39 | del adata.uns[key]
40 |
41 | def test_symmetric(self):
42 | result = cc.gr.nhood_enrichment(
43 | adata, cluster_key=_CK, symmetric=True, log_fold_change=True, only_inter=False, copy=True
44 | )
45 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02)
46 |
47 | result = cc.gr.nhood_enrichment(
48 | adata, cluster_key=_CK, symmetric=True, log_fold_change=False, only_inter=False, copy=True
49 | )
50 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02)
51 |
52 | result = cc.gr.nhood_enrichment(
53 | adata, cluster_key=_CK, symmetric=True, log_fold_change=False, only_inter=True, copy=True
54 | )
55 | result["enrichment"][result["enrichment"].isna()] = 0 # issymmetric fails with NaNs
56 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02)
57 |
58 | result = cc.gr.nhood_enrichment(
59 | adata, cluster_key=_CK, symmetric=True, log_fold_change=True, only_inter=True, copy=True
60 | )
61 | result["enrichment"][result["enrichment"].isna()] = 0 # issymmetric fails with NaNs
62 | assert scipy.linalg.issymmetric(result["enrichment"].values, atol=1e-02)
63 |
64 | def test_perm(self):
65 | result_analytical = cc.gr.nhood_enrichment(
66 | adata, cluster_key=_CK, only_inter=True, pvalues=False, observed_expected=True, copy=True
67 | )
68 | result_perm = cc.gr.nhood_enrichment(
69 | adata,
70 | cluster_key=_CK,
71 | only_inter=True,
72 | pvalues=True,
73 | n_perms=5000,
74 | observed_expected=True,
75 | copy=True,
76 | n_jobs=15,
77 | )
78 | np.testing.assert_allclose(result_analytical["enrichment"], result_perm["enrichment"], atol=0.1)
79 | np.testing.assert_allclose(result_analytical["observed"], result_perm["observed"], atol=0.1)
80 | np.testing.assert_allclose(result_analytical["expected"], result_perm["expected"], atol=0.1)
81 |
--------------------------------------------------------------------------------
/tests/plotting/test_group.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import pytest
3 | import squidpy as sq
4 |
5 | try:
6 | from matplotlib.colormaps import get_cmap
7 | except ImportError:
8 | from matplotlib.pyplot import get_cmap
9 |
10 | import cellcharter as cc
11 |
12 | GROUP_KEY = "batch"
13 | LABEL_KEY = "Cell_class"
14 | key = f"{GROUP_KEY}_{LABEL_KEY}_enrichment"
15 |
16 | adata = sq.datasets.merfish()
17 | adata_empirical = adata.copy()
18 | cc.gr.enrichment(adata_empirical, group_key=GROUP_KEY, label_key=LABEL_KEY, pvalues=True, n_perms=1000)
19 |
20 | adata_analytical = adata.copy()
21 | cc.gr.enrichment(adata_analytical, group_key=GROUP_KEY, label_key=LABEL_KEY)
22 |
23 |
24 | class TestProportion:
25 | def test_proportion(self):
26 | cc.pl.proportion(adata, group_key=GROUP_KEY, label_key=LABEL_KEY)
27 |
28 | def test_groups_labels(self):
29 | cc.pl.proportion(
30 | adata,
31 | group_key=GROUP_KEY,
32 | label_key=LABEL_KEY,
33 | groups=adata.obs[GROUP_KEY].cat.categories[:3],
34 | labels=adata.obs[LABEL_KEY].cat.categories[:4],
35 | )
36 |
37 |
38 | class TestPlotEnrichment:
39 | @pytest.mark.parametrize("adata_enrichment", [adata_analytical, adata_empirical])
40 | def test_enrichment(self, adata_enrichment):
41 | cc.pl.enrichment(adata_enrichment, group_key=GROUP_KEY, label_key=LABEL_KEY)
42 |
43 | @pytest.mark.parametrize("adata_enrichment", [adata_analytical, adata_empirical])
44 | @pytest.mark.parametrize("label_cluster", [False, True])
45 | @pytest.mark.parametrize("show_pvalues", [False, True])
46 | @pytest.mark.parametrize("groups", [None, adata.obs[GROUP_KEY].cat.categories[:3]])
47 | @pytest.mark.parametrize("labels", [None, adata.obs[LABEL_KEY].cat.categories[:4]])
48 | @pytest.mark.parametrize("size_threshold", [1, 2.5])
49 | @pytest.mark.parametrize("palette", [None, "coolwarm", get_cmap("coolwarm")])
50 | @pytest.mark.parametrize("figsize", [None, (10, 8)])
51 | @pytest.mark.parametrize("alpha,edgecolor", [(1, "red"), (0.5, "blue")])
52 | def test_params(
53 | self,
54 | adata_enrichment,
55 | label_cluster,
56 | show_pvalues,
57 | groups,
58 | labels,
59 | size_threshold,
60 | palette,
61 | figsize,
62 | alpha,
63 | edgecolor,
64 | ):
65 | cc.pl.enrichment(
66 | adata_enrichment,
67 | group_key=GROUP_KEY,
68 | label_key=LABEL_KEY,
69 | label_cluster=label_cluster,
70 | groups=groups,
71 | labels=labels,
72 | show_pvalues=show_pvalues,
73 | size_threshold=size_threshold,
74 | palette=palette,
75 | figsize=figsize,
76 | alpha=alpha,
77 | edgecolor=edgecolor,
78 | )
79 | plt.close()
80 |
81 | def test_no_pvalues(self):
82 | # If the enrichment data is not present, it should raise an error
83 | with pytest.raises(ValueError):
84 | cc.pl.enrichment(adata_analytical, "group_key", "label_key", show_pvalues=True)
85 |
86 | with pytest.raises(ValueError):
87 | cc.pl.enrichment(adata_analytical, "group_key", "label_key", significance=0.01)
88 |
89 | with pytest.raises(ValueError):
90 | cc.pl.enrichment(adata_analytical, "group_key", "label_key", significant_only=True)
91 |
92 | def test_obs_exp(self):
93 | cc.gr.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY, observed_expected=True)
94 | cc.pl.enrichment(adata, group_key=GROUP_KEY, label_key=LABEL_KEY)
95 |
96 | def test_enrichment_no_enrichment_data(self):
97 | with pytest.raises(ValueError):
98 | cc.pl.enrichment(adata, "group_key", "label_key")
99 |
100 | def test_size_threshold_zero(self):
101 | with pytest.raises(ValueError):
102 | cc.pl.enrichment(adata_empirical, group_key=GROUP_KEY, label_key=LABEL_KEY, size_threshold=0)
103 |
--------------------------------------------------------------------------------
/tests/plotting/test_plot_nhood.py:
--------------------------------------------------------------------------------
1 | import squidpy as sq
2 | from squidpy._constants._pkg_constants import Key
3 |
4 | import cellcharter as cc
5 |
6 | _CK = "cell type"
7 | key = Key.uns.nhood_enrichment(_CK)
8 |
9 | adata = sq.datasets.imc()
10 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
11 | cc.gr.remove_long_links(adata)
12 |
13 |
14 | class TestPlotNhoodEnrichment:
15 | def test_annotate(self):
16 | cc.gr.nhood_enrichment(adata, cluster_key=_CK)
17 | cc.pl.nhood_enrichment(adata, cluster_key=_CK, annotate=True)
18 |
19 | del adata.uns[key]
20 |
21 | def test_significance(self):
22 | cc.gr.nhood_enrichment(adata, cluster_key=_CK, pvalues=True, n_perms=100)
23 |
24 | cc.pl.nhood_enrichment(adata, cluster_key=_CK, significance=0.05)
25 | cc.pl.nhood_enrichment(adata, cluster_key=_CK, annotate=True, significance=0.05)
26 |
27 | del adata.uns[key]
28 |
--------------------------------------------------------------------------------
/tests/plotting/test_plot_stability.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import scipy.sparse as sps
3 | import squidpy as sq
4 |
5 | import cellcharter as cc
6 |
7 |
8 | class TestPlotStability:
9 | @pytest.mark.parametrize("dataset_name", ["mibitof"])
10 | def test_spatial_proteomics(self, dataset_name: str):
11 | download_dataset = getattr(sq.datasets, dataset_name)
12 | adata = download_dataset()
13 | if sps.issparse(adata.X):
14 | adata.X = adata.X.todense()
15 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
16 | cc.gr.remove_long_links(adata)
17 | cc.gr.aggregate_neighbors(adata, n_layers=3)
18 |
19 | model = cc.tl.ClusterAutoK.load(f"tests/_models/cellcharter_autok_{dataset_name}")
20 |
21 | cc.pl.autok_stability(model)
22 |
--------------------------------------------------------------------------------
/tests/plotting/test_shape.py:
--------------------------------------------------------------------------------
1 | from anndata import AnnData
2 |
3 | import cellcharter as cc
4 |
5 |
6 | class TestPlotBoundaries:
7 | def test_boundaries(self, codex_adata: AnnData):
8 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=250)
9 | cc.tl.boundaries(codex_adata)
10 | cc.pl.boundaries(codex_adata, sample="BALBc-1", alpha_boundary=0.5, show_cells=False)
11 |
12 | # def test_boundaries_only(self, codex_adata: AnnData):
13 | # cc.tl.boundaries(codex_adata)
14 | # cc.pl.boundaries(codex_adata, sample="BALBc-1", alpha_boundary=0.5)
15 |
--------------------------------------------------------------------------------
/tests/tools/test_autok.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import scipy.sparse as sps
4 | import squidpy as sq
5 |
6 | import cellcharter as cc
7 |
8 |
9 | class TestClusterAutoK:
10 | @pytest.mark.parametrize("dataset_name", ["mibitof"])
11 | def test_spatial_proteomics(self, dataset_name: str):
12 | download_dataset = getattr(sq.datasets, dataset_name)
13 | adata = download_dataset()
14 | if sps.issparse(adata.X):
15 | adata.X = adata.X.todense()
16 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
17 | cc.gr.remove_long_links(adata)
18 | cc.gr.aggregate_neighbors(adata, n_layers=3)
19 |
20 | model_params = {
21 | "init_strategy": "kmeans",
22 | "random_state": 42,
23 | "trainer_params": {"accelerator": "cpu", "enable_progress_bar": False},
24 | }
25 | autok = cc.tl.ClusterAutoK(
26 | n_clusters=(2, 5), model_class=cc.tl.GaussianMixture, model_params=model_params, max_runs=3
27 | )
28 | autok.fit(adata, use_rep="X_cellcharter")
29 | adata.obs[f"cellcharter_{autok.best_k}"] = autok.predict(adata, use_rep="X_cellcharter", k=autok.best_k)
30 |
31 | assert len(np.unique(adata.obs[f"cellcharter_{autok.best_k}"])) == autok.best_k
32 |
--------------------------------------------------------------------------------
/tests/tools/test_gmm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import scipy.sparse as sps
3 | import squidpy as sq
4 |
5 | import cellcharter as cc
6 |
7 |
8 | class TestCluster:
9 | @pytest.mark.parametrize("dataset_name", ["mibitof"])
10 | def test_sparse(self, dataset_name: str):
11 | download_dataset = getattr(sq.datasets, dataset_name)
12 | adata = download_dataset()
13 | adata.X = sps.csr_matrix(adata.X)
14 |
15 | sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
16 | cc.gr.remove_long_links(adata)
17 |
18 | gmm = cc.tl.Cluster(n_clusters=(10))
19 |
20 | # Check if fit raises a ValueError
21 | with pytest.raises(ValueError):
22 | gmm.fit(adata, use_rep=None)
23 |
--------------------------------------------------------------------------------
/tests/tools/test_shape.py:
--------------------------------------------------------------------------------
1 | from anndata import AnnData
2 | from shapely import Polygon
3 |
4 | import cellcharter as cc
5 |
6 |
7 | # Test for cc.tl.boundaries, that computes the topological boundaries of sets of cells.
8 | class TestBoundaries:
9 | def test_boundaries(self, codex_adata: AnnData):
10 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=250)
11 | cc.tl.boundaries(codex_adata)
12 |
13 | boundaries = codex_adata.uns["shape_component"]["boundary"]
14 |
15 | assert isinstance(boundaries, dict)
16 |
17 | # Check if boundaries contains all components of codex_adata
18 | assert set(boundaries.keys()) == set(codex_adata.obs["component"].cat.categories)
19 |
20 | def test_copy(self, codex_adata: AnnData):
21 | cc.gr.connected_components(codex_adata, cluster_key="cluster_cellcharter", min_cells=250)
22 | boundaries = cc.tl.boundaries(codex_adata, copy=True)
23 |
24 | assert isinstance(boundaries, dict)
25 |
26 | # Check if boundaries contains all components of codex_adata
27 | assert set(boundaries.keys()) == set(codex_adata.obs["component"].cat.categories)
28 |
29 |
30 | class TestLinearity:
31 | def test_rectangle(self, codex_adata: AnnData):
32 | codex_adata.obs["rectangle"] = 1
33 |
34 | polygon = Polygon([(0, 0), (0, 10), (2, 10), (2, 0)])
35 |
36 | codex_adata.uns["shape_rectangle"] = {"boundary": {1: polygon}}
37 | linearities = cc.tl.linearity(codex_adata, "rectangle", copy=True)
38 | assert linearities[1] == 1.0
39 |
40 | def test_symmetrical_cross(self, codex_adata: AnnData):
41 | codex_adata.obs["cross"] = 1
42 |
43 | # Symmetrical cross with arm width of 2 and length of 5
44 | polygon = Polygon(
45 | [(0, 5), (0, 7), (5, 7), (5, 12), (7, 12), (7, 7), (12, 7), (12, 5), (7, 5), (7, 0), (5, 0), (5, 5)]
46 | )
47 |
48 | codex_adata.uns["shape_cross"] = {"boundary": {1: polygon}}
49 | linearities = cc.tl.linearity(codex_adata, "cross", copy=True)
50 |
51 | # The cross is symmetrical, so the linearity should be 0.5
52 | assert abs(linearities[1] - 0.5) < 0.01
53 |
54 | def test_thickness(self, codex_adata: AnnData):
55 | # The thickness of the cross should not influence the linearity
56 | codex_adata.obs["cross"] = 1
57 |
58 | # Symmetrical cross with arm width of 2 and length of 5
59 | polygon1 = Polygon(
60 | [(0, 5), (0, 6), (5, 6), (5, 11), (6, 11), (6, 6), (11, 6), (11, 5), (6, 5), (6, 0), (5, 0), (5, 5)]
61 | )
62 |
63 | # Symmetrical cross with arm width of 2 and length of 5
64 | polygon2 = Polygon(
65 | [(0, 5), (0, 7), (5, 7), (5, 12), (7, 12), (7, 7), (12, 7), (12, 5), (7, 5), (7, 0), (5, 0), (5, 5)]
66 | )
67 |
68 | codex_adata.uns["shape_cross"] = {"boundary": {1: polygon1}}
69 | linearities1 = cc.tl.linearity(codex_adata, "cross", copy=True)
70 |
71 | codex_adata.uns["shape_cross"] = {"boundary": {1: polygon2}}
72 | linearities2 = cc.tl.linearity(codex_adata, "cross", copy=True)
73 |
74 | assert abs(linearities1[1] - linearities2[1]) < 0.01
75 |
--------------------------------------------------------------------------------