├── .codecov.yaml
├── .cruft.json
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
└── workflows
│ ├── build.yaml
│ ├── release.yaml
│ ├── test_linux.yaml
│ ├── test_linux_cuda.yaml
│ ├── test_linux_pre.yaml
│ ├── test_macos.yaml
│ ├── test_macos_m1.yaml
│ └── test_windows.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _static
│ ├── .gitkeep
│ └── css
│ │ └── custom.css
├── _templates
│ ├── .gitkeep
│ ├── autosummary
│ │ └── class.rst
│ └── class_no_inherited.rst
├── api.md
├── changelog.md
├── conf.py
├── contributing.md
├── extensions
│ ├── .gitkeep
│ └── typed_returns.py
├── index.md
├── notebooks
│ ├── large_scale.ipynb
│ └── lung_example.ipynb
├── references.bib
├── references.md
├── template_usage.md
└── tutorials.md
├── pyproject.toml
├── setup.py
├── src
└── scib_metrics
│ ├── __init__.py
│ ├── _settings.py
│ ├── _types.py
│ ├── benchmark
│ ├── __init__.py
│ └── _core.py
│ ├── metrics
│ ├── __init__.py
│ ├── _graph_connectivity.py
│ ├── _isolated_labels.py
│ ├── _kbet.py
│ ├── _lisi.py
│ ├── _nmi_ari.py
│ ├── _pcr_comparison.py
│ └── _silhouette.py
│ ├── nearest_neighbors
│ ├── __init__.py
│ ├── _dataclass.py
│ ├── _jax.py
│ └── _pynndescent.py
│ └── utils
│ ├── __init__.py
│ ├── _diffusion_nn.py
│ ├── _dist.py
│ ├── _kmeans.py
│ ├── _lisi.py
│ ├── _pca.py
│ ├── _pcr.py
│ ├── _silhouette.py
│ └── _utils.py
└── tests
├── __init__.py
├── test_benchmarker.py
├── test_metrics.py
├── test_neighbors.py
├── test_pcr_comparison.py
└── utils
├── __init__.py
├── data.py
├── sampling.py
├── test_pca.py
└── test_pcr.py
/.codecov.yaml:
--------------------------------------------------------------------------------
1 | # Based on pydata/xarray
2 | codecov:
3 | require_ci_to_pass: no
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | # Require 1% coverage, i.e., always succeed
10 | target: 1
11 | patch: false
12 | changes: false
13 |
14 | comment:
15 | layout: diff, flags, files
16 | behavior: once
17 | require_base: no
18 |
--------------------------------------------------------------------------------
/.cruft.json:
--------------------------------------------------------------------------------
1 | {
2 | "template": "https://github.com/scverse/cookiecutter-scverse",
3 | "commit": "87a407a65408d75a949c0b54b19fd287475a56f8",
4 | "checkout": "v0.4.0",
5 | "context": {
6 | "cookiecutter": {
7 | "project_name": "scib-metrics",
8 | "package_name": "scib_metrics",
9 | "project_description": "Accelerated and Python-only scIB metrics",
10 | "author_full_name": "Adam Gayoso",
11 | "author_email": "adamgayoso@berkeley.edu",
12 | "github_user": "adamgayoso",
13 | "project_repo": "https://github.com/yoseflab/scib-metrics",
14 | "license": "BSD 3-Clause License",
15 | "_copy_without_render": [
16 | ".github/workflows/build.yaml",
17 | ".github/workflows/test.yaml",
18 | "docs/_templates/autosummary/**.rst"
19 | ],
20 | "_render_devdocs": false,
21 | "_jinja2_env_vars": {
22 | "lstrip_blocks": true,
23 | "trim_blocks": true
24 | },
25 | "_template": "https://github.com/scverse/cookiecutter-scverse"
26 | }
27 | },
28 | "directory": null,
29 | "skip": [
30 | ".github/workflows/**.yaml",
31 | ".pre-commit-config.yaml",
32 | "pyproject.toml",
33 | "tests",
34 | "src/**/__init__.py",
35 | "src/**/basic.py",
36 | "docs/api.md",
37 | "docs/changelog.md",
38 | "docs/references.bib",
39 | "docs/references.md",
40 | "docs/notebooks/example.ipynb",
41 | "tests",
42 | "src/**/__init__.py",
43 | "src/**/basic.py",
44 | "docs/api.md",
45 | "docs/changelog.md",
46 | "docs/references.bib",
47 | "docs/references.md",
48 | "docs/notebooks/example.ipynb"
49 | ]
50 | }
51 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 |
11 | [*.{yml,yaml}]
12 | indent_size = 2
13 |
14 | [.cruft.json]
15 | indent_size = 2
16 |
17 | [Makefile]
18 | indent_style = tab
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Report something that is broken or incorrect
3 | labels: bug
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports)
9 | detailing how to provide the necessary information for us to reproduce your bug. In brief:
10 | * Please provide exact steps how to reproduce the bug in a clean Python environment.
11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure.
12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly
13 | available datasets or to share a subset of your data.
14 |
15 | - type: textarea
16 | id: report
17 | attributes:
18 | label: Report
19 | description: A clear and concise description of what the bug is.
20 | validations:
21 | required: true
22 |
23 | - type: textarea
24 | id: versions
25 | attributes:
26 | label: Version information
27 | description: |
28 | Please paste below the output of
29 |
30 | ```python
31 | import session_info
32 | session_info.show(html=False, dependencies=True)
33 | ```
34 | placeholder: |
35 | -----
36 | anndata 0.8.0rc2.dev27+ge524389
37 | session_info 1.0.0
38 | -----
39 | asttokens NA
40 | awkward 1.8.0
41 | backcall 0.2.0
42 | cython_runtime NA
43 | dateutil 2.8.2
44 | debugpy 1.6.0
45 | decorator 5.1.1
46 | entrypoints 0.4
47 | executing 0.8.3
48 | h5py 3.7.0
49 | ipykernel 6.15.0
50 | jedi 0.18.1
51 | mpl_toolkits NA
52 | natsort 8.1.0
53 | numpy 1.22.4
54 | packaging 21.3
55 | pandas 1.4.2
56 | parso 0.8.3
57 | pexpect 4.8.0
58 | pickleshare 0.7.5
59 | pkg_resources NA
60 | prompt_toolkit 3.0.29
61 | psutil 5.9.1
62 | ptyprocess 0.7.0
63 | pure_eval 0.2.2
64 | pydev_ipython NA
65 | pydevconsole NA
66 | pydevd 2.8.0
67 | pydevd_file_utils NA
68 | pydevd_plugins NA
69 | pydevd_tracing NA
70 | pygments 2.12.0
71 | pytz 2022.1
72 | scipy 1.8.1
73 | setuptools 62.5.0
74 | setuptools_scm NA
75 | six 1.16.0
76 | stack_data 0.3.0
77 | tornado 6.1
78 | traitlets 5.3.0
79 | wcwidth 0.2.5
80 | zmq 23.1.0
81 | -----
82 | IPython 8.4.0
83 | jupyter_client 7.3.4
84 | jupyter_core 4.10.0
85 | -----
86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0]
87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35
88 | -----
89 | Session information updated at 2022-07-07 17:55
90 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Scverse Community Forum
4 | url: https://discourse.scverse.org/
5 | about: If you have questions about “How to do X”, please ask them here.
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Propose a new feature for scib-metrics
3 | labels: enhancement
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Description of feature
9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered.
10 | validations:
11 | required: true
12 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | package:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Set up Python 3.11
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: "3.11"
22 | cache: "pip"
23 | cache-dependency-path: "**/pyproject.toml"
24 | - name: Install build dependencies
25 | run: python -m pip install --upgrade pip wheel twine build
26 | - name: Build package
27 | run: python -m build
28 | - name: Check package
29 | run: twine check --strict dist/*.whl
30 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/
8 | jobs:
9 | release:
10 | name: Upload release to PyPI
11 | runs-on: ubuntu-latest
12 | environment:
13 | name: pypi
14 | url: https://pypi.org/p/scib-metrics
15 | permissions:
16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
17 | steps:
18 | - uses: actions/checkout@v4
19 | with:
20 | filter: blob:none
21 | fetch-depth: 0
22 | - uses: actions/setup-python@v4
23 | with:
24 | python-version: "3.x"
25 | cache: "pip"
26 | - run: pip install build
27 | - run: python -m build
28 | - name: Publish package distributions to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 |
--------------------------------------------------------------------------------
/.github/workflows/test_linux.yaml:
--------------------------------------------------------------------------------
1 | name: Test (Linux)
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | test:
15 | runs-on: ${{ matrix.os }}
16 | defaults:
17 | run:
18 | shell: bash -e {0} # -e to fail on error
19 |
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | os: [ubuntu-latest]
24 | python: ["3.10", "3.11", "3.12"]
25 |
26 | name: Integration
27 |
28 | env:
29 | OS: ${{ matrix.os }}
30 | PYTHON: ${{ matrix.python }}
31 |
32 | steps:
33 | - uses: actions/checkout@v4
34 | - name: Set up Python ${{ matrix.python }}
35 | uses: actions/setup-python@v5
36 | with:
37 | python-version: ${{ matrix.python }}
38 | cache: "pip"
39 | cache-dependency-path: "**/pyproject.toml"
40 |
41 | - name: Install test dependencies
42 | run: |
43 | python -m pip install --upgrade pip wheel
44 |
45 | - name: Install dependencies
46 | run: |
47 | pip install ".[dev,test]"
48 |
49 | - name: Test
50 | env:
51 | MPLBACKEND: agg
52 | PLATFORM: ${{ matrix.os }}
53 | DISPLAY: :42
54 | run: |
55 | coverage run -m pytest -v --color=yes
56 | - name: Report coverage
57 | run: |
58 | coverage report
59 | - name: Upload coverage
60 | uses: codecov/codecov-action@v4
61 | with:
62 | token: ${{ secrets.CODECOV_TOKEN }}
63 |
--------------------------------------------------------------------------------
/.github/workflows/test_linux_cuda.yaml:
--------------------------------------------------------------------------------
1 | name: Test (Linux, CUDA)
2 |
3 | on:
4 | pull_request:
5 | branches: [main]
6 | types: [labeled, synchronize, opened]
7 | schedule:
8 | - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
9 | workflow_dispatch:
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | test:
17 | # if PR has label "cuda tests" or "all tests" or if scheduled or manually triggered
18 | if: >-
19 | (
20 | contains(github.event.pull_request.labels.*.name, 'cuda tests') ||
21 | contains(github.event.pull_request.labels.*.name, 'all tests') ||
22 | contains(github.event_name, 'schedule') ||
23 | contains(github.event_name, 'workflow_dispatch')
24 | )
25 | runs-on: [self-hosted, Linux, X64, CUDA]
26 | defaults:
27 | run:
28 | shell: bash -e {0} # -e to fail on error
29 |
30 | strategy:
31 | fail-fast: false
32 | matrix:
33 | python: ["3.11"]
34 | cuda: ["11"]
35 |
36 | container:
37 | image: scverse/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}-base
38 | options: --user root --gpus all
39 |
40 | name: Integration (CUDA)
41 |
42 | env:
43 | OS: ${{ matrix.os }}
44 | PYTHON: ${{ matrix.python }}
45 |
46 | steps:
47 | - uses: actions/checkout@v4
48 |
49 | - name: Install dependencies
50 | run: |
51 | pip install ".[dev,test]"
52 |
53 | - name: Test
54 | env:
55 | MPLBACKEND: agg
56 | PLATFORM: ${{ matrix.os }}
57 | DISPLAY: :42
58 | run: |
59 | coverage run -m pytest -v --color=yes
60 | - name: Report coverage
61 | run: |
62 | coverage report
63 | - name: Upload coverage
64 | uses: codecov/codecov-action@v4
65 | with:
66 | token: ${{ secrets.CODECOV_TOKEN }}
67 |
--------------------------------------------------------------------------------
/.github/workflows/test_linux_pre.yaml:
--------------------------------------------------------------------------------
1 | name: Test (Linux, prereleases)
2 |
3 | on:
4 | pull_request:
5 | branches: [main]
6 | types: [labeled, synchronize, opened]
7 | schedule:
8 | - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
9 | workflow_dispatch:
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | test:
17 | # if PR has label "cuda tests" or "all tests" or if scheduled or manually triggered
18 | if: >-
19 | (
20 | contains(github.event.pull_request.labels.*.name, 'prerelease tests') ||
21 | contains(github.event.pull_request.labels.*.name, 'all tests') ||
22 | contains(github.event_name, 'schedule') ||
23 | contains(github.event_name, 'workflow_dispatch')
24 | )
25 | runs-on: ${{ matrix.os }}
26 | defaults:
27 | run:
28 | shell: bash -e {0} # -e to fail on error
29 |
30 | strategy:
31 | fail-fast: false
32 | matrix:
33 | os: [ubuntu-latest]
34 | python: ["3.10", "3.11", "3.12"]
35 |
36 | name: Integration (Prereleases)
37 |
38 | env:
39 | OS: ${{ matrix.os }}
40 | PYTHON: ${{ matrix.python }}
41 |
42 | steps:
43 | - uses: actions/checkout@v4
44 | - name: Set up Python ${{ matrix.python }}
45 | uses: actions/setup-python@v5
46 | with:
47 | python-version: ${{ matrix.python }}
48 | cache: "pip"
49 | cache-dependency-path: "**/pyproject.toml"
50 |
51 | - name: Install test dependencies
52 | run: |
53 | python -m pip install --upgrade pip wheel
54 |
55 | - name: Install dependencies
56 | run: |
57 | pip install --pre ".[dev,test]"
58 |
59 | - name: Test
60 | env:
61 | MPLBACKEND: agg
62 | PLATFORM: ${{ matrix.os }}
63 | DISPLAY: :42
64 | run: |
65 | coverage run -m pytest -v --color=yes
66 | - name: Report coverage
67 | run: |
68 | coverage report
69 | - name: Upload coverage
70 | uses: codecov/codecov-action@v4
71 | with:
72 | token: ${{ secrets.CODECOV_TOKEN }}
73 |
--------------------------------------------------------------------------------
/.github/workflows/test_macos.yaml:
--------------------------------------------------------------------------------
1 | name: Test (MacOS)
2 |
3 | on:
4 | schedule:
5 | - cron: "0 10 * * *" # runs at 10:00 UTC -> 03:00 PST every day
6 | workflow_dispatch:
7 |
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.ref }}
10 | cancel-in-progress: true
11 |
12 | jobs:
13 | test:
14 | runs-on: ${{ matrix.os }}
15 | defaults:
16 | run:
17 | shell: bash -e {0} # -e to fail on error
18 |
19 | strategy:
20 | fail-fast: false
21 | matrix:
22 | os: [macos-latest]
23 | python: ["3.10", "3.11", "3.12"]
24 |
25 | name: Integration
26 |
27 | env:
28 | OS: ${{ matrix.os }}
29 | PYTHON: ${{ matrix.python }}
30 |
31 | steps:
32 | - uses: actions/checkout@v4
33 | - name: Set up Python ${{ matrix.python }}
34 | uses: actions/setup-python@v5
35 | with:
36 | python-version: ${{ matrix.python }}
37 | cache: "pip"
38 | cache-dependency-path: "**/pyproject.toml"
39 |
40 | - name: Install test dependencies
41 | run: |
42 | python -m pip install --upgrade pip wheel
43 |
44 | - name: Install dependencies
45 | run: |
46 | pip install ".[dev,test]"
47 |
48 | - name: Test
49 | env:
50 | MPLBACKEND: agg
51 | PLATFORM: ${{ matrix.os }}
52 | DISPLAY: :42
53 | run: |
54 | coverage run -m pytest -v --color=yes
55 | - name: Report coverage
56 | run: |
57 | coverage report
58 | - name: Upload coverage
59 | uses: codecov/codecov-action@v4
60 | with:
61 | token: ${{ secrets.CODECOV_TOKEN }}
62 |
--------------------------------------------------------------------------------
/.github/workflows/test_macos_m1.yaml:
--------------------------------------------------------------------------------
1 | name: Test (MacOS M1)
2 |
3 | on:
4 | schedule:
5 | - cron: "0 10 * * *" # runs at 10:00 UTC -> 03:00 PST every day
6 | workflow_dispatch:
7 |
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.ref }}
10 | cancel-in-progress: true
11 |
12 | jobs:
13 | test:
14 | runs-on: ${{ matrix.os }}
15 | defaults:
16 | run:
17 | shell: bash -e {0} # -e to fail on error
18 |
19 | strategy:
20 | fail-fast: false
21 | matrix:
22 | os: [macos-14]
23 | python: ["3.10", "3.11", "3.12"]
24 |
25 | name: Integration
26 |
27 | env:
28 | OS: ${{ matrix.os }}
29 | PYTHON: ${{ matrix.python }}
30 |
31 | steps:
32 | - uses: actions/checkout@v4
33 | - name: Set up Python ${{ matrix.python }}
34 | uses: actions/setup-python@v5
35 | with:
36 | python-version: ${{ matrix.python }}
37 | cache: "pip"
38 | cache-dependency-path: "**/pyproject.toml"
39 |
40 | - name: Install test dependencies
41 | run: |
42 | python -m pip install --upgrade pip wheel
43 |
44 | - name: Install dependencies
45 | run: |
46 | pip install ".[dev,test]"
47 |
48 | - name: Test
49 | env:
50 | MPLBACKEND: agg
51 | PLATFORM: ${{ matrix.os }}
52 | DISPLAY: :42
53 | run: |
54 | coverage run -m pytest -v --color=yes
55 | - name: Report coverage
56 | run: |
57 | coverage report
58 | - name: Upload coverage
59 | uses: codecov/codecov-action@v4
60 | with:
61 | token: ${{ secrets.CODECOV_TOKEN }}
62 |
--------------------------------------------------------------------------------
/.github/workflows/test_windows.yaml:
--------------------------------------------------------------------------------
1 | name: Test (Windows)
2 |
3 | on:
4 | schedule:
5 | - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
6 | workflow_dispatch:
7 |
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.ref }}
10 | cancel-in-progress: true
11 |
12 | jobs:
13 | test:
14 | runs-on: ${{ matrix.os }}
15 | defaults:
16 | run:
17 | shell: bash -e {0} # -e to fail on error
18 |
19 | strategy:
20 | fail-fast: false
21 | matrix:
22 | os: [windows-latest]
23 | python: ["3.10", "3.11", "3.12"]
24 |
25 | name: Integration
26 |
27 | env:
28 | OS: ${{ matrix.os }}
29 | PYTHON: ${{ matrix.python }}
30 |
31 | steps:
32 | - uses: actions/checkout@v4
33 | - name: Set up Python ${{ matrix.python }}
34 | uses: actions/setup-python@v5
35 | with:
36 | python-version: ${{ matrix.python }}
37 | cache: "pip"
38 | cache-dependency-path: "**/pyproject.toml"
39 |
40 | - name: Install test dependencies
41 | run: |
42 | python -m pip install --upgrade pip wheel
43 |
44 | - name: Install dependencies
45 | run: |
46 | pip install ".[dev,test]"
47 |
48 | - name: Test
49 | env:
50 | MPLBACKEND: agg
51 | PLATFORM: ${{ matrix.os }}
52 | DISPLAY: :42
53 | run: |
54 | coverage run -m pytest -v --color=yes
55 | - name: Report coverage
56 | run: |
57 | coverage report
58 | - name: Upload coverage
59 | uses: codecov/codecov-action@v4
60 | with:
61 | token: ${{ secrets.CODECOV_TOKEN }}
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Temp files
2 | .DS_Store
3 | *~
4 | buck-out/
5 |
6 | # Compiled files
7 | .venv/
8 | __pycache__/
9 | .mypy_cache/
10 | .ruff_cache/
11 |
12 | # Distribution / packaging
13 | /build/
14 | /dist/
15 | /*.egg-info/
16 |
17 | # Tests and coverage
18 | /.pytest_cache/
19 | /.cache/
20 | /data/
21 | /node_modules/
22 |
23 | # docs
24 | /docs/generated/
25 | /docs/_build/
26 |
27 | # IDEs
28 | /.idea/
29 | /.vscode/
30 |
31 | # node
32 | /node_modules/
33 |
34 | docs/notebooks/data/
35 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - pre-commit
6 | - pre-push
7 | minimum_pre_commit_version: 2.16.0
8 | repos:
9 | - repo: https://github.com/pre-commit/mirrors-prettier
10 | rev: v4.0.0-alpha.8
11 | hooks:
12 | - id: prettier
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: v0.11.11
15 | hooks:
16 | - id: ruff
17 | types_or: [python, pyi, jupyter]
18 | args: [--fix, --exit-non-zero-on-fix]
19 | - id: ruff-format
20 | types_or: [python, pyi, jupyter]
21 | - repo: https://github.com/pre-commit/pre-commit-hooks
22 | rev: v5.0.0
23 | hooks:
24 | - id: detect-private-key
25 | - id: check-ast
26 | - id: end-of-file-fixer
27 | - id: mixed-line-ending
28 | args: [--fix=lf]
29 | - id: trailing-whitespace
30 | - id: check-case-conflict
31 | # Check that there are no merge conflicts (could be generated by template sync)
32 | - id: check-merge-conflict
33 | args: [--assume-in-merge]
34 | - repo: local
35 | hooks:
36 | - id: forbid-to-commit
37 | name: Don't commit rej files
38 | entry: |
39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
40 | Fix the merge conflicts manually and remove the .rej files.
41 | language: fail
42 | files: '.*\.rej$'
43 | - repo: https://github.com/kynan/nbstripout
44 | rev: 0.8.1
45 | hooks:
46 | - id: nbstripout
47 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html
2 | version: 2
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.10"
7 | sphinx:
8 | configuration: docs/conf.py
9 | # disable this for more lenient docs builds
10 | fail_on_warning: false
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements:
16 | - doc
17 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog][],
6 | and this project adheres to [Semantic Versioning][].
7 |
8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/
9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html
10 |
11 | ## 0.6.0 (unreleased)
12 |
13 | ## 0.5.4 (2025-04-23)
14 |
15 | ### Fixed
16 |
17 | - Apply default values for benchmarker metrics {pr}`203`.
18 |
19 | ## 0.5.3 (2025-02-17)
20 |
21 | ### Removed
22 |
23 | - Reverted a change that was needed for scib-autotune in scvi-tools {pr}`189`.
24 |
25 | ## 0.5.2 (2025-02-13)
26 |
27 | ### Added
28 |
29 | - Add `progress_bar` argument to {class}`scib_metrics.benchmark.Benchmarker` {pr}`152`.
30 | - Add ability of {class}`scib_metrics.benchmark.Benchmarker` plotting code to handle missing sets of metrics {pr}`181`.
31 | - Add random score in case of aggregate metrics not selected to be used in scib autotune in scvi-tools, {pr}`188`.
32 |
33 | ### Changed
34 |
35 | - Changed Leiden clustering now has a seed argument for reproducibility {pr}`173`.
36 | - Changed passing `None` to `bio_conservation_metrics` or `batch_correction_metrics` in {class}`scib_metrics.benchmark.Benchmarker` now implies to skip this set of metrics {pr}`181`.
37 |
38 | ### Fixed
39 |
40 | - Fix neighbors connectivities in test to use new scanpy fn {pr}`170`.
41 | - Fix Kmeans test {pr}`172`.
42 | - Fix deprecation and future warnings {pr}`171`.
43 | - Fix lisi return type and docstring {pr}`182`.
44 |
45 | ## 0.5.1 (2024-02-23)
46 |
47 | ### Changed
48 |
49 | - Replace removed {class}`jax.random.KeyArray` with {class}`jax.Array` {pr}`135`.
50 |
51 | ## 0.5.0 (2024-01-04)
52 |
53 | ### Changed
54 |
55 | - Refactor all relevant metrics to use `NeighborsResults` as input instead of sparse
56 | distance/connectivity matrices {pr}`129`.
57 |
58 | ## 0.4.1 (2023-10-08)
59 |
60 | ### Fixed
61 |
62 | - Fix KMeans. All previous versions had a bug with KMeans and ARI/NMI metrics are not reliable
63 | with this clustering {pr}`115`.
64 |
65 | ## 0.4.0 (2023-09-19)
66 |
67 | ### Added
68 |
69 | - Update isolated labels to use newest scib methodology {pr}`108`.
70 |
71 | ### Fixed
72 |
73 | - Fix jax one-hot error {pr}`107`.
74 |
75 | ### Removed
76 |
77 | - Drop Python 3.8 {pr}`107`.
78 |
79 | ## 0.3.3 (2023-03-29)
80 |
81 | ### Fixed
82 |
83 | - Large scale tutorial now properly uses gpu index {pr}`92`
84 |
85 | ## 0.3.2 (2023-03-13)
86 |
87 | ### Changed
88 |
89 | - Switch to Ruff for linting/formatting {pr}`87`
90 | - Update cookiecutter template {pr}`88`
91 |
92 | ## 0.3.1 (2023-02-16)
93 |
94 | ### Changed
95 |
96 | - Expose chunk size for silhouette {pr}`82`
97 |
98 | ## 0.3.0 (2023-02-16)
99 |
100 | ### Changed
101 |
102 | - Rename `KmeansJax` to `Kmeans` and fix ++ initialization, use Kmeans as default in benchmarker instead of Leiden {pr}`81`.
103 | - Warn about joblib, add progress bar postfix str {pr}`80`
104 |
105 | ## 0.2.0 (2023-02-02)
106 |
107 | ### Added
108 |
109 | - Allow custom nearest neighbors methods in Benchmarker {pr}`78`.
110 |
111 | ## 0.1.1 (2023-01-04)
112 |
113 | ### Added
114 |
115 | - Add new tutorial and fix scalability of lisi {pr}`71`.
116 |
117 | ## 0.1.0 (2023-01-03)
118 |
119 | ### Added
120 |
121 | - Add benchmarking pipeline with plotting {pr}`52` {pr}`69`.
122 |
123 | ### Fixed
124 |
125 | - Fix diffusion distance computation, affecting kbet {pr}`70`.
126 |
127 | ## 0.0.9 (2022-12-16)
128 |
129 | ### Added
130 |
131 | - Add kbet {pr}`60`.
132 | - Add graph connectivty metric {pr}`61`.
133 |
134 | ## 0.0.8 (2022-11-18)
135 |
136 | ### Changed
137 |
138 | - Switch to random kmeans initialization due to kmeans++ complexity issues {pr}`54`.
139 |
140 | ### Fixed
141 |
142 | - Begin fixes to make kmeans++ initialization faster {pr}`49`.
143 |
144 | ## 0.0.7 (2022-10-31)
145 |
146 | ### Changed
147 |
148 | - Move PCR to utils module in favor of PCR comparison {pr}`46`.
149 |
150 | ### Fixed
151 |
152 | - Fix memory issue in `KMeansJax` by using `_kmeans_full_run` with `map` instead of `vmap` {pr}`45`.
153 |
154 | ## 0.0.6 (2022-10-25)
155 |
156 | ### Changed
157 |
158 | - Reimplement silhouette in a memory constant way, pdist using lax scan {pr}`42`.
159 |
160 | ## 0.0.5 (2022-10-24)
161 |
162 | ### Added
163 |
164 | - Standardize language of docstring {pr}`30`.
165 | - Use K-means++ initialization {pr}`23`.
166 | - Add pc regression and pc comparsion {pr}`16` {pr}`38`.
167 | - Lax'd silhouette {pr}`33`.
168 | - Cookicutter template sync {pr}`35`.
169 |
170 | ## 0.0.4 (2022-10-10)
171 |
172 | ### Added
173 |
174 | - NMI/ARI metric with Leiden clustering resolution optimization {pr}`24`.
175 | - iLISI/cLISI metrics {pr}`20`.
176 |
177 | ## 0.0.1 - 0.0.3
178 |
179 | See the [GitHub releases][https://github.com/yoseflab/scib-metrics/releases] for details.
180 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, Adam Gayoso
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scib-metrics
2 |
3 | [![Stars][badge-stars]][link-stars]
4 | [![PyPI][badge-pypi]][link-pypi]
5 | [![PyPIDownloads][badge-downloads]][link-downloads]
6 | [![Docs][badge-docs]][link-docs]
7 | [![Build][badge-build]][link-build]
8 | [![Coverage][badge-cov]][link-cov]
9 | [![Discourse][badge-discourse]][link-discourse]
10 | [![Chat][badge-zulip]][link-zulip]
11 |
12 | [badge-stars]: https://img.shields.io/github/stars/YosefLab/scib-metrics?logo=GitHub&color=yellow
13 | [link-stars]: https://github.com/YosefLab/scib-metrics/stargazers
14 | [badge-pypi]: https://img.shields.io/pypi/v/scib-metrics.svg
15 | [link-pypi]: https://pypi.org/project/scib-metrics
16 | [badge-downloads]: https://static.pepy.tech/badge/scib-metrics
17 | [link-downloads]: https://pepy.tech/project/scib-metrics
18 | [badge-docs]: https://readthedocs.org/projects/scib-metrics/badge/?version=latest
19 | [link-docs]: https://scib-metrics.readthedocs.io/en/latest/?badge=latest
20 | [badge-build]: https://github.com/YosefLab/scib-metrics/actions/workflows/build.yaml/badge.svg
21 | [link-build]: https://github.com/YosefLab/scib-metrics/actions/workflows/build.yaml/
22 | [badge-cov]: https://codecov.io/gh/YosefLab/scib-metrics/branch/main/graph/badge.svg
23 | [link-cov]: https://codecov.io/gh/YosefLab/scib-metrics
24 | [badge-discourse]: https://img.shields.io/discourse/posts?color=yellow&logo=discourse&server=https%3A%2F%2Fdiscourse.scverse.org
25 | [link-discourse]: https://discourse.scverse.org/
26 | [badge-zulip]: https://img.shields.io/badge/zulip-join_chat-brightgreen.svg
27 | [link-zulip]: https://scverse.zulipchat.com/
28 |
29 | Accelerated and Python-only metrics for benchmarking single-cell integration outputs.
30 |
31 | This package contains implementations of metrics for evaluating the performance of single-cell omics data integration methods. The implementations of these metrics use [JAX](https://jax.readthedocs.io/en/latest/) when possible for jit-compilation and hardware acceleration. All implementations are in Python.
32 |
33 | Currently we are porting metrics used in the scIB [manuscript](https://www.nature.com/articles/s41592-021-01336-8) (and [code](https://github.com/theislab/scib)). Deviations from the original implementations are documented. However, metric values from this repository should not be compared to the scIB repository.
34 |
35 | ## Getting started
36 |
37 | Please refer to the [documentation][link-docs].
38 |
39 | ## Installation
40 |
41 | You need to have Python 3.10 or newer installed on your system. If you don't have
42 | Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
43 |
44 | There are several options to install scib-metrics:
45 |
46 | 1. Install the latest release on PyPI:
47 |
48 | ```bash
49 | pip install scib-metrics
50 | ```
51 |
52 | 2. Install the latest development version:
53 |
54 | ```bash
55 | pip install git+https://github.com/yoseflab/scib-metrics.git@main
56 | ```
57 |
58 | To leverage hardware acceleration (e.g., GPU) please install the apprpriate version of [JAX](https://github.com/google/jax#installation) separately. Often this can be easier by using conda-distributed versions of JAX.
59 |
60 | ## Release notes
61 |
62 | See the [changelog][changelog].
63 |
64 | ## Contact
65 |
66 | For questions and help requests, you can reach out in the [scverse Discourse][link-discourse].
67 | If you found a bug, please use the [issue tracker][issue-tracker].
68 |
69 | ## Citation
70 |
71 | References for individual metrics can be found in the corresponding documentation. This package is heavily inspired by the single-cell integration benchmarking work:
72 |
73 | ```
74 | @article{luecken2022benchmarking,
75 | title={Benchmarking atlas-level data integration in single-cell genomics},
76 | author={Luecken, Malte D and B{\"u}ttner, Maren and Chaichoompu, Kridsadakorn and Danese, Anna and Interlandi, Marta and M{\"u}ller, Michaela F and Strobl, Daniel C and Zappia, Luke and Dugas, Martin and Colom{\'e}-Tatch{\'e}, Maria and others},
77 | journal={Nature methods},
78 | volume={19},
79 | number={1},
80 | pages={41--50},
81 | year={2022},
82 | publisher={Nature Publishing Group}
83 | }
84 | ```
85 |
86 | [scverse-discourse]: https://discourse.scverse.org/
87 | [issue-tracker]: https://github.com/YosefLab/scib-metrics/issues
88 | [changelog]: https://scib-metrics.readthedocs.io/en/latest/changelog.html
89 | [link-api]: https://scib-metrics.readthedocs.io/en/latest/api.html
90 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/docs/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */
2 | div.cell_output table.dataframe {
3 | font-size: 0.8em;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/docs/_templates/.gitkeep
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/_templates/class_no_inherited.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 | :show-inheritance:
9 |
10 | {% block attributes %}
11 | {% if attributes %}
12 | Attributes table
13 | ~~~~~~~~~~~~~~~~
14 |
15 | .. autosummary::
16 | {% for item in attributes %}
17 | {%- if item not in inherited_members%}
18 | ~{{ fullname }}.{{ item }}
19 | {%- endif -%}
20 | {%- endfor %}
21 | {% endif %}
22 | {% endblock %}
23 |
24 |
25 | {% block methods %}
26 | {% if methods %}
27 | Methods table
28 | ~~~~~~~~~~~~~~
29 |
30 | .. autosummary::
31 | {% for item in methods %}
32 | {%- if item != '__init__' and item not in inherited_members%}
33 | ~{{ fullname }}.{{ item }}
34 | {%- endif -%}
35 |
36 | {%- endfor %}
37 | {% endif %}
38 | {% endblock %}
39 |
40 | {% block attributes_documentation %}
41 | {% if attributes %}
42 | Attributes
43 | ~~~~~~~~~~
44 |
45 | {% for item in attributes %}
46 | {{ item }}
47 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
48 |
49 | .. autoattribute:: {{ [objname, item] | join(".") }}
50 | {%- endfor %}
51 |
52 | {% endif %}
53 | {% endblock %}
54 |
55 | {% block methods_documentation %}
56 | {% if methods %}
57 | Methods
58 | ~~~~~~~
59 |
60 | {% for item in methods %}
61 | {%- if item != '__init__' and item not in inherited_members%}
62 | {{ item }}
63 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
64 | .. automethod:: {{ [objname, item] | join(".") }}
65 | {%- endif -%}
66 | {%- endfor %}
67 |
68 | {% endif %}
69 | {% endblock %}
70 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ## Benchmarking pipeline
4 |
5 | Import as:
6 |
7 | ```
8 | from scib_metrics.benchmark import Benchmarker
9 | ```
10 |
11 | ```{eval-rst}
12 | .. module:: scib_metrics.benchmark
13 | .. currentmodule:: scib_metrics
14 |
15 | .. autosummary::
16 | :toctree: generated
17 |
18 | benchmark.Benchmarker
19 | benchmark.BioConservation
20 | benchmark.BatchCorrection
21 | ```
22 |
23 | ## Metrics
24 |
25 | Import as:
26 |
27 | ```
28 | import scib_metrics
29 | scib_metrics.ilisi_knn(...)
30 | ```
31 |
32 | ```{eval-rst}
33 | .. module:: scib_metrics
34 | .. currentmodule:: scib_metrics
35 |
36 | .. autosummary::
37 | :toctree: generated
38 |
39 | isolated_labels
40 | nmi_ari_cluster_labels_kmeans
41 | nmi_ari_cluster_labels_leiden
42 | pcr_comparison
43 | silhouette_label
44 | silhouette_batch
45 | ilisi_knn
46 | clisi_knn
47 | kbet
48 | kbet_per_label
49 | graph_connectivity
50 | ```
51 |
52 | ## Utils
53 |
54 | ```{eval-rst}
55 | .. module:: scib_metrics.utils
56 | .. currentmodule:: scib_metrics
57 |
58 | .. autosummary::
59 | :toctree: generated
60 |
61 | utils.cdist
62 | utils.pdist_squareform
63 | utils.silhouette_samples
64 | utils.KMeans
65 | utils.pca
66 | utils.principal_component_regression
67 | utils.one_hot
68 | utils.compute_simpson_index
69 | utils.convert_knn_graph_to_idx
70 | utils.check_square
71 | utils.diffusion_nn
72 | ```
73 |
74 | ### Nearest neighbors
75 |
76 | ```{eval-rst}
77 | .. module:: scib_metrics.nearest_neighbors
78 | .. currentmodule:: scib_metrics
79 |
80 | .. autosummary::
81 | :toctree: generated
82 |
83 | nearest_neighbors.pynndescent
84 | nearest_neighbors.jax_approx_min_k
85 | nearest_neighbors.NeighborsResults
86 | ```
87 |
88 | ## Settings
89 |
90 | An instance of the {class}`~scib_metrics._settings.ScibConfig` is available as `scib_metrics.settings` and allows configuring scib_metrics.
91 |
92 | ```{eval-rst}
93 | .. autosummary::
94 | :toctree: reference/
95 | :nosignatures:
96 |
97 | _settings.ScibConfig
98 | ```
99 |
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | ```{include} ../CHANGELOG.md
2 |
3 | ```
4 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 | from typing import Any
9 | import subprocess
10 | import os
11 | import importlib
12 | import inspect
13 | import re
14 | import sys
15 | from datetime import datetime
16 | from importlib.metadata import metadata
17 | from pathlib import Path
18 |
19 | HERE = Path(__file__).parent
20 | sys.path.insert(0, str(HERE / "extensions"))
21 |
22 |
23 | # -- Project information -----------------------------------------------------
24 |
25 | project_name = "scib-metrics"
26 | info = metadata(project_name)
27 | package_name = "scib_metrics"
28 | author = info["Author"]
29 | copyright = f"{datetime.now():%Y}, {author}."
30 | version = info["Version"]
31 | repository_url = f"https://github.com/YosefLab/{project_name}"
32 |
33 | # The full version, including alpha/beta/rc tags
34 | release = info["Version"]
35 |
36 | bibtex_bibfiles = ["references.bib"]
37 | templates_path = ["_templates"]
38 | nitpicky = True # Warn about broken links
39 | needs_sphinx = "4.0"
40 |
41 | html_context = {
42 | "display_github": True, # Integrate GitHub
43 | "github_user": "yoseflab", # Username
44 | "github_repo": project_name, # Repo name
45 | "github_version": "main", # Version
46 | "conf_py_path": "/docs/", # Path in the checkout to the docs root
47 | }
48 |
49 | # -- General configuration ---------------------------------------------------
50 |
51 | # Add any Sphinx extension module names here, as strings.
52 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
53 | extensions = [
54 | "myst_nb",
55 | "sphinx.ext.autodoc",
56 | "sphinx_copybutton",
57 | "sphinx.ext.linkcode",
58 | "sphinx.ext.intersphinx",
59 | "sphinx.ext.autosummary",
60 | "sphinx.ext.napoleon",
61 | "sphinx.ext.extlinks",
62 | "sphinxcontrib.bibtex",
63 | "sphinx_autodoc_typehints",
64 | "sphinx.ext.mathjax",
65 | "IPython.sphinxext.ipython_console_highlighting",
66 | "sphinxext.opengraph",
67 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
68 | ]
69 |
70 | autosummary_generate = True
71 | autodoc_member_order = "groupwise"
72 | default_role = "literal"
73 | bibtex_reference_style = "author_year"
74 | napoleon_google_docstring = False
75 | napoleon_numpy_docstring = True
76 | napoleon_include_init_with_doc = False
77 | napoleon_use_rtype = True # having a separate entry generally helps readability
78 | napoleon_use_param = True
79 | myst_heading_anchors = 6 # create anchors for h1-h6
80 | myst_enable_extensions = [
81 | "amsmath",
82 | "colon_fence",
83 | "deflist",
84 | "dollarmath",
85 | "html_image",
86 | "html_admonition",
87 | ]
88 | myst_url_schemes = ("http", "https", "mailto")
89 | nb_output_stderr = "remove"
90 | nb_execution_mode = "off"
91 | nb_merge_streams = True
92 | typehints_defaults = "braces"
93 |
94 | source_suffix = {
95 | ".rst": "restructuredtext",
96 | ".ipynb": "myst-nb",
97 | ".myst": "myst-nb",
98 | }
99 |
100 | intersphinx_mapping = {
101 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
102 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None),
103 | "matplotlib": ("https://matplotlib.org/", None),
104 | "numpy": ("https://numpy.org/doc/stable/", None),
105 | "pandas": ("https://pandas.pydata.org/docs/", None),
106 | "python": ("https://docs.python.org/3", None),
107 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
108 | "sklearn": ("https://scikit-learn.org/stable/", None),
109 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None),
110 | "jax": ("https://jax.readthedocs.io/en/latest/", None),
111 | "plottable": ("https://plottable.readthedocs.io/en/latest/", None),
112 | }
113 |
114 | # List of patterns, relative to source directory, that match files and
115 | # directories to ignore when looking for source files.
116 | # This pattern also affects html_static_path and html_extra_path.
117 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
118 |
119 | # extlinks config
120 | extlinks = {
121 | "issue": (f"{repository_url}/issues/%s", "#%s"),
122 | "pr": (f"{repository_url}/pull/%s", "#%s"),
123 | "ghuser": ("https://github.com/%s", "@%s"),
124 | }
125 |
126 | # -- Linkcode settings -------------------------------------------------
127 |
128 |
129 | def git(*args):
130 | """Run a git command and return the output."""
131 | return subprocess.check_output(["git", *args]).strip().decode()
132 |
133 |
134 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192
135 | # Current git reference. Uses branch/tag name if found, otherwise uses commit hash
136 | git_ref = None
137 | try:
138 | git_ref = git("name-rev", "--name-only", "--no-undefined", "HEAD")
139 | git_ref = re.sub(r"^(remotes/[^/]+|tags)/", "", git_ref)
140 | except Exception:
141 | pass
142 |
143 | # (if no name found or relative ref, use commit hash instead)
144 | if not git_ref or re.search(r"[\^~]", git_ref):
145 | try:
146 | git_ref = git("rev-parse", "HEAD")
147 | except Exception:
148 | git_ref = "main"
149 |
150 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192
151 | github_repo = "https://github.com/" + html_context["github_user"] + "/" + project_name
152 | _project_module_path = os.path.dirname(importlib.util.find_spec(package_name).origin) # type: ignore
153 |
154 |
155 | def linkcode_resolve(domain, info):
156 | """Resolve links for the linkcode extension."""
157 | if domain != "py":
158 | return None
159 |
160 | try:
161 | obj: Any = sys.modules[info["module"]]
162 | for part in info["fullname"].split("."):
163 | obj = getattr(obj, part)
164 | obj = inspect.unwrap(obj)
165 |
166 | if isinstance(obj, property):
167 | obj = inspect.unwrap(obj.fget) # type: ignore
168 |
169 | path = os.path.relpath(inspect.getsourcefile(obj), start=_project_module_path) # type: ignore
170 | src, lineno = inspect.getsourcelines(obj)
171 | except Exception:
172 | return None
173 |
174 | path = f"{path}#L{lineno}-L{lineno + len(src) - 1}"
175 | return f"{github_repo}/blob/{git_ref}/src/{package_name}/{path}"
176 |
177 |
178 | # -- Options for HTML output -------------------------------------------------
179 |
180 | # The theme to use for HTML and HTML Help pages. See the documentation for
181 | # a list of builtin themes.
182 | #
183 | html_theme = "sphinx_book_theme"
184 | html_static_path = ["_static"]
185 | html_css_files = ["css/custom.css"]
186 | html_title = "scib-metrics"
187 |
188 | html_theme_options = {
189 | "repository_url": github_repo,
190 | "use_repository_button": True,
191 | }
192 |
193 | pygments_style = "default"
194 |
195 | nitpick_ignore = [
196 | # If building the documentation fails because of a missing link that is outside your control,
197 | # you can add an exception to this list.
198 | ]
199 |
200 |
201 | def setup(app):
202 | """App setup hook."""
203 | app.add_config_value(
204 | "recommonmark_config",
205 | {
206 | "auto_toc_tree_section": "Contents",
207 | "enable_auto_toc_tree": True,
208 | "enable_math": True,
209 | "enable_inline_math": False,
210 | "enable_eval_rst": True,
211 | },
212 | True,
213 | )
214 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing guide
2 |
3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too.
4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important
5 | information to get you started on contributing.
6 |
7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer
8 | to the [scanpy developer guide][].
9 |
10 | ## Installing dev dependencies
11 |
12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build
13 | the documentation_. It's easy to install them using `pip`:
14 |
15 | ```bash
16 | cd scib-metrics
17 | pip install -e ".[dev,test,doc]"
18 | ```
19 |
20 | ## Code-style
21 |
22 | This package uses [pre-commit][] to enforce consistent code-styles.
23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message.
24 |
25 | To enable pre-commit locally, simply run
26 |
27 | ```bash
28 | pre-commit install
29 | ```
30 |
31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time.
32 |
33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before
34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message.
35 |
36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use
37 |
38 | ```bash
39 | git pull --rebase
40 | ```
41 |
42 | to integrate the changes into yours.
43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage.
44 |
45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors]
46 | and [prettier][prettier-editors].
47 |
48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/
49 | [prettier-editors]: https://prettier.io/docs/en/editors.html
50 |
51 | ## Writing tests
52 |
53 | ```{note}
54 | Remember to first install the package with `pip install -e '.[dev,test]'`
55 | ```
56 |
57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added
58 | to the package.
59 |
60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the
61 | command line by executing
62 |
63 | ```bash
64 | pytest
65 | ```
66 |
67 | in the root of the repository.
68 |
69 | ### Continuous integration
70 |
71 | Continuous integration will automatically run the tests on all pull requests and test
72 | against the minimum and maximum supported Python version.
73 |
74 | Additionally, there's a CI job that tests against pre-releases of all dependencies
75 | (if there are any). The purpose of this check is to detect incompatibilities
76 | of new package versions early on and gives you time to fix the issue or reach
77 | out to the developers of the dependency before the package is released to a wider audience.
78 |
79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests
80 |
81 | ## Publishing a release
82 |
83 | ### Updating the version number
84 |
85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief
86 |
87 | > Given a version number MAJOR.MINOR.PATCH, increment the:
88 | >
89 | > 1. MAJOR version when you make incompatible API changes,
90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and
91 | > 3. PATCH version when you make backwards compatible bug fixes.
92 | >
93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format.
94 |
95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub.
96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI.
97 |
98 | ## Writing documentation
99 |
100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:
101 |
102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types
106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
107 |
108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
109 | on how to write documentation.
110 |
111 | ### Tutorials with myst-nb and jupyter notebooks
112 |
113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][].
114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells.
115 | It is your responsibility to update and re-run the notebook whenever necessary.
116 |
117 | If you are interested in automatically running notebooks as part of the continuous integration, please check
118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse`
119 | repository.
120 |
121 | #### Hints
122 |
123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
124 | if you do so can sphinx automatically create a link to the external documentation.
125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to
126 | the `nitpick_ignore` list in `docs/conf.py`
127 |
128 | #### Building the docs locally
129 |
130 | ```bash
131 | cd docs
132 | make html
133 | open _build/html/index.html
134 | ```
135 |
136 |
137 |
138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html
139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html
140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui
141 | [codecov]: https://about.codecov.io/sign-up/
142 | [codecov docs]: https://docs.codecov.com/docs
143 | [codecov bot]: https://docs.codecov.com/docs/team-bot
144 | [codecov app]: https://github.com/apps/codecov
145 | [pre-commit.ci]: https://pre-commit.ci/
146 | [readthedocs.org]: https://readthedocs.org/
147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/
148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/
149 | [pre-commit]: https://pre-commit.com/
150 | [anndata]: https://github.com/scverse/anndata
151 | [mudata]: https://github.com/scverse/mudata
152 | [pytest]: https://docs.pytest.org/
153 | [semver]: https://semver.org/
154 | [sphinx]: https://www.sphinx-doc.org/en/master/
155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html
156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html
158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints
159 | [pypi]: https://pypi.org/
160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository
161 |
--------------------------------------------------------------------------------
/docs/extensions/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/docs/extensions/.gitkeep
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | from __future__ import annotations
4 | from sphinx.ext.napoleon import NumpyDocstring
5 | from typing import TYPE_CHECKING
6 | import re
7 |
8 | if TYPE_CHECKING:
9 | from collections.abc import Generator, Iterable
10 |
11 | from sphinx.application import Sphinx
12 |
13 |
14 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
15 | for line in lines:
16 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line):
17 | yield f"-{m['param']} (:class:`~{m['type']}`)"
18 | else:
19 | yield line
20 |
21 |
22 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]:
23 | lines_raw = self._dedent(self._consume_to_next_section())
24 | if lines_raw[0] == ":":
25 | del lines_raw[0]
26 | lines = self._format_block(":returns: ", list(_process_return(lines_raw)))
27 | if lines and lines[-1]:
28 | lines.append("")
29 | return lines
30 |
31 |
32 | def setup(app: Sphinx):
33 | """Set app."""
34 | NumpyDocstring._parse_returns_section = _parse_returns_section
35 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{include} ../README.md
2 |
3 | ```
4 |
5 | ```{toctree}
6 | :hidden: true
7 | :maxdepth: 1
8 |
9 | api.md
10 | tutorials.md
11 | changelog.md
12 | template_usage.md
13 | contributing.md
14 | references.md
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{luecken2022benchmarking,
2 | title = {Benchmarking atlas-level data integration in single-cell genomics},
3 | author = {Luecken, Malte D and B{\"u}ttner, Maren and Chaichoompu, Kridsadakorn and Danese, Anna and Interlandi, Marta and M{\"u}ller, Michaela F and Strobl, Daniel C and Zappia, Luke and Dugas, Martin and Colom{\'e}-Tatch{\'e}, Maria and others},
4 | journal = {Nature methods},
5 | volume = {19},
6 | number = {1},
7 | pages = {41--50},
8 | year = {2022},
9 | publisher = {Nature Publishing Group}
10 | }
11 |
12 |
13 | @article{korsunsky2019harmony,
14 | title = {Fast, sensitive and accurate integration of single-cell data with
15 | Harmony},
16 | author = {Korsunsky, Ilya and Millard, Nghia and Fan, Jean and Slowikowski,
17 | Kamil and Zhang, Fan and Wei, Kevin and Baglaenko, Yuriy and
18 | Brenner, Michael and Loh, Po-Ru and Raychaudhuri, Soumya},
19 | journal = {Nat. Methods},
20 | volume = {16},
21 | number = {12},
22 | pages = {1289--1296},
23 | month = {dec},
24 | year = {2019},
25 | }
26 |
27 | @article{buttner2018,
28 | title = {A test metric for assessing single-cell {RNA}-seq batch correction},
29 | author = {Maren B\"{u}ttner and Zhichao Miao and F. Alexander Wolf and Sarah A. Teichmann and Fabian J. Theis},
30 | doi = {10.1038/s41592-018-0254-1},
31 | year = {2018},
32 | month = dec,
33 | journal = {Nature Methods},
34 | volume = {16},
35 | number = {1},
36 | pages = {43--49},
37 | publisher = {Springer Science and Business Media {LLC}}
38 | }
39 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/docs/template_usage.md:
--------------------------------------------------------------------------------
1 | # Using this template
2 |
3 | Welcome to the developer guidelines! This document is split into two parts:
4 |
5 | 1. The [repository setup](#setting-up-the-repository). This section is relevant primarily for the repository maintainer and shows how to connect
6 | continuous integration services and documents initial set-up of the repository.
7 | 2. The [contributor guide](contributing.md#contributing-guide). It contains information relevant to all developers who want to make a contribution.
8 |
9 | ## Setting up the repository
10 |
11 | ### First commit
12 |
13 | If you are reading this, you should have just completed the repository creation with :
14 |
15 | ```bash
16 | cruft create https://github.com/scverse/cookiecutter-scverse
17 | ```
18 |
19 | and you should have
20 |
21 | ```
22 | cd scib-metrics
23 | ```
24 |
25 | into the new project directory. Now that you have created a new repository locally, the first step is to push it to github. To do this, you'd have to create a **new repository** on github.
26 | You can follow the instructions directly on [github quickstart guide][].
27 | Since `cruft` already populated the local repository of your project with all the necessary files, we suggest to _NOT_ initialize the repository with a `README.md` file or `.gitignore`, because you might encounter git conflicts on your first push.
28 | If you are familiar with git and knows how to handle git conflicts, you can go ahead with your preferred choice.
29 |
30 | :::{note}
31 | If you are looking at this document in the [cookiecutter-scverse-instance][] repository documentation, throughout this document the name of the project is `cookiecutter-scverse-instance`. Otherwise it should be replaced by your new project name: `scib-metrics`.
32 | :::
33 |
34 | Now that your new project repository has been created on github at `https://github.com/adamgayoso/scib-metrics` you can push your first commit to github.
35 | To do this, simply follow the instructions on your github repository page or a more verbose walkthrough here:
36 |
37 | Assuming you are in `/your/path/to/scib-metrics`. Add all files and commit.
38 |
39 | ```bash
40 | # stage all files of your new repo
41 | git add --all
42 | # commit
43 | git commit -m "first commit"
44 | ```
45 |
46 | You'll notice that the command `git commit` installed a bunch of packages and triggered their execution: those are pre-commit! To read more about what they are and what they do, you can go to the related section [Pre-commit checks](#pre-commit-checks) in this document.
47 |
48 | :::{note}
49 | There is a chance that `git commit -m "first commit"` fails due to the `prettier` pre-commit formatting the file `.cruft.json`. No problem, you have just experienced what pre-commit checks do in action. Just go ahead and re-add the modified file and try to commit again:
50 |
51 | ```bash
52 | git add -u # update all tracked file
53 | git commit -m "first commit"
54 | ```
55 |
56 | :::
57 |
58 | Now that all the files of the newly created project have been committed, go ahead with the remaining steps:
59 |
60 | ```bash
61 | # update the `origin` of your local repo with the remote github link
62 | git remote add origin https://github.com/adamgayoso/scib-metrics.git
63 | # rename the default branch to main
64 | git branch -M main
65 | # push all your files to remote
66 | git push -u origin main
67 | ```
68 |
69 | Your project should be now available at `https://github.com/adamgayoso/scib-metrics`. While the repository at this point can be directly used, there are few remaining steps that needs to be done in order to achieve full functionality.
70 |
71 | ### Coverage tests with _Codecov_
72 |
73 | Coverage tells what fraction of the code is "covered" by unit tests, thereby encouraging contributors to
74 | [write tests](contributing.md#writing-tests).
75 | To enable coverage checks, head over to [codecov][] and sign in with your GitHub account.
76 | You'll find more information in "getting started" section of the [codecov docs][].
77 |
78 | In the `Actions` tab of your projects' github repository, you can see that the workflows are failing due to the **Upload coverage** step. The error message in the workflow should display something like:
79 |
80 | ```
81 | ...
82 | Retrying 5/5 in 2s..
83 | {'detail': ErrorDetail(string='Could not find a repository, try using repo upload token', code='not_found')}
84 | Error: 404 Client Error: Not Found for url:
85 | ...
86 | ```
87 |
88 | While [codecov docs][] has a very extensive documentation on how to get started, _if_ you are using the default settings of this template we can assume that you are using [codecov][] in a github action workflow and hence you can make use of the [codecov bot][].
89 |
90 | To set it up, simply go to the [codecov app][] page and follow the instructions to activate it for your repository.
91 | Once the activation is completed, go back to the `Actions` tab and re-run the failing workflows.
92 |
93 | The workflows should now succeed and you will be able to find the code coverage at this link: `https://app.codecov.io/gh/adamgayoso/scib-metrics`. You might have to wait couple of minutes and the coverage of this repository should be ~60%.
94 |
95 | If your repository is private, you will have to specify an additional token in the repository secrets. In brief, you need to:
96 |
97 | 1. Generate a Codecov Token by clicking _setup repo_ in the codecov dashboard.
98 | - If you have already set up codecov in the repository by following the previous steps, you can directly go to the codecov repo webpage.
99 | 2. Go to _Settings_ and copy **only** the token `_______-____-...`.
100 | 3. Go to _Settings_ of your newly created repository on GitHub.
101 | 4. Go to _Security > Secrets > Actions_.
102 | 5. Create new repository secret with name `CODECOV_TOKEN` and paste the token generated by codecov.
103 | 6. Past these additional lines in `/.github/workflows.test.yaml` under the **Upload coverage** step:
104 | ```bash
105 | - name: Upload coverage
106 | uses: codecov/codecov-action@v3
107 | with:
108 | token: ${{ secrets.CODECOV_TOKEN }}
109 | ```
110 | 7. Go back to github `Actions` page an re-run previously failed jobs.
111 |
112 | ### Documentation on _readthedocs_
113 |
114 | We recommend using [readthedocs.org][] (RTD) to build and host the documentation for your project.
115 | To enable readthedocs, head over to [their website][readthedocs.org] and sign in with your GitHub account.
116 | On the RTD dashboard choose "Import a Project" and follow the instructions to add your repository.
117 |
118 | - Make sure to choose the correct name of the default branch. On GitHub, the name of the default branch should be `main` (it has
119 | recently changed from `master` to `main`).
120 | - We recommend to enable documentation builds for pull requests (PRs). This ensures that a PR doesn't introduce changes
121 | that break the documentation. To do so, got to `Admin -> Advanced Settings`, check the
122 | `Build pull requests for this projects` option, and click `Save`. For more information, please refer to
123 | the [official RTD documentation](https://docs.readthedocs.io/en/stable/pull-requests.html).
124 | - If you find the RTD builds are failing, you can disable the `fail_on_warning` option in `.readthedocs.yaml`.
125 |
126 | If your project is private, there are ways to enable docs rendering on [readthedocs.org][] but it is more cumbersome and requires a different subscription for read the docs. See a guide [here](https://docs.readthedocs.io/en/stable/guides/importing-private-repositories.html).
127 |
128 | ### Pre-commit checks
129 |
130 | [Pre-commit][] checks are fast programs that
131 | check code for errors, inconsistencies and code styles, before the code
132 | is committed.
133 |
134 | This template uses a number of pre-commit checks. In this section we'll detail what is used, where they're defined, and how to modify these checks.
135 |
136 | #### Pre-commit CI
137 |
138 | We recommend setting up [pre-commit.ci][] to enforce consistency checks on every commit
139 | and pull-request.
140 |
141 | To do so, head over to [pre-commit.ci][] and click "Sign In With GitHub". Follow
142 | the instructions to enable pre-commit.ci for your account or your organization. You
143 | may choose to enable the service for an entire organization or on a per-repository basis.
144 |
145 | Once authorized, pre-commit.ci should automatically be activated.
146 |
147 | #### Overview of pre-commit hooks used by the template
148 |
149 | The following pre-commit hooks are for code style and format:
150 |
151 | - [black](https://black.readthedocs.io/en/stable/):
152 | standard code formatter in Python.
153 | - [blacken-docs](https://github.com/asottile/blacken-docs):
154 | black on Python code in docs.
155 | - [prettier](https://prettier.io/docs/en/index.html):
156 | standard code formatter for non-Python files (e.g. YAML).
157 | - [ruff][] based checks:
158 | - [isort](https://beta.ruff.rs/docs/rules/#isort-i) (rule category: `I`):
159 | sort module imports into sections and types.
160 | - [pydocstyle](https://beta.ruff.rs/docs/rules/#pydocstyle-d) (rule category: `D`):
161 | pydocstyle extension of flake8.
162 | - [flake8-tidy-imports](https://beta.ruff.rs/docs/rules/#flake8-tidy-imports-tid) (rule category: `TID`):
163 | tidy module imports.
164 | - [flake8-comprehensions](https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4) (rule category: `C4`):
165 | write better list/set/dict comprehensions.
166 | - [pyupgrade](https://beta.ruff.rs/docs/rules/#pyupgrade-up) (rule category: `UP`):
167 | upgrade syntax for newer versions of the language.
168 |
169 | The following pre-commit hooks are for errors and inconsistencies:
170 |
171 | - [pre-commit-hooks](https://github.com/pre-commit/pre-commit-hooks): generic pre-commit hooks for text files.
172 | - **detect-private-key**: checks for the existence of private keys.
173 | - **check-ast**: check whether files parse as valid python.
174 | - **end-of-file-fixer**: check files end in a newline and only a newline.
175 | - **mixed-line-ending**: checks mixed line ending.
176 | - **trailing-whitespace**: trims trailing whitespace.
177 | - **check-case-conflict**: check files that would conflict with case-insensitive file systems.
178 | - **forbid-to-commit**: Make sure that `*.rej` files cannot be commited.
179 | These files are created by the [automated template sync](#automated-template-sync)
180 | if there's a merge conflict and need to be addressed manually.
181 | - [ruff][] based checks:
182 | - [pyflakes](https://beta.ruff.rs/docs/rules/#pyflakes-f) (rule category: `F`):
183 | various checks for errors.
184 | - [pycodestyle](https://beta.ruff.rs/docs/rules/#pycodestyle-e-w) (rule category: `E`, `W`):
185 | various checks for errors.
186 | - [flake8-bugbear](https://beta.ruff.rs/docs/rules/#flake8-bugbear-b) (rule category: `B`):
187 | find possible bugs and design issues in program.
188 | - [flake8-blind-except](https://beta.ruff.rs/docs/rules/#flake8-blind-except-ble) (rule category: `BLE`):
189 | checks for blind, catch-all `except` statements.
190 | - [Ruff-specific rules](https://beta.ruff.rs/docs/rules/#ruff-specific-rules-ruf) (rule category: `RUF`):
191 | - `RUF100`: remove unneccesary `# noqa` comments ()
192 |
193 | #### How to add or remove pre-commit checks
194 |
195 | The [pre-commit checks](#pre-commit-checks) check for both correctness and stylistic errors.
196 | In some cases it might overshoot and you may have good reasons to ignore certain warnings.
197 | This section shows you where these checks are defined, and how to enable/ disable them.
198 |
199 | ##### pre-commit
200 |
201 | You can add or remove pre-commit checks by simply deleting relevant lines in the `.pre-commit-config.yaml` file under the repository root.
202 | Some pre-commit checks have additional options that can be specified either in the `pyproject.toml` (for `ruff` and `black`) or tool-specific
203 | config files, such as `.prettierrc.yml` for **prettier**.
204 |
205 | ##### Ruff
206 |
207 | This template configures `ruff` through the `[tool.ruff]` entry in the `pyproject.toml`.
208 | For further information `ruff` configuration, see [the docs](https://beta.ruff.rs/docs/configuration/).
209 |
210 | Ruff assigns code to the rules it checks (e.g. `E401`) and groups them under a rule category (e.g. `E`).
211 | Rule categories are selectively enabled by including them under the `select` key:
212 |
213 | ```toml
214 | [tool.ruff]
215 | ...
216 |
217 | select = [
218 | "F", # Errors detected by Pyflakes
219 | "E", # Error detected by Pycodestyle
220 | "W", # Warning detected by Pycodestyle
221 | "I", # isort
222 | ...
223 | ]
224 | ```
225 |
226 | The `ignore` entry is used to disable specific rules for the entire project.
227 | Add the rule code(s) you want to ignore and don't forget to add a comment explaining why.
228 | You can find a long list of checks that this template disables by default sitting there already.
229 |
230 | ```toml
231 | ignore = [
232 | ...
233 | # __magic__ methods are are often self-explanatory, allow missing docstrings
234 | "D105",
235 | ...
236 | ]
237 | ```
238 |
239 | Checks can be ignored per-file (or glob pattern) with `[tool.ruff.per-file-ignores]`.
240 |
241 | ```toml
242 | [tool.ruff.per-file-ignores]
243 | "docs/*" = ["I"]
244 | "tests/*" = ["D"]
245 | "*/__init__.py" = ["F401"]
246 | ```
247 |
248 | To ignore a specific rule on a per-case basis, you can add a `# noqa: [, , …]` comment to the offending line.
249 | Specify the rule code(s) to ignore, with e.g. `# noqa: E731`. Check the [Ruff guide][] for reference.
250 |
251 | ```{note}
252 | The `RUF100` check will remove rule codes that are no longer necessary from `noqa` comments.
253 | If you want to add a code that comes from a tool other than Ruff,
254 | add it to Ruff’s [`external = [...]`](https://beta.ruff.rs/docs/settings/#external) setting to prevent `RUF100` from removing it.
255 | ```
256 |
257 | [ruff]: https://beta.ruff.rs/docs/
258 | [ruff guide]: https://beta.ruff.rs/docs/configuration/#suppressing-errors
259 |
260 | ### API design
261 |
262 | Scverse ecosystem packages should operate on [AnnData][] and/or [MuData][] data structures and typically use an API
263 | as originally [introduced by scanpy][scanpy-api] with the following submodules:
264 |
265 | - `pp` for preprocessing
266 | - `tl` for tools (that, compared to `pp` generate interpretable output, often associated with a corresponding plotting
267 | function)
268 | - `pl` for plotting functions
269 |
270 | You may add additional submodules as appropriate. While we encourage to follow a scanpy-like API for ecosystem packages,
271 | there may also be good reasons to choose a different approach, e.g. using an object-oriented API.
272 |
273 | [scanpy-api]: https://scanpy.readthedocs.io/en/stable/usage-principles.html
274 |
275 | ### Using VCS-based versioning
276 |
277 | By default, the template uses hard-coded version numbers that are set in `pyproject.toml` and [managed with
278 | bump2version](contributing.md#publishing-a-release). If you prefer to have your project automatically infer version numbers from git
279 | tags, it is straightforward to switch to vcs-based versioning using [hatch-vcs][].
280 |
281 | In `pyproject.toml` add the following changes, and you are good to go!
282 |
283 | ```diff
284 | --- a/pyproject.toml
285 | +++ b/pyproject.toml
286 | @@ -1,11 +1,11 @@
287 | [build-system]
288 | build-backend = "hatchling.build"
289 | -requires = ["hatchling"]
290 | +requires = ["hatchling", "hatch-vcs"]
291 | [project]
292 | name = "scib-metrics"
293 | -version = "0.3.1dev"
294 | +dynamic = ["version"]
295 | @@ -60,6 +60,9 @@
296 | +[tool.hatch.version]
297 | +source = "vcs"
298 | +
299 | [tool.coverage.run]
300 | source = ["scib-metrics"]
301 | omit = [
302 | ```
303 |
304 | Don't forget to update the [Making a release section](contributing.md#publishing-a-release) in this document accordingly, after you are done!
305 |
306 | [hatch-vcs]: https://pypi.org/project/hatch-vcs/
307 |
308 | ### Automated template sync
309 |
310 | Automated template sync is enabled by default. This means that every night, a GitHub action runs [cruft][] to check
311 | if a new version of the `scverse-cookiecutter` template got released. If there are any new changes, a pull request
312 | proposing these changes is created automatically. This helps keeping the repository up-to-date with the latest
313 | coding standards.
314 |
315 | It may happen that a template sync results in a merge conflict. If this is the case a `*.ref` file with the
316 | diff is created. You need to manually address these changes and remove the `.rej` file when you are done.
317 | The pull request can only be merged after all `*.rej` files have been removed.
318 |
319 | :::{tip}
320 | The following hints may be useful to work with the template sync:
321 |
322 | - GitHub automatically disables scheduled actions if there has been not activity to the repository for 60 days.
323 | You can re-enable or manually trigger the sync by navigating to `Actions` -> `Sync Template` in your GitHub repository.
324 | - If you want to ignore certain files from the template update, you can add them to the `[tool.cruft]` section in the
325 | `pyproject.toml` file in the root of your repository. More details are described in the
326 | [cruft documentation][cruft-update-project].
327 | - To disable the sync entirely, simply remove the file `.github/workflows/sync.yaml`.
328 |
329 | :::
330 |
331 | [cruft]: https://cruft.github.io/cruft/
332 | [cruft-update-project]: https://cruft.github.io/cruft/#updating-a-project
333 |
334 | ### Making a release
335 |
336 | #### Updating the version number
337 |
338 | Before making a release, you need to update the version number. Please adhere to [Semantic Versioning][semver], in brief
339 |
340 | > Given a version number MAJOR.MINOR.PATCH, increment the:
341 | >
342 | > 1. MAJOR version when you make incompatible API changes,
343 | > 2. MINOR version when you add functionality in a backwards compatible manner, and
344 | > 3. PATCH version when you make backwards compatible bug fixes.
345 | >
346 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format.
347 | > We use [bump2version][] to automatically update the version number in all places and automatically create a git tag.
348 | > Run one of the following commands in the root of the repository
349 |
350 | ```bash
351 | bump2version patch
352 | bump2version minor
353 | bump2version major
354 | ```
355 |
356 | Once you are done, run
357 |
358 | ```
359 | git push --tags
360 | ```
361 |
362 | to publish the created tag on GitHub.
363 |
364 | [bump2version]: https://github.com/c4urself/bump2version
365 |
366 | #### Upload on PyPI
367 |
368 | Please follow the [Python packaging tutorial][].
369 |
370 | It is possible to automate this with GitHub actions, see also [this feature request][pypi-feature-request]
371 | in the cookiecutter-scverse template.
372 |
373 | [python packaging tutorial]: https://packaging.python.org/en/latest/tutorials/packaging-projects/#generating-distribution-archives
374 | [pypi-feature-request]: https://github.com/scverse/cookiecutter-scverse/issues/88
375 |
376 | ### Writing documentation
377 |
378 | Please write documentation for your package. This project uses [sphinx][] with the following features:
379 |
380 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
381 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
382 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
383 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types
384 |
385 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
386 | on how to write documentation.
387 |
388 | ### Tutorials with myst-nb and jupyter notebooks
389 |
390 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][].
391 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells.
392 | It is your reponsibility to update and re-run the notebook whenever necessary.
393 |
394 | If you are interested in automatically running notebooks as part of the continuous integration, please check
395 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse`
396 | repository.
397 |
398 | #### Hints
399 |
400 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
401 | if you do so can sphinx automatically create a link to the external documentation.
402 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to
403 | the `nitpick_ignore` list in `docs/conf.py`
404 |
405 | #### Building the docs locally
406 |
407 | ```bash
408 | cd docs
409 | make html
410 | open _build/html/index.html
411 | ```
412 |
413 |
414 |
415 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html
416 | [codecov]: https://about.codecov.io/sign-up/
417 | [codecov docs]: https://docs.codecov.com/docs
418 | [pre-commit.ci]: https://pre-commit.ci/
419 | [readthedocs.org]: https://readthedocs.org/
420 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/
421 | [jupytext]: https://jupytext.readthedocs.io/en/latest/
422 | [pre-commit]: https://pre-commit.com/
423 | [anndata]: https://github.com/scverse/anndata
424 | [mudata]: https://github.com/scverse/mudata
425 | [pytest]: https://docs.pytest.org/
426 | [semver]: https://semver.org/
427 | [sphinx]: https://www.sphinx-doc.org/en/master/
428 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html
429 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
430 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html
431 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints
432 |
--------------------------------------------------------------------------------
/docs/tutorials.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 | ## Walkthrough
4 |
5 | ```{toctree}
6 | :maxdepth: 1
7 |
8 | notebooks/lung_example
9 | notebooks/large_scale
10 | ```
11 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "hatchling.build"
3 | requires = ["hatchling"]
4 |
5 |
6 | [project]
7 | name = "scib-metrics"
8 | version = "0.5.4"
9 | description = "Accelerated and Python-only scIB metrics"
10 | readme = "README.md"
11 | requires-python = ">=3.10"
12 | license = { file = "LICENSE" }
13 | authors = [{ name = "Adam Gayoso" }]
14 | maintainers = [{ name = "Adam Gayoso", email = "adamgayoso@berkeley.edu" }]
15 | urls.Documentation = "https://scib-metrics.readthedocs.io/"
16 | urls.Source = "https://github.com/yoseflab/scib-metrics"
17 | urls.Home-page = "https://github.com/yoseflab/scib-metrics"
18 | dependencies = [
19 | "anndata",
20 | "chex",
21 | "jax",
22 | "jaxlib",
23 | "numpy",
24 | "pandas",
25 | "scipy",
26 | "scikit-learn",
27 | "scanpy>=1.9",
28 | "rich",
29 | "pynndescent",
30 | "igraph>0.9.0",
31 | "matplotlib",
32 | "plottable",
33 | "tqdm",
34 | "umap-learn>=0.5.0",
35 | ]
36 |
37 | [project.optional-dependencies]
38 | dev = ["pre-commit", "twine>=4.0.2"]
39 | doc = [
40 | "sphinx>=4",
41 | "sphinx-book-theme>=1.0",
42 | "myst-nb",
43 | "sphinxcontrib-bibtex>=1.0.0",
44 | "scanpydoc[typehints]>=0.7.4",
45 | "sphinxext-opengraph",
46 | # For notebooks
47 | "ipython",
48 | "ipykernel",
49 | "sphinx-copybutton",
50 | "numba>=0.57.1",
51 | ]
52 | test = [
53 | "pytest",
54 | "coverage",
55 | "scib>=1.1.4",
56 | "harmonypy",
57 | "joblib",
58 | # For vscode Python extension testing
59 | "flake8",
60 | "black",
61 | "numba>=0.57.1",
62 | ]
63 | parallel = ["joblib"]
64 | tutorial = [
65 | "rich",
66 | "scanorama",
67 | "harmony-pytorch",
68 | "scvi-tools",
69 | "pyliger",
70 | "numexpr", # missing liger dependency
71 | "plotnine", # missing liger dependency
72 | "mygene", # missing liger dependency
73 | "goatools", # missing liger dependency
74 | "adjustText", # missing liger dependency
75 | ]
76 |
77 | [tool.hatch.build.targets.wheel]
78 | packages = ['src/scib_metrics']
79 |
80 | [tool.coverage.run]
81 | source = ["scib_metrics"]
82 | omit = ["**/test_*.py"]
83 |
84 | [tool.pytest.ini_options]
85 | testpaths = ["tests"]
86 | xfail_strict = true
87 |
88 |
89 | [tool.ruff]
90 | src = ["src"]
91 | line-length = 120
92 | lint.select = [
93 | "F", # Errors detected by Pyflakes
94 | "E", # Error detected by Pycodestyle
95 | "W", # Warning detected by Pycodestyle
96 | "I", # isort
97 | "D", # pydocstyle
98 | "B", # flake8-bugbear
99 | "TID", # flake8-tidy-imports
100 | "C4", # flake8-comprehensions
101 | "BLE", # flake8-blind-except
102 | "UP", # pyupgrade
103 | "RUF100", # Report unused noqa directives
104 | "ICN", # flake8-import-conventions
105 | "TCH", # flake8-type-checking
106 | "FA", # flake8-future-annotations
107 | ]
108 | lint.ignore = [
109 | # line too long -> we accept long comment lines; formatter gets rid of long code lines
110 | "E501",
111 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
112 | "E731",
113 | # allow I, O, l as variable names -> I is the identity matrix
114 | "E741",
115 | # Missing docstring in public package
116 | "D104",
117 | # Missing docstring in public module
118 | "D100",
119 | # Missing docstring in __init__
120 | "D107",
121 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
122 | "B008",
123 | # __magic__ methods are are often self-explanatory, allow missing docstrings
124 | "D105",
125 | # First line should be in imperative mood; try rephrasing
126 | "D401",
127 | ## Disable one in each pair of mutually incompatible rules
128 | # We don’t want a blank line before a class docstring
129 | "D203",
130 | # We want docstrings to start immediately after the opening triple quote
131 | "D213",
132 | # Missing argument description in the docstring TODO: enable
133 | "D417",
134 | # No explicit stacklevel argument
135 | "B028",
136 | ]
137 | extend-include = ["*.ipynb"]
138 |
139 | [tool.ruff.lint.per-file-ignores]
140 | "docs/*" = ["I", "BLE001", "E402", "B018"]
141 | "tests/*" = ["D", "B018"]
142 | "*/__init__.py" = ["F401"]
143 | "src/scib_metrics/__init__.py" = ["I"]
144 |
145 | [tool.ruff.format]
146 | docstring-code-format = true
147 |
148 | [tool.ruff.lint.pydocstyle]
149 | convention = "numpy"
150 |
151 | [tool.jupytext]
152 | formats = "ipynb,md"
153 |
154 | [tool.cruft]
155 | skip = [
156 | "tests",
157 | "src/**/__init__.py",
158 | "src/**/basic.py",
159 | "docs/api.md",
160 | "docs/changelog.md",
161 | "docs/references.bib",
162 | "docs/references.md",
163 | "docs/notebooks/example.ipynb",
164 | ]
165 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # This is a shim to allow Github to detect the package, build is done with hatch
4 |
5 | import setuptools
6 |
7 | if __name__ == "__main__":
8 | setuptools.setup(name="scib-metrics")
9 |
--------------------------------------------------------------------------------
/src/scib_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from importlib.metadata import version
3 |
4 | from . import nearest_neighbors, utils
5 | from .metrics import (
6 | graph_connectivity,
7 | isolated_labels,
8 | kbet,
9 | kbet_per_label,
10 | clisi_knn,
11 | ilisi_knn,
12 | lisi_knn,
13 | nmi_ari_cluster_labels_kmeans,
14 | nmi_ari_cluster_labels_leiden,
15 | pcr_comparison,
16 | silhouette_batch,
17 | silhouette_label,
18 | )
19 | from ._settings import settings
20 |
21 | __all__ = [
22 | "utils",
23 | "nearest_neighbors",
24 | "isolated_labels",
25 | "pcr_comparison",
26 | "silhouette_label",
27 | "silhouette_batch",
28 | "ilisi_knn",
29 | "clisi_knn",
30 | "lisi_knn",
31 | "nmi_ari_cluster_labels_kmeans",
32 | "nmi_ari_cluster_labels_leiden",
33 | "kbet",
34 | "kbet_per_label",
35 | "graph_connectivity",
36 | "settings",
37 | ]
38 |
39 | __version__ = version("scib-metrics")
40 |
41 | settings.verbosity = logging.INFO
42 | # Jax sets the root logger, this prevents double output.
43 | logger = logging.getLogger("scib_metrics")
44 | logger.propagate = False
45 |
--------------------------------------------------------------------------------
/src/scib_metrics/_settings.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Literal
4 |
5 | from rich.console import Console
6 | from rich.logging import RichHandler
7 |
8 | scib_logger = logging.getLogger("scib_metrics")
9 |
10 |
11 | class ScibConfig:
12 | """Config manager for scib-metrics.
13 |
14 | Examples
15 | --------
16 | To set the progress bar style, choose one of "rich", "tqdm"
17 |
18 | >>> scib_metrics.settings.progress_bar_style = "rich"
19 |
20 | To set the verbosity
21 |
22 | >>> import logging
23 | >>> scib_metrics.settings.verbosity = logging.INFO
24 | """
25 |
26 | def __init__(
27 | self,
28 | verbosity: int = logging.INFO,
29 | progress_bar_style: Literal["rich", "tqdm"] = "tqdm",
30 | jax_preallocate_gpu_memory: bool = False,
31 | ):
32 | if progress_bar_style not in ["rich", "tqdm"]:
33 | raise ValueError("Progress bar style must be in ['rich', 'tqdm']")
34 | self.progress_bar_style = progress_bar_style
35 | self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory
36 | self.verbosity = verbosity
37 |
38 | @property
39 | def progress_bar_style(self) -> str:
40 | """Library to use for progress bar."""
41 | return self._pbar_style
42 |
43 | @progress_bar_style.setter
44 | def progress_bar_style(self, pbar_style: Literal["tqdm", "rich"]):
45 | """Library to use for progress bar."""
46 | self._pbar_style = pbar_style
47 |
48 | @property
49 | def verbosity(self) -> int:
50 | """Verbosity level (default `logging.INFO`).
51 |
52 | Returns
53 | -------
54 | verbosity: int
55 | """
56 | return self._verbosity
57 |
58 | @verbosity.setter
59 | def verbosity(self, level: str | int):
60 | """Set verbosity level.
61 |
62 | If "scib_metrics" logger has no StreamHandler, add one.
63 | Else, set its level to `level`.
64 |
65 | Parameters
66 | ----------
67 | level
68 | Sets "scib_metrics" logging level to `level`
69 | force_terminal
70 | Rich logging option, set to False if piping to file output.
71 | """
72 | self._verbosity = level
73 | scib_logger.setLevel(level)
74 | if len(scib_logger.handlers) == 0:
75 | console = Console(force_terminal=True)
76 | if console.is_jupyter is True:
77 | console.is_jupyter = False
78 | ch = RichHandler(level=level, show_path=False, console=console, show_time=False)
79 | formatter = logging.Formatter("%(message)s")
80 | ch.setFormatter(formatter)
81 | scib_logger.addHandler(ch)
82 | else:
83 | scib_logger.setLevel(level)
84 |
85 | def reset_logging_handler(self) -> None:
86 | """Reset "scib_metrics" log handler to a basic RichHandler().
87 |
88 | This is useful if piping outputs to a file.
89 |
90 | Returns
91 | -------
92 | None
93 | """
94 | scib_logger.removeHandler(scib_logger.handlers[0])
95 | ch = RichHandler(level=self._verbosity, show_path=False, show_time=False)
96 | formatter = logging.Formatter("%(message)s")
97 | ch.setFormatter(formatter)
98 | scib_logger.addHandler(ch)
99 |
100 | def jax_fix_no_kernel_image(self) -> None:
101 | """Fix for JAX error "No kernel image is available for execution on the device"."""
102 | os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"
103 |
104 | @property
105 | def jax_preallocate_gpu_memory(self):
106 | """Jax GPU memory allocation settings.
107 |
108 | If False, Jax will ony preallocate GPU memory it needs.
109 | If float in (0, 1), Jax will preallocate GPU memory to that
110 | fraction of the GPU memory.
111 |
112 | Returns
113 | -------
114 | jax_preallocate_gpu_memory: bool or float
115 | """
116 | return self._jax_gpu
117 |
118 | @jax_preallocate_gpu_memory.setter
119 | def jax_preallocate_gpu_memory(self, value: float | bool):
120 | # see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation
121 | if value is False:
122 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
123 | elif isinstance(value, float):
124 | if value >= 1 or value <= 0:
125 | raise ValueError("Need to use a value between 0 and 1")
126 | # format is ".XX"
127 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(value)[1:4]
128 | else:
129 | raise ValueError("value not understood, need bool or float in (0, 1)")
130 | self._jax_gpu = value
131 |
132 |
133 | settings = ScibConfig()
134 |
--------------------------------------------------------------------------------
/src/scib_metrics/_types.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import numpy as np
3 | import scipy.sparse as sp
4 | from jax import Array
5 |
6 | NdArray = np.ndarray | jnp.ndarray
7 | IntOrKey = int | Array
8 | ArrayLike = np.ndarray | sp.spmatrix | jnp.ndarray
9 |
--------------------------------------------------------------------------------
/src/scib_metrics/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | from ._core import BatchCorrection, Benchmarker, BioConservation
2 |
3 | __all__ = ["Benchmarker", "BioConservation", "BatchCorrection"]
4 |
--------------------------------------------------------------------------------
/src/scib_metrics/benchmark/_core.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from collections.abc import Callable
4 | from dataclasses import asdict, dataclass
5 | from enum import Enum
6 | from functools import partial
7 | from typing import Any
8 |
9 | import matplotlib as mpl
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import pandas as pd
13 | import scanpy as sc
14 | from anndata import AnnData
15 | from plottable import ColumnDefinition, Table
16 | from plottable.cmap import normed_cmap
17 | from plottable.plots import bar
18 | from sklearn.preprocessing import MinMaxScaler
19 | from tqdm import tqdm
20 |
21 | import scib_metrics
22 | from scib_metrics.nearest_neighbors import NeighborsResults, pynndescent
23 |
24 | Kwargs = dict[str, Any]
25 | MetricType = bool | Kwargs
26 |
27 | _LABELS = "labels"
28 | _BATCH = "batch"
29 | _X_PRE = "X_pre"
30 | _METRIC_TYPE = "Metric Type"
31 | _AGGREGATE_SCORE = "Aggregate score"
32 |
33 | # Mapping of metric fn names to clean DataFrame column names
34 | metric_name_cleaner = {
35 | "silhouette_label": "Silhouette label",
36 | "silhouette_batch": "Silhouette batch",
37 | "isolated_labels": "Isolated labels",
38 | "nmi_ari_cluster_labels_leiden_nmi": "Leiden NMI",
39 | "nmi_ari_cluster_labels_leiden_ari": "Leiden ARI",
40 | "nmi_ari_cluster_labels_kmeans_nmi": "KMeans NMI",
41 | "nmi_ari_cluster_labels_kmeans_ari": "KMeans ARI",
42 | "clisi_knn": "cLISI",
43 | "ilisi_knn": "iLISI",
44 | "kbet_per_label": "KBET",
45 | "graph_connectivity": "Graph connectivity",
46 | "pcr_comparison": "PCR comparison",
47 | }
48 |
49 |
50 | @dataclass(frozen=True)
51 | class BioConservation:
52 | """Specification of bio conservation metrics to run in the pipeline.
53 |
54 | Metrics can be included using a boolean flag. Custom keyword args can be
55 | used by passing a dictionary here. Keyword args should not set data-related
56 | parameters, such as `X` or `labels`.
57 | """
58 |
59 | isolated_labels: MetricType = True
60 | nmi_ari_cluster_labels_leiden: MetricType = False
61 | nmi_ari_cluster_labels_kmeans: MetricType = True
62 | silhouette_label: MetricType = True
63 | clisi_knn: MetricType = True
64 |
65 |
66 | @dataclass(frozen=True)
67 | class BatchCorrection:
68 | """Specification of which batch correction metrics to run in the pipeline.
69 |
70 | Metrics can be included using a boolean flag. Custom keyword args can be
71 | used by passing a dictionary here. Keyword args should not set data-related
72 | parameters, such as `X` or `labels`.
73 | """
74 |
75 | silhouette_batch: MetricType = True
76 | ilisi_knn: MetricType = True
77 | kbet_per_label: MetricType = True
78 | graph_connectivity: MetricType = True
79 | pcr_comparison: MetricType = True
80 |
81 |
82 | class MetricAnnDataAPI(Enum):
83 | """Specification of the AnnData API for a metric."""
84 |
85 | isolated_labels = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH])
86 | nmi_ari_cluster_labels_leiden = lambda ad, fn: fn(ad.uns["15_neighbor_res"], ad.obs[_LABELS])
87 | nmi_ari_cluster_labels_kmeans = lambda ad, fn: fn(ad.X, ad.obs[_LABELS])
88 | silhouette_label = lambda ad, fn: fn(ad.X, ad.obs[_LABELS])
89 | clisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_LABELS])
90 | graph_connectivity = lambda ad, fn: fn(ad.uns["15_neighbor_res"], ad.obs[_LABELS])
91 | silhouette_batch = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH])
92 | pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True)
93 | ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH])
94 | kbet_per_label = lambda ad, fn: fn(ad.uns["50_neighbor_res"], ad.obs[_BATCH], ad.obs[_LABELS])
95 |
96 |
97 | class Benchmarker:
98 | """Benchmarking pipeline for the single-cell integration task.
99 |
100 | Parameters
101 | ----------
102 | adata
103 | AnnData object containing the raw count data and integrated embeddings as obsm keys.
104 | batch_key
105 | Key in `adata.obs` that contains the batch information.
106 | label_key
107 | Key in `adata.obs` that contains the cell type labels.
108 | embedding_obsm_keys
109 | List of obsm keys that contain the embeddings to be benchmarked.
110 | bio_conservation_metrics
111 | Specification of which bio conservation metrics to run in the pipeline.
112 | batch_correction_metrics
113 | Specification of which batch correction metrics to run in the pipeline.
114 | pre_integrated_embedding_obsm_key
115 | Obsm key containing a non-integrated embedding of the data. If `None`, the embedding will be computed
116 | in the prepare step. See the notes below for more information.
117 | n_jobs
118 | Number of jobs to use for parallelization of neighbor search.
119 | progress_bar
120 | Whether to show a progress bar for :meth:`~scib_metrics.benchmark.Benchmarker.prepare` and
121 | :meth:`~scib_metrics.benchmark.Benchmarker.benchmark`.
122 |
123 | Notes
124 | -----
125 | `adata.X` should contain a form of the data that is not integrated, but is normalized. The `prepare` method will
126 | use `adata.X` for PCA via :func:`~scanpy.tl.pca`, which also only uses features masked via `adata.var['highly_variable']`.
127 |
128 | See further usage examples in the following tutorial:
129 |
130 | 1. :doc:`/notebooks/lung_example`
131 | """
132 |
133 | def __init__(
134 | self,
135 | adata: AnnData,
136 | batch_key: str,
137 | label_key: str,
138 | embedding_obsm_keys: list[str],
139 | bio_conservation_metrics: BioConservation | None = BioConservation(),
140 | batch_correction_metrics: BatchCorrection | None = BatchCorrection(),
141 | pre_integrated_embedding_obsm_key: str | None = None,
142 | n_jobs: int = 1,
143 | progress_bar: bool = True,
144 | ):
145 | self._adata = adata
146 | self._embedding_obsm_keys = embedding_obsm_keys
147 | self._pre_integrated_embedding_obsm_key = pre_integrated_embedding_obsm_key
148 | self._bio_conservation_metrics = bio_conservation_metrics
149 | self._batch_correction_metrics = batch_correction_metrics
150 | self._results = pd.DataFrame(columns=list(self._embedding_obsm_keys) + [_METRIC_TYPE])
151 | self._emb_adatas = {}
152 | self._neighbor_values = (15, 50, 90)
153 | self._prepared = False
154 | self._benchmarked = False
155 | self._batch_key = batch_key
156 | self._label_key = label_key
157 | self._n_jobs = n_jobs
158 | self._progress_bar = progress_bar
159 |
160 | if self._bio_conservation_metrics is None and self._batch_correction_metrics is None:
161 | raise ValueError("Either batch or bio metrics must be defined.")
162 |
163 | self._metric_collection_dict = {}
164 | if self._bio_conservation_metrics is not None:
165 | self._metric_collection_dict.update({"Bio conservation": self._bio_conservation_metrics})
166 | if self._batch_correction_metrics is not None:
167 | self._metric_collection_dict.update({"Batch correction": self._batch_correction_metrics})
168 |
169 | def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None) -> None:
170 | """Prepare the data for benchmarking.
171 |
172 | Parameters
173 | ----------
174 | neighbor_computer
175 | Function that computes the neighbors of the data. If `None`, the neighbors will be computed
176 | with :func:`~scib_metrics.utils.nearest_neighbors.pynndescent`. The function should take as input
177 | the data and the number of neighbors to compute and return a :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults`
178 | object.
179 | """
180 | # Compute PCA
181 | if self._pre_integrated_embedding_obsm_key is None:
182 | # This is how scib does it
183 | # https://github.com/theislab/scib/blob/896f689e5fe8c57502cb012af06bed1a9b2b61d2/scib/metrics/pcr.py#L197
184 | sc.tl.pca(self._adata, use_highly_variable=False)
185 | self._pre_integrated_embedding_obsm_key = "X_pca"
186 |
187 | for emb_key in self._embedding_obsm_keys:
188 | self._emb_adatas[emb_key] = AnnData(self._adata.obsm[emb_key], obs=self._adata.obs)
189 | self._emb_adatas[emb_key].obs[_BATCH] = np.asarray(self._adata.obs[self._batch_key].values)
190 | self._emb_adatas[emb_key].obs[_LABELS] = np.asarray(self._adata.obs[self._label_key].values)
191 | self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key]
192 |
193 | # Compute neighbors
194 | progress = self._emb_adatas.values()
195 | if self._progress_bar:
196 | progress = tqdm(progress, desc="Computing neighbors")
197 |
198 | for ad in progress:
199 | if neighbor_computer is not None:
200 | neigh_result = neighbor_computer(ad.X, max(self._neighbor_values))
201 | else:
202 | neigh_result = pynndescent(
203 | ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs
204 | )
205 | for n in self._neighbor_values:
206 | ad.uns[f"{n}_neighbor_res"] = neigh_result.subset_neighbors(n=n)
207 |
208 | self._prepared = True
209 |
210 | def benchmark(self) -> None:
211 | """Run the pipeline."""
212 | if self._benchmarked:
213 | warnings.warn(
214 | "The benchmark has already been run. Running it again will overwrite the previous results.",
215 | UserWarning,
216 | )
217 |
218 | if not self._prepared:
219 | self.prepare()
220 |
221 | num_metrics = sum(
222 | [sum([v is not False for v in asdict(met_col)]) for met_col in self._metric_collection_dict.values()]
223 | )
224 |
225 | progress_embs = self._emb_adatas.items()
226 | if self._progress_bar:
227 | progress_embs = tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green")
228 |
229 | for emb_key, ad in progress_embs:
230 | pbar = None
231 | if self._progress_bar:
232 | pbar = tqdm(total=num_metrics, desc="Metrics", position=1, leave=False, colour="blue")
233 | for metric_type, metric_collection in self._metric_collection_dict.items():
234 | for metric_name, use_metric_or_kwargs in asdict(metric_collection).items():
235 | if use_metric_or_kwargs:
236 | pbar.set_postfix_str(f"{metric_type}: {metric_name}") if pbar is not None else None
237 | metric_fn = getattr(scib_metrics, metric_name)
238 | if isinstance(use_metric_or_kwargs, dict):
239 | # Kwargs in this case
240 | metric_fn = partial(metric_fn, **use_metric_or_kwargs)
241 | metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn)
242 | # nmi/ari metrics return a dict
243 | if isinstance(metric_value, dict):
244 | for k, v in metric_value.items():
245 | self._results.loc[f"{metric_name}_{k}", emb_key] = v
246 | self._results.loc[f"{metric_name}_{k}", _METRIC_TYPE] = metric_type
247 | else:
248 | self._results.loc[metric_name, emb_key] = metric_value
249 | self._results.loc[metric_name, _METRIC_TYPE] = metric_type
250 | pbar.update(1) if pbar is not None else None
251 |
252 | self._benchmarked = True
253 |
254 | def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> pd.DataFrame:
255 | """Return the benchmarking results.
256 |
257 | Parameters
258 | ----------
259 | min_max_scale
260 | Whether to min max scale the results.
261 | clean_names
262 | Whether to clean the metric names.
263 |
264 | Returns
265 | -------
266 | The benchmarking results.
267 | """
268 | df = self._results.transpose()
269 | df.index.name = "Embedding"
270 | df = df.loc[df.index != _METRIC_TYPE]
271 | if min_max_scale:
272 | # Use sklearn to min max scale
273 | df = pd.DataFrame(
274 | MinMaxScaler().fit_transform(df),
275 | columns=df.columns,
276 | index=df.index,
277 | )
278 | if clean_names:
279 | df = df.rename(columns=metric_name_cleaner)
280 | df = df.transpose()
281 | df[_METRIC_TYPE] = self._results[_METRIC_TYPE].values
282 |
283 | # Compute scores
284 | per_class_score = df.groupby(_METRIC_TYPE).mean().transpose()
285 | # This is the default scIB weighting from the manuscript
286 | if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None:
287 | per_class_score["Total"] = (
288 | 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"]
289 | )
290 | df = pd.concat([df.transpose(), per_class_score], axis=1)
291 | df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE
292 | return df
293 |
294 | def plot_results_table(self, min_max_scale: bool = True, show: bool = True, save_dir: str | None = None) -> Table:
295 | """Plot the benchmarking results.
296 |
297 | Parameters
298 | ----------
299 | min_max_scale
300 | Whether to min max scale the results.
301 | show
302 | Whether to show the plot.
303 | save_dir
304 | The directory to save the plot to. If `None`, the plot is not saved.
305 | """
306 | num_embeds = len(self._embedding_obsm_keys)
307 | cmap_fn = lambda col_data: normed_cmap(col_data, cmap=mpl.cm.PRGn, num_stds=2.5)
308 | df = self.get_results(min_max_scale=min_max_scale)
309 | # Do not want to plot what kind of metric it is
310 | plot_df = df.drop(_METRIC_TYPE, axis=0)
311 | # Sort by total score
312 | if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None:
313 | sort_col = "Total"
314 | elif self._batch_correction_metrics is not None:
315 | sort_col = "Batch correction"
316 | else:
317 | sort_col = "Bio conservation"
318 | plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64)
319 | plot_df["Method"] = plot_df.index
320 |
321 | # Split columns by metric type, using df as it doesn't have the new method col
322 | score_cols = df.columns[df.loc[_METRIC_TYPE] == _AGGREGATE_SCORE]
323 | other_cols = df.columns[df.loc[_METRIC_TYPE] != _AGGREGATE_SCORE]
324 | column_definitions = [
325 | ColumnDefinition("Method", width=1.5, textprops={"ha": "left", "weight": "bold"}),
326 | ]
327 | # Circles for the metric values
328 | column_definitions += [
329 | ColumnDefinition(
330 | col,
331 | title=col.replace(" ", "\n", 1),
332 | width=1,
333 | textprops={
334 | "ha": "center",
335 | "bbox": {"boxstyle": "circle", "pad": 0.25},
336 | },
337 | cmap=cmap_fn(plot_df[col]),
338 | group=df.loc[_METRIC_TYPE, col],
339 | formatter="{:.2f}",
340 | )
341 | for i, col in enumerate(other_cols)
342 | ]
343 | # Bars for the aggregate scores
344 | column_definitions += [
345 | ColumnDefinition(
346 | col,
347 | width=1,
348 | title=col.replace(" ", "\n", 1),
349 | plot_fn=bar,
350 | plot_kw={
351 | "cmap": mpl.cm.YlGnBu,
352 | "plot_bg_bar": False,
353 | "annotate": True,
354 | "height": 0.9,
355 | "formatter": "{:.2f}",
356 | },
357 | group=df.loc[_METRIC_TYPE, col],
358 | border="left" if i == 0 else None,
359 | )
360 | for i, col in enumerate(score_cols)
361 | ]
362 | # Allow to manipulate text post-hoc (in illustrator)
363 | with mpl.rc_context({"svg.fonttype": "none"}):
364 | fig, ax = plt.subplots(figsize=(len(df.columns) * 1.25, 3 + 0.3 * num_embeds))
365 | tab = Table(
366 | plot_df,
367 | cell_kw={
368 | "linewidth": 0,
369 | "edgecolor": "k",
370 | },
371 | column_definitions=column_definitions,
372 | ax=ax,
373 | row_dividers=True,
374 | footer_divider=True,
375 | textprops={"fontsize": 10, "ha": "center"},
376 | row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 5))},
377 | col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
378 | column_border_kw={"linewidth": 1, "linestyle": "-"},
379 | index_col="Method",
380 | ).autoset_fontcolors(colnames=plot_df.columns)
381 | if show:
382 | plt.show()
383 | if save_dir is not None:
384 | fig.savefig(os.path.join(save_dir, "scib_results.svg"), facecolor=ax.get_facecolor(), dpi=300)
385 |
386 | return tab
387 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from ._graph_connectivity import graph_connectivity
2 | from ._isolated_labels import isolated_labels
3 | from ._kbet import kbet, kbet_per_label
4 | from ._lisi import clisi_knn, ilisi_knn, lisi_knn
5 | from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden
6 | from ._pcr_comparison import pcr_comparison
7 | from ._silhouette import silhouette_batch, silhouette_label
8 |
9 | __all__ = [
10 | "isolated_labels",
11 | "pcr_comparison",
12 | "silhouette_label",
13 | "silhouette_batch",
14 | "ilisi_knn",
15 | "clisi_knn",
16 | "lisi_knn",
17 | "nmi_ari_cluster_labels_kmeans",
18 | "nmi_ari_cluster_labels_leiden",
19 | "kbet",
20 | "kbet_per_label",
21 | "graph_connectivity",
22 | ]
23 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_graph_connectivity.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from scipy.sparse.csgraph import connected_components
4 |
5 | from scib_metrics.nearest_neighbors import NeighborsResults
6 |
7 |
8 | def graph_connectivity(X: NeighborsResults, labels: np.ndarray) -> float:
9 | """Quantify the connectivity of the subgraph per cell type label.
10 |
11 | Parameters
12 | ----------
13 | X
14 | Array of shape (n_cells, n_cells) with non-zero values
15 | representing distances to exactly each cell's k nearest neighbors.
16 | labels
17 | Array of shape (n_cells,) representing label values
18 | for each cell.
19 | """
20 | # TODO(adamgayoso): Utils for validating inputs
21 | clust_res = []
22 |
23 | graph = X.knn_graph_distances
24 |
25 | for label in np.unique(labels):
26 | mask = labels == label
27 | if hasattr(mask, "values"):
28 | mask = mask.values
29 | graph_sub = graph[mask]
30 | graph_sub = graph_sub[:, mask]
31 | _, comps = connected_components(graph_sub, connection="strong")
32 | tab = pd.value_counts(comps)
33 | clust_res.append(tab.max() / sum(tab))
34 |
35 | return np.mean(clust_res)
36 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_isolated_labels.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from scib_metrics.utils import silhouette_samples
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def isolated_labels(
12 | X: np.ndarray,
13 | labels: np.ndarray,
14 | batch: np.ndarray,
15 | rescale: bool = True,
16 | iso_threshold: int | None = None,
17 | ) -> float:
18 | """Isolated label score :cite:p:`luecken2022benchmarking`.
19 |
20 | Score how well labels of isolated labels are distiguished in the dataset by
21 | average-width silhouette score (ASW) on isolated label vs all other labels.
22 |
23 | The default of the original scib package is to use a cluster-based F1 scoring
24 | procedure, but here we use the ASW for speed and simplicity.
25 |
26 | Parameters
27 | ----------
28 | X
29 | Array of shape (n_cells, n_features).
30 | labels
31 | Array of shape (n_cells,) representing label values
32 | batch
33 | Array of shape (n_cells,) representing batch values
34 | rescale
35 | Scale asw into the range [0, 1].
36 | iso_threshold
37 | Max number of batches per label for label to be considered as
38 | isolated, if integer. If `None`, considers minimum number of
39 | batches that labels are present in
40 |
41 | Returns
42 | -------
43 | isolated_label_score
44 | """
45 | scores = {}
46 | isolated_labels = _get_isolated_labels(labels, batch, iso_threshold)
47 |
48 | silhouette_all = silhouette_samples(X, labels)
49 | if rescale:
50 | silhouette_all = (silhouette_all + 1) / 2
51 |
52 | for label in isolated_labels:
53 | score = np.mean(silhouette_all[labels == label])
54 | scores[label] = score
55 | scores = pd.Series(scores)
56 |
57 | return scores.mean()
58 |
59 |
60 | def _get_isolated_labels(labels: np.ndarray, batch: np.ndarray, iso_threshold: float):
61 | """Get labels that are isolated depending on the number of batches."""
62 | tmp = pd.DataFrame()
63 | label_key = "label"
64 | batch_key = "batch"
65 | tmp[label_key] = labels
66 | tmp[batch_key] = batch
67 | tmp = tmp.drop_duplicates()
68 | batch_per_lab = tmp.groupby(label_key).agg({batch_key: "count"})
69 |
70 | # threshold for determining when label is considered isolated
71 | if iso_threshold is None:
72 | iso_threshold = batch_per_lab.min().tolist()[0]
73 |
74 | logging.info(f"isolated labels: no more than {iso_threshold} batches per label")
75 |
76 | labels = batch_per_lab[batch_per_lab[batch_key] <= iso_threshold].index.tolist()
77 | if len(labels) == 0:
78 | logging.info(f"no isolated labels with less than {iso_threshold} batches")
79 |
80 | return np.array(labels)
81 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_kbet.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 |
4 | import chex
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 | import pandas as pd
9 | import scipy
10 |
11 | from scib_metrics._types import NdArray
12 | from scib_metrics.nearest_neighbors import NeighborsResults
13 | from scib_metrics.utils import diffusion_nn, get_ndarray
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def _chi2_cdf(df: int | NdArray, x: NdArray) -> float:
19 | """Chi2 cdf.
20 |
21 | See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.chdtr.html
22 | for explanation of gammainc.
23 | """
24 | return jax.scipy.special.gammainc(df / 2, x / 2)
25 |
26 |
27 | @partial(jax.jit, static_argnums=2)
28 | def _kbet(neigh_batch_ids: jnp.ndarray, batches: jnp.ndarray, n_batches: int) -> float:
29 | expected_freq = jnp.bincount(batches, length=n_batches)
30 | expected_freq = expected_freq / jnp.sum(expected_freq)
31 | dof = n_batches - 1
32 |
33 | observed_counts = jax.vmap(partial(jnp.bincount, length=n_batches))(neigh_batch_ids)
34 | expected_counts = expected_freq * neigh_batch_ids.shape[1]
35 | test_statistics = jnp.sum(jnp.square(observed_counts - expected_counts) / expected_counts, axis=1)
36 | p_values = 1 - jax.vmap(_chi2_cdf, in_axes=(None, 0))(dof, test_statistics)
37 |
38 | return test_statistics, p_values
39 |
40 |
41 | def kbet(X: NeighborsResults, batches: np.ndarray, alpha: float = 0.05) -> float:
42 | """Compute kbet :cite:p:`buttner2018`.
43 |
44 | This implementation is inspired by the implementation in Pegasus:
45 | https://pegasus.readthedocs.io/en/stable/index.html
46 |
47 | A higher acceptance rate means more mixing of batches. This implementation does
48 | not exactly mirror the default original implementation, as there is currently no
49 | `adapt` option.
50 |
51 | Note that this is also not equivalent to the kbet used in the original scib package,
52 | as that one computes kbet for each cell type label. To achieve this, use
53 | :func:`scib_metrics.kbet_per_label`.
54 |
55 | Parameters
56 | ----------
57 | X
58 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
59 | batches
60 | Array of shape (n_cells,) representing batch values
61 | for each cell.
62 | alpha
63 | Significance level for the statistical test.
64 |
65 | Returns
66 | -------
67 | acceptance_rate
68 | Kbet acceptance rate of the sample.
69 | stat_mean
70 | Mean Kbet chi-square statistic over all cells.
71 | pvalue_mean
72 | Mean Kbet p-value over all cells.
73 | """
74 | if len(batches) != len(X.indices):
75 | raise ValueError("Length of batches does not match number of cells.")
76 | knn_idx = X.indices
77 | batches = np.asarray(pd.Categorical(batches).codes)
78 | neigh_batch_ids = batches[knn_idx]
79 | chex.assert_equal_shape([neigh_batch_ids, knn_idx])
80 | n_batches = jnp.unique(batches).shape[0]
81 | test_statistics, p_values = _kbet(neigh_batch_ids, batches, n_batches)
82 | test_statistics = get_ndarray(test_statistics)
83 | p_values = get_ndarray(p_values)
84 | acceptance_rate = (p_values >= alpha).mean()
85 |
86 | return acceptance_rate, test_statistics, p_values
87 |
88 |
89 | def kbet_per_label(
90 | X: NeighborsResults,
91 | batches: np.ndarray,
92 | labels: np.ndarray,
93 | alpha: float = 0.05,
94 | diffusion_n_comps: int = 100,
95 | return_df: bool = False,
96 | ) -> float | tuple[float, pd.DataFrame]:
97 | """Compute kBET score per cell type label as in :cite:p:`luecken2022benchmarking`.
98 |
99 | This approximates the method used in the original scib package. Notably, the underlying
100 | kbet might have some inconsistencies with the R implementation. Furthermore, to equalize
101 | the neighbor graphs of cell type subsets we use diffusion distance approximated with diffusion
102 | maps. Increasing `diffusion_n_comps` will increase the accuracy of the approximation.
103 |
104 | Parameters
105 | ----------
106 | X
107 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
108 | batches
109 | Array of shape (n_cells,) representing batch values
110 | for each cell.
111 | alpha
112 | Significance level for the statistical test.
113 | diffusion_n_comps
114 | Number of diffusion components to use for diffusion distance approximation.
115 | return_df
116 | Return dataframe of results in addition to score.
117 |
118 | Returns
119 | -------
120 | kbet_score
121 | Kbet score over all cells. Higher means more integrated, as in the kBET acceptance rate.
122 | df
123 | Dataframe with kBET score per cell type label.
124 |
125 | Notes
126 | -----
127 | This function requires X to be cell-cell connectivities, not distances.
128 | """
129 | if len(batches) != len(X.indices):
130 | raise ValueError("Length of batches does not match number of cells.")
131 | if len(labels) != len(X.indices):
132 | raise ValueError("Length of labels does not match number of cells.")
133 | # set upper bound for k0
134 | size_max = 2**31 - 1
135 | batches = np.asarray(pd.Categorical(batches).codes)
136 | labels = np.asarray(labels)
137 |
138 | conn_graph = X.knn_graph_connectivities
139 |
140 | # prepare call of kBET per cluster
141 | clusters = []
142 | clusters, counts = np.unique(labels, return_counts=True)
143 | skipped = clusters[counts > 10]
144 | clusters = clusters[counts <= 10]
145 | kbet_scores = {"cluster": list(skipped), "kBET": [np.nan] * len(skipped)}
146 | logger.info(f"{len(skipped)} clusters consist of a single batch or are too small. Skip.")
147 |
148 | for clus in clusters:
149 | # subset by label
150 | mask = labels == clus
151 | conn_graph_sub = conn_graph[mask, :][:, mask]
152 | conn_graph_sub.sort_indices()
153 | n_obs = conn_graph_sub.shape[0]
154 | batches_sub = batches[mask]
155 |
156 | quarter_mean = np.floor(np.mean(pd.Series(batches_sub).value_counts()) / 4).astype("int")
157 | k0 = np.min([70, np.max([10, quarter_mean])])
158 | # check k0 for reasonability
159 | if k0 * n_obs >= size_max:
160 | k0 = np.floor(size_max / n_obs).astype("int")
161 |
162 | n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong")
163 |
164 | if n_comp == 1: # a single component to compute kBET on
165 | try:
166 | diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1])
167 | nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps)
168 | # call kBET
169 | score, _, _ = kbet(
170 | nn_graph_sub,
171 | batches=batches_sub,
172 | alpha=alpha,
173 | )
174 | except ValueError:
175 | logger.info("Diffusion distance failed. Skip.")
176 | score = 0 # i.e. 100% rejection
177 |
178 | else:
179 | # check the number of components where kBET can be computed upon
180 | comp_size = pd.Series(labs).value_counts()
181 | # check which components are small
182 | comp_size_thresh = 3 * k0
183 | idx_nonan = np.flatnonzero(np.in1d(labs, comp_size[comp_size >= comp_size_thresh].index))
184 |
185 | # check if 75% of all cells can be used for kBET run
186 | if len(idx_nonan) / len(labs) >= 0.75:
187 | # create another subset of components, assume they are not visited in a diffusion process
188 | conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan]
189 | conn_graph_sub_sub.sort_indices()
190 |
191 | try:
192 | diffusion_n_comps = np.min([diffusion_n_comps, conn_graph_sub_sub.shape[0] - 1])
193 | nn_results_sub_sub = diffusion_nn(conn_graph_sub_sub, k=k0, n_comps=diffusion_n_comps)
194 | # call kBET
195 | score, _, _ = kbet(
196 | nn_results_sub_sub,
197 | batches=batches_sub[idx_nonan],
198 | alpha=alpha,
199 | )
200 | except ValueError:
201 | logger.info("Diffusion distance failed. Skip.")
202 | score = 0 # i.e. 100% rejection
203 | else: # if there are too many too small connected components, set kBET score to 0
204 | score = 0 # i.e. 100% rejection
205 |
206 | kbet_scores["cluster"].append(clus)
207 | kbet_scores["kBET"].append(score)
208 |
209 | kbet_scores = pd.DataFrame.from_dict(kbet_scores)
210 | kbet_scores = kbet_scores.reset_index(drop=True)
211 |
212 | final_score = np.nanmean(kbet_scores["kBET"])
213 | if not return_df:
214 | return final_score
215 | else:
216 | return final_score, kbet_scores
217 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_lisi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from scib_metrics.nearest_neighbors import NeighborsResults
5 | from scib_metrics.utils import compute_simpson_index
6 |
7 |
8 | def lisi_knn(X: NeighborsResults, labels: np.ndarray, perplexity: float = None) -> np.ndarray:
9 | """Compute the local inverse simpson index (LISI) for each cell :cite:p:`korsunsky2019harmony`.
10 |
11 | Parameters
12 | ----------
13 | X
14 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
15 | labels
16 | Array of shape (n_cells,) representing label values
17 | for each cell.
18 | perplexity
19 | Parameter controlling effective neighborhood size. If None, the
20 | perplexity is set to the number of neighbors // 3.
21 |
22 | Returns
23 | -------
24 | lisi
25 | Array of shape (n_cells,) with the LISI score for each cell.
26 | """
27 | labels = np.asarray(pd.Categorical(labels).codes)
28 | knn_dists, knn_idx = X.distances, X.indices
29 | row_idx = np.arange(X.n_samples)[:, np.newaxis]
30 |
31 | if perplexity is None:
32 | perplexity = np.floor(knn_idx.shape[1] / 3)
33 |
34 | n_labels = len(np.unique(labels))
35 |
36 | simpson = compute_simpson_index(
37 | knn_dists=knn_dists, knn_idx=knn_idx, row_idx=row_idx, labels=labels, n_labels=n_labels, perplexity=perplexity
38 | )
39 | return 1 / simpson
40 |
41 |
42 | def ilisi_knn(X: NeighborsResults, batches: np.ndarray, perplexity: float = None, scale: bool = True) -> float:
43 | """Compute the integration local inverse simpson index (iLISI) for each cell :cite:p:`korsunsky2019harmony`.
44 |
45 | Returns a scaled version of the iLISI score for each cell, by default :cite:p:`luecken2022benchmarking`.
46 |
47 | Parameters
48 | ----------
49 | X
50 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
51 | batches
52 | Array of shape (n_cells,) representing batch values
53 | for each cell.
54 | perplexity
55 | Parameter controlling effective neighborhood size. If None, the
56 | perplexity is set to the number of neighbors // 3.
57 | scale
58 | Scale lisi into the range [0, 1]. If True, higher values are better.
59 |
60 | Returns
61 | -------
62 | ilisi
63 | iLISI score.
64 | """
65 | batches = np.asarray(pd.Categorical(batches).codes)
66 | lisi = lisi_knn(X, batches, perplexity=perplexity)
67 | ilisi = np.nanmedian(lisi)
68 | if scale:
69 | nbatches = len(np.unique(batches))
70 | ilisi = (ilisi - 1) / (nbatches - 1)
71 | return ilisi
72 |
73 |
74 | def clisi_knn(X: NeighborsResults, labels: np.ndarray, perplexity: float = None, scale: bool = True) -> float:
75 | """Compute the cell-type local inverse simpson index (cLISI) for each cell :cite:p:`korsunsky2019harmony`.
76 |
77 | Returns a scaled version of the cLISI score for each cell, by default :cite:p:`luecken2022benchmarking`.
78 |
79 | Parameters
80 | ----------
81 | X
82 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
83 | labels
84 | Array of shape (n_cells,) representing cell type label values
85 | for each cell.
86 | perplexity
87 | Parameter controlling effective neighborhood size. If None, the
88 | perplexity is set to the number of neighbors // 3.
89 | scale
90 | Scale lisi into the range [0, 1]. If True, higher values are better.
91 |
92 | Returns
93 | -------
94 | clisi
95 | cLISI score.
96 | """
97 | labels = np.asarray(pd.Categorical(labels).codes)
98 | lisi = lisi_knn(X, labels, perplexity=perplexity)
99 | clisi = np.nanmedian(lisi)
100 | if scale:
101 | nlabels = len(np.unique(labels))
102 | clisi = (nlabels - clisi) / (nlabels - 1)
103 | return clisi
104 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_nmi_ari.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import warnings
4 |
5 | import igraph
6 | import numpy as np
7 | from scipy.sparse import spmatrix
8 | from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score
9 | from sklearn.utils import check_array
10 |
11 | from scib_metrics.nearest_neighbors import NeighborsResults
12 | from scib_metrics.utils import KMeans
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def _compute_clustering_kmeans(X: np.ndarray, n_clusters: int) -> np.ndarray:
18 | kmeans = KMeans(n_clusters)
19 | kmeans.fit(X)
20 | return kmeans.labels_
21 |
22 |
23 | def _compute_clustering_leiden(connectivity_graph: spmatrix, resolution: float, seed: int) -> np.ndarray:
24 | rng = random.Random(seed)
25 | igraph.set_random_number_generator(rng)
26 | # The connectivity graph with the umap method is symmetric, but we need to first make it directed
27 | # to have both sets of edges as is done in scanpy. See test for more details.
28 | g = igraph.Graph.Weighted_Adjacency(connectivity_graph, mode="directed")
29 | g.to_undirected(mode="each")
30 | clustering = g.community_leiden(objective_function="modularity", weights="weight", resolution=resolution)
31 | clusters = clustering.membership
32 | return np.asarray(clusters)
33 |
34 |
35 | def _compute_nmi_ari_cluster_labels(
36 | X: spmatrix,
37 | labels: np.ndarray,
38 | resolution: float = 1.0,
39 | seed: int = 42,
40 | ) -> tuple[float, float]:
41 | labels_pred = _compute_clustering_leiden(X, resolution, seed)
42 | nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic")
43 | ari = adjusted_rand_score(labels, labels_pred)
44 | return nmi, ari
45 |
46 |
47 | def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> dict[str, float]:
48 | """Compute nmi and ari between k-means clusters and labels.
49 |
50 | This deviates from the original implementation in scib by using k-means
51 | with k equal to the known number of cell types/labels. This leads to
52 | a more efficient computation of the nmi and ari scores.
53 |
54 | Parameters
55 | ----------
56 | X
57 | Array of shape (n_cells, n_features).
58 | labels
59 | Array of shape (n_cells,) representing label values
60 |
61 | Returns
62 | -------
63 | nmi
64 | Normalized mutual information score
65 | ari
66 | Adjusted rand index score
67 | """
68 | X = check_array(X, accept_sparse=False, ensure_2d=True)
69 | n_clusters = len(np.unique(labels))
70 | labels_pred = _compute_clustering_kmeans(X, n_clusters)
71 | nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic")
72 | ari = adjusted_rand_score(labels, labels_pred)
73 |
74 | return {"nmi": nmi, "ari": ari}
75 |
76 |
77 | def nmi_ari_cluster_labels_leiden(
78 | X: NeighborsResults,
79 | labels: np.ndarray,
80 | optimize_resolution: bool = True,
81 | resolution: float = 1.0,
82 | n_jobs: int = 1,
83 | seed: int = 42,
84 | ) -> dict[str, float]:
85 | """Compute nmi and ari between leiden clusters and labels.
86 |
87 | This deviates from the original implementation in scib by using leiden instead of
88 | louvain clustering. Installing joblib allows for parallelization of the leiden
89 | resoution optimization.
90 |
91 | Parameters
92 | ----------
93 | X
94 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
95 | labels
96 | Array of shape (n_cells,) representing label values
97 | optimize_resolution
98 | Whether to optimize the resolution parameter of leiden clustering by searching over
99 | 10 values
100 | resolution
101 | Resolution parameter of leiden clustering. Only used if optimize_resolution is False.
102 | n_jobs
103 | Number of jobs for parallelizing resolution optimization via joblib. If -1, all CPUs
104 | are used.
105 | seed
106 | Seed used for reproducibility of clustering.
107 |
108 | Returns
109 | -------
110 | nmi
111 | Normalized mutual information score
112 | ari
113 | Adjusted rand index score
114 | """
115 | conn_graph = X.knn_graph_connectivities
116 | if optimize_resolution:
117 | n = 10
118 | resolutions = np.array([2 * x / n for x in range(1, n + 1)])
119 | try:
120 | from joblib import Parallel, delayed
121 |
122 | out = Parallel(n_jobs=n_jobs)(
123 | delayed(_compute_nmi_ari_cluster_labels)(conn_graph, labels, r) for r in resolutions
124 | )
125 | except ImportError:
126 | warnings.warn("Using for loop over clustering resolutions. `pip install joblib` for parallelization.")
127 | out = [_compute_nmi_ari_cluster_labels(conn_graph, labels, r, seed=seed) for r in resolutions]
128 | nmi_ari = np.array(out)
129 | nmi_ind = np.argmax(nmi_ari[:, 0])
130 | nmi, ari = nmi_ari[nmi_ind, :]
131 | return {"nmi": nmi, "ari": ari}
132 | else:
133 | nmi, ari = _compute_nmi_ari_cluster_labels(conn_graph, labels, resolution)
134 |
135 | return {"nmi": nmi, "ari": ari}
136 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_pcr_comparison.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from scib_metrics._types import NdArray
4 | from scib_metrics.utils import principal_component_regression
5 |
6 |
7 | def pcr_comparison(
8 | X_pre: NdArray,
9 | X_post: NdArray,
10 | covariate: NdArray,
11 | scale: bool = True,
12 | **kwargs,
13 | ) -> float:
14 | """Principal component regression (PCR) comparison :cite:p:`buttner2018`.
15 |
16 | Compare the explained variance before and after integration.
17 |
18 | Parameters
19 | ----------
20 | X_pre
21 | Pre-integration array of shape (n_cells, n_features).
22 | X_post
23 | Post-integration array of shape (n_celss, n_features).
24 | covariate_pre:
25 | Array of shape (n_cells,) or (n_cells, 1) representing batch/covariate values.
26 | scale
27 | Whether to scale the score between 0 and 1. If True, larger values correspond to
28 | larger differences in variance contributions between `X_pre` and `X_post`.
29 | kwargs
30 | Keyword arguments passed into :func:`~scib_metrics.principal_component_regression`.
31 |
32 | Returns
33 | -------
34 | pcr_compared: float
35 | Principal component regression score comparing the explained variance before and
36 | after integration.
37 | """
38 | if X_pre.shape[0] != X_post.shape[0]:
39 | raise ValueError("Dimension mismatch: `X_pre` and `X_post` must have the same number of samples.")
40 | if covariate.shape[0] != X_pre.shape[0]:
41 | raise ValueError("Dimension mismatch: `X_pre` and `covariate` must have the same number of samples.")
42 |
43 | pcr_pre = principal_component_regression(X_pre, covariate, **kwargs)
44 | pcr_post = principal_component_regression(X_post, covariate, **kwargs)
45 |
46 | if scale:
47 | pcr_compared = (pcr_pre - pcr_post) / pcr_pre
48 | if pcr_compared < 0:
49 | warnings.warn(
50 | "PCR comparison score is negative, meaning variance contribution "
51 | "increased after integration. Setting to 0."
52 | )
53 | pcr_compared = 0
54 | else:
55 | pcr_compared = pcr_post - pcr_pre
56 |
57 | return pcr_compared
58 |
--------------------------------------------------------------------------------
/src/scib_metrics/metrics/_silhouette.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from scib_metrics.utils import silhouette_samples
5 |
6 |
7 | def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True, chunk_size: int = 256) -> float:
8 | """Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
9 |
10 | Parameters
11 | ----------
12 | X
13 | Array of shape (n_cells, n_features).
14 | labels
15 | Array of shape (n_cells,) representing label values
16 | rescale
17 | Scale asw into the range [0, 1].
18 | chunk_size
19 | Size of chunks to process at a time for distance computation.
20 |
21 | Returns
22 | -------
23 | silhouette score
24 | """
25 | asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size))
26 | if rescale:
27 | asw = (asw + 1) / 2
28 | return np.mean(asw)
29 |
30 |
31 | def silhouette_batch(
32 | X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True, chunk_size: int = 256
33 | ) -> float:
34 | """Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
35 |
36 | Parameters
37 | ----------
38 | X
39 | Array of shape (n_cells, n_features).
40 | labels
41 | Array of shape (n_cells,) representing label values
42 | batch
43 | Array of shape (n_cells,) representing batch values
44 | rescale
45 | Scale asw into the range [0, 1]. If True, higher values are better.
46 | chunk_size
47 | Size of chunks to process at a time for distance computation.
48 |
49 | Returns
50 | -------
51 | silhouette score
52 | """
53 | sil_dfs = []
54 | unique_labels = np.unique(labels)
55 | for group in unique_labels:
56 | labels_mask = labels == group
57 | X_subset = X[labels_mask]
58 | batch_subset = batch[labels_mask]
59 | n_batches = len(np.unique(batch_subset))
60 |
61 | if (n_batches == 1) or (n_batches == X_subset.shape[0]):
62 | continue
63 |
64 | sil_per_group = silhouette_samples(X_subset, batch_subset, chunk_size=chunk_size)
65 |
66 | # take only absolute value
67 | sil_per_group = np.abs(sil_per_group)
68 |
69 | if rescale:
70 | # scale s.t. highest number is optimal
71 | sil_per_group = 1 - sil_per_group
72 |
73 | sil_dfs.append(
74 | pd.DataFrame(
75 | {
76 | "group": [group] * len(sil_per_group),
77 | "silhouette_score": sil_per_group,
78 | }
79 | )
80 | )
81 |
82 | sil_df = pd.concat(sil_dfs).reset_index(drop=True)
83 | sil_means = sil_df.groupby("group").mean()
84 | asw = sil_means["silhouette_score"].mean()
85 |
86 | return asw
87 |
--------------------------------------------------------------------------------
/src/scib_metrics/nearest_neighbors/__init__.py:
--------------------------------------------------------------------------------
1 | from ._dataclass import NeighborsResults
2 | from ._jax import jax_approx_min_k
3 | from ._pynndescent import pynndescent
4 |
5 | __all__ = [
6 | "pynndescent",
7 | "jax_approx_min_k",
8 | "NeighborsResults",
9 | ]
10 |
--------------------------------------------------------------------------------
/src/scib_metrics/nearest_neighbors/_dataclass.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import cached_property
3 |
4 | import chex
5 | import numpy as np
6 | from scipy.sparse import coo_matrix, csr_matrix
7 | from umap.umap_ import fuzzy_simplicial_set
8 |
9 |
10 | @dataclass
11 | class NeighborsResults:
12 | """Nearest neighbors results data store.
13 |
14 | Attributes
15 | ----------
16 | distances : np.ndarray
17 | Array of distances to the nearest neighbors.
18 | indices : np.ndarray
19 | Array of indices of the nearest neighbors. Self should always
20 | be included here; however, some approximate algorithms may not return
21 | the self edge.
22 | """
23 |
24 | indices: np.ndarray
25 | distances: np.ndarray
26 |
27 | def __post_init__(self):
28 | chex.assert_equal_shape([self.indices, self.distances])
29 |
30 | @property
31 | def n_samples(self) -> np.ndarray:
32 | """Number of samples (cells)."""
33 | return self.indices.shape[0]
34 |
35 | @property
36 | def n_neighbors(self) -> np.ndarray:
37 | """Number of neighbors."""
38 | return self.indices.shape[1]
39 |
40 | @cached_property
41 | def knn_graph_distances(self) -> csr_matrix:
42 | """Return the sparse weighted adjacency matrix."""
43 | n_samples, n_neighbors = self.indices.shape
44 | # Efficient creation of row pointer
45 | rowptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors)
46 | # Create CSR matrix
47 | return csr_matrix((self.distances.ravel(), self.indices.ravel(), rowptr), shape=(n_samples, n_samples))
48 |
49 | @cached_property
50 | def knn_graph_connectivities(self) -> coo_matrix:
51 | """Compute connectivities using the UMAP approach.
52 |
53 | Connectivities (similarities) are computed from distances
54 | using the approach from the UMAP method, which is also used by scanpy.
55 | """
56 | conn_graph = coo_matrix(([], ([], [])), shape=(self.n_samples, 1))
57 | connectivities = fuzzy_simplicial_set(
58 | conn_graph,
59 | n_neighbors=self.n_neighbors,
60 | random_state=None,
61 | metric=None,
62 | knn_indices=self.indices,
63 | knn_dists=self.distances,
64 | set_op_mix_ratio=1.0,
65 | local_connectivity=1.0,
66 | )
67 | return connectivities[0]
68 |
69 | def subset_neighbors(self, n: int) -> "NeighborsResults":
70 | """Subset down to `n` neighbors."""
71 | if n > self.n_neighbors:
72 | raise ValueError("n must be smaller than the number of neighbors")
73 | return self.__class__(indices=self.indices[:, :n], distances=self.distances[:, :n])
74 |
--------------------------------------------------------------------------------
/src/scib_metrics/nearest_neighbors/_jax.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 |
7 | from scib_metrics.utils import cdist, get_ndarray
8 |
9 | from ._dataclass import NeighborsResults
10 |
11 |
12 | @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
13 | def _euclidean_ann(qy: jnp.ndarray, db: jnp.ndarray, k: int, recall_target: float = 0.95):
14 | """Compute half squared L2 distance between query points and database points."""
15 | dists = cdist(qy, db)
16 | return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
17 |
18 |
19 | def jax_approx_min_k(
20 | X: np.ndarray, n_neighbors: int, recall_target: float = 0.95, chunk_size: int = 2048
21 | ) -> NeighborsResults:
22 | """Run approximate nearest neighbor search using jax.
23 |
24 | On TPU backends, this is approximate nearest neighbor search. On other backends, this is exact nearest neighbor search.
25 |
26 | Parameters
27 | ----------
28 | X
29 | Data matrix.
30 | n_neighbors
31 | Number of neighbors to search for.
32 | recall_target
33 | Target recall for approximate nearest neighbor search.
34 | chunk_size
35 | Number of query points to search for at once.
36 | """
37 | db = jnp.asarray(X)
38 | # Loop over query points in chunks
39 | neighbors = []
40 | dists = []
41 | for i in range(0, db.shape[0], chunk_size):
42 | start = i
43 | end = min(i + chunk_size, db.shape[0])
44 | qy = db[start:end]
45 | dist, neighbor = _euclidean_ann(qy, db, k=n_neighbors, recall_target=recall_target)
46 | neighbors.append(neighbor)
47 | dists.append(dist)
48 | neighbors = jnp.concatenate(neighbors, axis=0)
49 | dists = jnp.concatenate(dists, axis=0)
50 | return NeighborsResults(indices=get_ndarray(neighbors), distances=get_ndarray(dists))
51 |
--------------------------------------------------------------------------------
/src/scib_metrics/nearest_neighbors/_pynndescent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pynndescent import NNDescent
3 |
4 | from ._dataclass import NeighborsResults
5 |
6 |
7 | def pynndescent(X: np.ndarray, n_neighbors: int, random_state: int = 0, n_jobs: int = 1) -> NeighborsResults:
8 | """Run pynndescent approximate nearest neighbor search.
9 |
10 | Parameters
11 | ----------
12 | X
13 | Data matrix.
14 | n_neighbors
15 | Number of neighbors to search for.
16 | random_state
17 | Random state.
18 | n_jobs
19 | Number of jobs to use.
20 | """
21 | # Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
22 | # which is used by scanpy under the hood
23 | n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0)))
24 | n_iters = max(5, int(round(np.log2(X.shape[0]))))
25 | max_candidates = 60
26 |
27 | knn_search_index = NNDescent(
28 | X,
29 | n_neighbors=n_neighbors,
30 | random_state=random_state,
31 | low_memory=True,
32 | n_jobs=n_jobs,
33 | compressed=False,
34 | n_trees=n_trees,
35 | n_iters=n_iters,
36 | max_candidates=max_candidates,
37 | )
38 | indices, distances = knn_search_index.neighbor_graph
39 |
40 | return NeighborsResults(indices=indices, distances=distances)
41 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._diffusion_nn import diffusion_nn
2 | from ._dist import cdist, pdist_squareform
3 | from ._kmeans import KMeans
4 | from ._lisi import compute_simpson_index
5 | from ._pca import pca
6 | from ._pcr import principal_component_regression
7 | from ._silhouette import silhouette_samples
8 | from ._utils import check_square, convert_knn_graph_to_idx, get_ndarray, one_hot
9 |
10 | __all__ = [
11 | "silhouette_samples",
12 | "cdist",
13 | "pdist_squareform",
14 | "get_ndarray",
15 | "KMeans",
16 | "pca",
17 | "principal_component_regression",
18 | "one_hot",
19 | "compute_simpson_index",
20 | "convert_knn_graph_to_idx",
21 | "check_square",
22 | "diffusion_nn",
23 | ]
24 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_diffusion_nn.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Literal
3 |
4 | import numpy as np
5 | import scipy
6 | from scipy.sparse import csr_matrix, issparse
7 |
8 | from scib_metrics import nearest_neighbors
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | _EPS = 1e-8
13 |
14 |
15 | def _compute_transitions(X: csr_matrix, density_normalize: bool = True):
16 | """Code from scanpy.
17 |
18 | https://github.com/scverse/scanpy/blob/2e98705347ea484c36caa9ba10de1987b09081bf/scanpy/neighbors/__init__.py#L899
19 | """
20 | # TODO(adamgayoso): Refactor this with Jax
21 | # density normalization as of Coifman et al. (2005)
22 | # ensures that kernel matrix is independent of sampling density
23 | if density_normalize:
24 | # q[i] is an estimate for the sampling density at point i
25 | # it's also the degree of the underlying graph
26 | q = np.asarray(X.sum(axis=0))
27 | if not issparse(X):
28 | Q = np.diag(1.0 / q)
29 | else:
30 | Q = scipy.sparse.spdiags(1.0 / q, 0, X.shape[0], X.shape[0])
31 | K = Q @ X @ Q
32 | else:
33 | K = X
34 |
35 | # z[i] is the square root of the row sum of K
36 | z = np.sqrt(np.asarray(K.sum(axis=0)))
37 | if not issparse(K):
38 | Z = np.diag(1.0 / z)
39 | else:
40 | Z = scipy.sparse.spdiags(1.0 / z, 0, K.shape[0], K.shape[0])
41 | transitions_sym = Z @ K @ Z
42 |
43 | return transitions_sym
44 |
45 |
46 | def _compute_eigen(
47 | transitions_sym: csr_matrix,
48 | n_comps: int = 15,
49 | sort: Literal["decrease", "increase"] = "decrease",
50 | ):
51 | """Compute eigen decomposition of transition matrix.
52 |
53 | https://github.com/scverse/scanpy/blob/2e98705347ea484c36caa9ba10de1987b09081bf/scanpy/neighbors/__init__.py
54 | """
55 | # TODO(adamgayoso): Refactor this with Jax
56 | matrix = transitions_sym
57 | # compute the spectrum
58 | if n_comps == 0:
59 | evals, evecs = scipy.linalg.eigh(matrix)
60 | else:
61 | n_comps = min(matrix.shape[0] - 1, n_comps)
62 | # ncv = max(2 * n_comps + 1, int(np.sqrt(matrix.shape[0])))
63 | ncv = None
64 | which = "LM" if sort == "decrease" else "SM"
65 | # it pays off to increase the stability with a bit more precision
66 | matrix = matrix.astype(np.float64)
67 |
68 | evals, evecs = scipy.sparse.linalg.eigsh(matrix, k=n_comps, which=which, ncv=ncv)
69 | evals, evecs = evals.astype(np.float32), evecs.astype(np.float32)
70 | if sort == "decrease":
71 | evals = evals[::-1]
72 | evecs = evecs[:, ::-1]
73 |
74 | return evals, evecs
75 |
76 |
77 | def _get_sparse_matrix_from_indices_distances_numpy(indices, distances, n_obs, n_neighbors):
78 | """Code from scanpy."""
79 | n_nonzero = n_obs * n_neighbors
80 | indptr = np.arange(0, n_nonzero + 1, n_neighbors)
81 | D = csr_matrix(
82 | (
83 | distances.copy().ravel(), # copy the data, otherwise strange behavior here
84 | indices.copy().ravel(),
85 | indptr,
86 | ),
87 | shape=(n_obs, n_obs),
88 | )
89 | D.eliminate_zeros()
90 | D.sort_indices()
91 | return D
92 |
93 |
94 | def diffusion_nn(X: csr_matrix, k: int, n_comps: int = 100) -> nearest_neighbors.NeighborsResults:
95 | """Diffusion-based neighbors.
96 |
97 | This function generates a nearest neighbour list from a connectivities matrix.
98 | This allows us to select a consistent number of nearest neighbors across all methods.
99 |
100 | This differs from the original scIB implemenation by leveraging diffusion maps. Here we
101 | embed the data with diffusion maps in which euclidean distance represents well the diffusion
102 | distance. We then use pynndescent to find the nearest neighbours in this embedding space.
103 |
104 | Parameters
105 | ----------
106 | X
107 | Array of shape (n_cells, n_cells) with non-zero values
108 | representing connectivities.
109 | k
110 | Number of nearest neighbours to select.
111 | n_comps
112 | Number of components for diffusion map
113 |
114 | Returns
115 | -------
116 | Neighbors results
117 | """
118 | transitions = _compute_transitions(X)
119 | evals, evecs = _compute_eigen(transitions, n_comps=n_comps)
120 | evals += _EPS # Avoid division by zero
121 | # Multiscale such that the number of steps t gets "integrated out"
122 | embedding = evecs
123 | scaled_evals = np.array([e if e == 1 else e / (1 - e) for e in evals])
124 | embedding *= scaled_evals
125 | nn_result = nearest_neighbors.pynndescent(embedding, n_neighbors=k + 1)
126 |
127 | return nn_result
128 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_dist.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 |
5 |
6 | @jax.jit
7 | def _euclidean_distance(x: np.array, y: np.array) -> float:
8 | dist = jnp.sqrt(jnp.sum((x - y) ** 2))
9 | return dist
10 |
11 |
12 | @jax.jit
13 | def cdist(x: np.ndarray, y: np.ndarray) -> jnp.ndarray:
14 | """Jax implementation of :func:`scipy.spatial.distance.cdist`.
15 |
16 | Uses euclidean distance.
17 |
18 | Parameters
19 | ----------
20 | x
21 | Array of shape (n_cells_a, n_features)
22 | y
23 | Array of shape (n_cells_b, n_features)
24 |
25 | Returns
26 | -------
27 | dist
28 | Array of shape (n_cells_a, n_cells_b)
29 | """
30 | return jax.vmap(lambda x1: jax.vmap(lambda y1: _euclidean_distance(x1, y1))(y))(x)
31 |
32 |
33 | @jax.jit
34 | def pdist_squareform(X: np.ndarray) -> jnp.ndarray:
35 | """Jax implementation of :func:`scipy.spatial.distance.pdist` and :func:`scipy.spatial.distance.squareform`.
36 |
37 | Uses euclidean distance.
38 |
39 | Parameters
40 | ----------
41 | X
42 | Array of shape (n_cells, n_features)
43 |
44 | Returns
45 | -------
46 | dist
47 | Array of shape (n_cells, n_cells)
48 | """
49 | n_cells = X.shape[0]
50 | inds = jnp.triu_indices(n_cells)
51 |
52 | def _body_fn(X, i_j):
53 | i, j = i_j
54 | return X, _euclidean_distance(X[i], X[j])
55 |
56 | dist_mat = jnp.zeros((n_cells, n_cells))
57 | dist_mat = dist_mat.at[inds].set(jax.lax.scan(_body_fn, X, (inds[0], inds[1]))[1])
58 | dist_mat = jnp.maximum(dist_mat, dist_mat.T)
59 | return dist_mat
60 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_kmeans.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Literal
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import numpy as np
7 | from jax import Array
8 | from sklearn.utils import check_array
9 |
10 | from scib_metrics._types import IntOrKey
11 |
12 | from ._dist import cdist
13 | from ._utils import get_ndarray, validate_seed
14 |
15 |
16 | def _tolerance(X: jnp.ndarray, tol: float) -> float:
17 | """Return a tolerance which is dependent on the dataset."""
18 | variances = np.var(X, axis=0)
19 | return np.mean(variances) * tol
20 |
21 |
22 | def _initialize_random(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray:
23 | """Initialize cluster centroids randomly."""
24 | n_obs = X.shape[0]
25 | key, subkey = jax.random.split(key)
26 | indices = jax.random.choice(subkey, n_obs, (n_clusters,), replace=False)
27 | initial_state = X[indices]
28 | return initial_state
29 |
30 |
31 | @partial(jax.jit, static_argnums=1)
32 | def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray:
33 | """Initialize cluster centroids with k-means++ algorithm."""
34 | n_obs = X.shape[0]
35 | key, subkey = jax.random.split(key)
36 | initial_centroid_idx = jax.random.choice(subkey, n_obs, (1,), replace=False)
37 | initial_centroid = X[initial_centroid_idx].ravel()
38 | dist_sq = jnp.square(cdist(initial_centroid[jnp.newaxis, :], X)).ravel()
39 | initial_state = {"min_dist_sq": dist_sq, "centroid": initial_centroid, "key": key}
40 | n_local_trials = 2 + int(np.log(n_clusters))
41 |
42 | def _step(state, _):
43 | prob = state["min_dist_sq"] / jnp.sum(state["min_dist_sq"])
44 | # note that observations already chosen as centers will have 0 probability
45 | # and will not be chosen again
46 | state["key"], subkey = jax.random.split(state["key"])
47 | next_centroid_idx_candidates = jax.random.choice(subkey, n_obs, (n_local_trials,), replace=False, p=prob)
48 | next_centroid_candidates = X[next_centroid_idx_candidates]
49 | # candidates by observations
50 | dist_sq_candidates = jnp.square(cdist(next_centroid_candidates, X))
51 | dist_sq_candidates = jnp.minimum(state["min_dist_sq"][jnp.newaxis, :], dist_sq_candidates)
52 | candidates_pot = dist_sq_candidates.sum(axis=1)
53 |
54 | # Decide which candidate is the best
55 | best_candidate = jnp.argmin(candidates_pot)
56 | min_dist_sq = dist_sq_candidates[best_candidate]
57 | best_candidate = next_centroid_idx_candidates[best_candidate]
58 |
59 | state["min_dist_sq"] = min_dist_sq.ravel()
60 | state["centroid"] = X[best_candidate].ravel()
61 | return state, state["centroid"]
62 |
63 | _, centroids = jax.lax.scan(_step, initial_state, jnp.arange(n_clusters - 1))
64 | centroids = jnp.concatenate([initial_centroid[jnp.newaxis, :], centroids])
65 | return centroids
66 |
67 |
68 | @jax.jit
69 | def _get_dist_labels(X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
70 | """Get the distance and labels for each observation."""
71 | dist = jnp.square(cdist(X, centroids))
72 | labels = jnp.argmin(dist, axis=1)
73 | return dist, labels
74 |
75 |
76 | class KMeans:
77 | """Jax implementation of :class:`sklearn.cluster.KMeans`.
78 |
79 | This implementation is limited to Euclidean distance.
80 |
81 | Parameters
82 | ----------
83 | n_clusters
84 | Number of clusters.
85 | init
86 | Cluster centroid initialization method. One of the following:
87 |
88 | * ``'k-means++'``: Sample initial cluster centroids based on an
89 | empirical distribution of the points' contributions to the
90 | overall inertia.
91 | * ``'random'``: Uniformly sample observations as initial centroids
92 | n_init
93 | Number of times the k-means algorithm will be initialized.
94 | max_iter
95 | Maximum number of iterations of the k-means algorithm for a single run.
96 | tol
97 | Relative tolerance with regards to inertia to declare convergence.
98 | seed
99 | Random seed.
100 | """
101 |
102 | def __init__(
103 | self,
104 | n_clusters: int = 8,
105 | init: Literal["k-means++", "random"] = "k-means++",
106 | n_init: int = 1,
107 | max_iter: int = 300,
108 | tol: float = 1e-4,
109 | seed: IntOrKey = 0,
110 | ):
111 | self.n_clusters = n_clusters
112 | self.n_init = n_init
113 | self.max_iter = max_iter
114 | self.tol_scale = tol
115 | self.seed: jax.Array = validate_seed(seed)
116 |
117 | if init not in ["k-means++", "random"]:
118 | raise ValueError("Invalid init method, must be one of ['k-means++' or 'random'].")
119 | if init == "k-means++":
120 | self._initialize = _initialize_plus_plus
121 | else:
122 | self._initialize = _initialize_random
123 |
124 | def fit(self, X: np.ndarray):
125 | """Fit the model to the data."""
126 | X = check_array(X, dtype=np.float32, order="C")
127 | self.tol = _tolerance(X, self.tol_scale)
128 | # Subtract mean for numerical accuracy
129 | mean = X.mean(axis=0)
130 | X -= mean
131 | self._fit(X)
132 | X += mean
133 | self.cluster_centroids_ += mean
134 | return self
135 |
136 | def _fit(self, X: np.ndarray):
137 | all_centroids, all_inertias = jax.lax.map(
138 | lambda key: self._kmeans_full_run(X, key), jax.random.split(self.seed, self.n_init)
139 | )
140 | i = jnp.argmin(all_inertias)
141 | self.cluster_centroids_ = get_ndarray(all_centroids[i])
142 | self.inertia_ = get_ndarray(all_inertias[i])
143 | _, labels = _get_dist_labels(X, self.cluster_centroids_)
144 | self.labels_ = get_ndarray(labels)
145 |
146 | @partial(jax.jit, static_argnums=(0,))
147 | def _kmeans_full_run(self, X: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:
148 | def _kmeans_step(state):
149 | centroids, old_inertia, _, n_iter = state
150 | # TODO(adamgayoso): Efficiently compute argmin and min simultaneously.
151 | dist, new_labels = _get_dist_labels(X, centroids)
152 | # From https://colab.research.google.com/drive/1AwS4haUx6swF82w3nXr6QKhajdF8aSvA?usp=sharing
153 | counts = (new_labels[jnp.newaxis, :] == jnp.arange(self.n_clusters)[:, jnp.newaxis]).sum(
154 | axis=1, keepdims=True
155 | )
156 | counts = jnp.clip(counts, min=1, max=None)
157 | # Sum over points in a centroid by zeroing others out
158 | new_centroids = (
159 | jnp.sum(
160 | jnp.where(
161 | # axes: (data points, clusters, data dimension)
162 | new_labels[:, jnp.newaxis, jnp.newaxis]
163 | == jnp.arange(self.n_clusters)[jnp.newaxis, :, jnp.newaxis],
164 | X[:, jnp.newaxis, :],
165 | 0.0,
166 | ),
167 | axis=0,
168 | )
169 | / counts
170 | )
171 | new_inertia = jnp.sum(jnp.min(dist, axis=1))
172 | n_iter = n_iter + 1
173 | return new_centroids, new_inertia, old_inertia, n_iter
174 |
175 | def _kmeans_convergence(state):
176 | _, new_inertia, old_inertia, n_iter = state
177 | cond1 = jnp.abs(old_inertia - new_inertia) > self.tol
178 | cond2 = n_iter < self.max_iter
179 | return jnp.logical_or(cond1, cond2)[0]
180 |
181 | centroids = self._initialize(X, self.n_clusters, key)
182 | # centroids, new_inertia, old_inertia, n_iter
183 | state = (centroids, jnp.inf, jnp.inf, jnp.array([0.0]))
184 | state = jax.lax.while_loop(_kmeans_convergence, _kmeans_step, state)
185 | # Compute final inertia
186 | centroids = state[0]
187 | dist, _ = _get_dist_labels(X, centroids)
188 | final_intertia = jnp.sum(jnp.min(dist, axis=1))
189 | return centroids, final_intertia
190 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_lisi.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import chex
4 | import jax
5 | import jax.numpy as jnp
6 | import numpy as np
7 |
8 | from ._utils import get_ndarray
9 |
10 | NdArray = np.ndarray | jnp.ndarray
11 |
12 |
13 | @chex.dataclass
14 | class _NeighborProbabilityState:
15 | H: float
16 | P: chex.ArrayDevice
17 | Hdiff: float
18 | beta: float
19 | betamin: float
20 | betamax: float
21 | tries: int
22 |
23 |
24 | @jax.jit
25 | def _Hbeta(knn_dists_row: jnp.ndarray, row_self_mask: jnp.ndarray, beta: float) -> tuple[jnp.ndarray, jnp.ndarray]:
26 | P = jnp.exp(-knn_dists_row * beta)
27 | # Mask out self edges to be zero
28 | P = jnp.where(row_self_mask, P, 0)
29 | sumP = jnp.nansum(P)
30 | H = jnp.where(sumP == 0, 0, jnp.log(sumP) + beta * jnp.nansum(knn_dists_row * P) / sumP)
31 | P = jnp.where(sumP == 0, jnp.zeros_like(knn_dists_row), P / sumP)
32 | return H, P
33 |
34 |
35 | @jax.jit
36 | def _get_neighbor_probability(
37 | knn_dists_row: jnp.ndarray, row_self_mask: jnp.ndarray, perplexity: float, tol: float
38 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
39 | beta = 1
40 | betamin = -jnp.inf
41 | betamax = jnp.inf
42 | H, P = _Hbeta(knn_dists_row, row_self_mask, beta)
43 | Hdiff = H - jnp.log(perplexity)
44 |
45 | def _get_neighbor_probability_step(state):
46 | Hdiff = state.Hdiff
47 | beta = state.beta
48 | betamin = state.betamin
49 | betamax = state.betamax
50 | tries = state.tries
51 |
52 | new_betamin = jnp.where(Hdiff > 0, beta, betamin)
53 | new_betamax = jnp.where(Hdiff > 0, betamax, beta)
54 | new_beta = jnp.where(
55 | Hdiff > 0,
56 | jnp.where(betamax == jnp.inf, beta * 2, (beta + betamax) / 2),
57 | jnp.where(betamin == -jnp.inf, beta / 2, (beta + betamin) / 2),
58 | )
59 | new_H, new_P = _Hbeta(knn_dists_row, row_self_mask, new_beta)
60 | new_Hdiff = new_H - jnp.log(perplexity)
61 | return _NeighborProbabilityState(
62 | H=new_H, P=new_P, Hdiff=new_Hdiff, beta=new_beta, betamin=new_betamin, betamax=new_betamax, tries=tries + 1
63 | )
64 |
65 | def _get_neighbor_probability_convergence(state):
66 | Hdiff, tries = state.Hdiff, state.tries
67 | return jnp.logical_and(jnp.abs(Hdiff) >= tol, tries < 50)
68 |
69 | init_state = _NeighborProbabilityState(H=H, P=P, Hdiff=Hdiff, beta=beta, betamin=betamin, betamax=betamax, tries=0)
70 | final_state = jax.lax.while_loop(_get_neighbor_probability_convergence, _get_neighbor_probability_step, init_state)
71 | return final_state.H, final_state.P
72 |
73 |
74 | def _compute_simpson_index_cell(
75 | knn_dists_row: jnp.ndarray,
76 | knn_labels_row: jnp.ndarray,
77 | row_self_mask: jnp.ndarray,
78 | n_batches: int,
79 | perplexity: float,
80 | tol: float,
81 | ) -> jnp.ndarray:
82 | H, P = _get_neighbor_probability(knn_dists_row, row_self_mask, perplexity, tol)
83 |
84 | def _non_zero_H_simpson():
85 | sumP = jnp.bincount(knn_labels_row, weights=P, length=n_batches)
86 | return jnp.where(knn_labels_row.shape[0] == P.shape[0], jnp.dot(sumP, sumP), 1)
87 |
88 | return jnp.where(H == 0, -1, _non_zero_H_simpson())
89 |
90 |
91 | def compute_simpson_index(
92 | knn_dists: NdArray,
93 | knn_idx: NdArray,
94 | row_idx: NdArray,
95 | labels: NdArray,
96 | n_labels: int,
97 | perplexity: float = 30,
98 | tol: float = 1e-5,
99 | ) -> np.ndarray:
100 | """Compute the Simpson index for each cell.
101 |
102 | Parameters
103 | ----------
104 | knn_dists
105 | KNN distances of size (n_cells, n_neighbors).
106 | knn_idx
107 | KNN indices of size (n_cells, n_neighbors) corresponding to distances.
108 | row_idx
109 | Idx of each row (n_cells, 1).
110 | labels
111 | Cell labels of size (n_cells,).
112 | n_labels
113 | Number of labels.
114 | perplexity
115 | Measure of the effective number of neighbors.
116 | tol
117 | Tolerance for binary search.
118 |
119 | Returns
120 | -------
121 | simpson_index
122 | Simpson index of size (n_cells,).
123 | """
124 | knn_dists = jnp.array(knn_dists)
125 | knn_idx = jnp.array(knn_idx)
126 | labels = jnp.array(labels)
127 | row_idx = jnp.array(row_idx)
128 | knn_labels = labels[knn_idx]
129 | self_mask = knn_idx != row_idx
130 | simpson_fn = partial(_compute_simpson_index_cell, n_batches=n_labels, perplexity=perplexity, tol=tol)
131 | out = jax.vmap(simpson_fn)(knn_dists, knn_labels, self_mask)
132 | return get_ndarray(out)
133 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_pca.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from chex import dataclass
3 | from jax import jit
4 |
5 | from scib_metrics._types import NdArray
6 |
7 | from ._utils import get_ndarray
8 |
9 |
10 | @dataclass
11 | class _SVDResult:
12 | """SVD result.
13 |
14 | Attributes
15 | ----------
16 | u
17 | Array of shape (n_cells, n_components) containing the left singular vectors.
18 | s
19 | Array of shape (n_components,) containing the singular values.
20 | v
21 | Array of shape (n_components, n_features) containing the right singular vectors.
22 | """
23 |
24 | u: NdArray
25 | s: NdArray
26 | v: NdArray
27 |
28 |
29 | @dataclass
30 | class _PCAResult:
31 | """PCA result.
32 |
33 | Attributes
34 | ----------
35 | coordinates
36 | Array of shape (n_cells, n_components) containing the PCA coordinates.
37 | components
38 | Array of shape (n_components, n_features) containing the PCA components.
39 | variance
40 | Array of shape (n_components,) containing the explained variance of each PC.
41 | variance_ratio
42 | Array of shape (n_components,) containing the explained variance ratio of each PC.
43 | svd
44 | Dataclass containing the SVD data.
45 | """
46 |
47 | coordinates: NdArray
48 | components: NdArray
49 | variance: NdArray
50 | variance_ratio: NdArray
51 | svd: _SVDResult | None = None
52 |
53 |
54 | def _svd_flip(
55 | u: NdArray,
56 | v: NdArray,
57 | u_based_decision: bool = True,
58 | ):
59 | """Sign correction to ensure deterministic output from SVD.
60 |
61 | Jax implementation of :func:`~sklearn.utils.extmath.svd_flip`.
62 |
63 | Parameters
64 | ----------
65 | u
66 | Left singular vectors of shape (M, K).
67 | v
68 | Right singular vectors of shape (K, N).
69 | u_based_decision
70 | If True, use the columns of u as the basis for sign flipping.
71 | """
72 | if u_based_decision:
73 | max_abs_cols = jnp.argmax(jnp.abs(u), axis=0)
74 | signs = jnp.sign(u[max_abs_cols, jnp.arange(u.shape[1])])
75 | else:
76 | max_abs_rows = jnp.argmax(jnp.abs(v), axis=1)
77 | signs = jnp.sign(v[jnp.arange(v.shape[0]), max_abs_rows])
78 | u_ = u * signs
79 | v_ = v * signs[:, None]
80 | return u_, v_
81 |
82 |
83 | def pca(
84 | X: NdArray,
85 | n_components: int | None = None,
86 | return_svd: bool = False,
87 | ) -> _PCAResult:
88 | """Principal component analysis (PCA).
89 |
90 | Parameters
91 | ----------
92 | X
93 | Array of shape (n_cells, n_features).
94 | n_components
95 | Number of components to keep. If None, all components are kept.
96 | return_svd
97 | If True, also return the results from SVD.
98 |
99 | Returns
100 | -------
101 | results: _PCAData
102 | """
103 | max_components = min(X.shape)
104 | if n_components and n_components > max_components:
105 | raise ValueError(f"n_components = {n_components} must be <= min(n_cells, n_features) = {max_components}")
106 | n_components = n_components or max_components
107 |
108 | u, s, v, variance, variance_ratio = _pca(X)
109 |
110 | # Select n_components
111 | coordinates = u[:, :n_components] * s[:n_components]
112 | components = v[:n_components]
113 | variance_ = variance[:n_components]
114 | variance_ratio_ = variance_ratio[:n_components]
115 |
116 | results = _PCAResult(
117 | coordinates=get_ndarray(coordinates),
118 | components=get_ndarray(components),
119 | variance=get_ndarray(variance_),
120 | variance_ratio=get_ndarray(variance_ratio_),
121 | svd=_SVDResult(u=get_ndarray(u), s=get_ndarray(s), v=get_ndarray(v)) if return_svd else None,
122 | )
123 | return results
124 |
125 |
126 | @jit
127 | def _pca(
128 | X: NdArray,
129 | ) -> tuple[NdArray, NdArray, NdArray, NdArray, NdArray]:
130 | """Principal component analysis.
131 |
132 | Parameters
133 | ----------
134 | X
135 | Array of shape (n_cells, n_features).
136 |
137 | Returns
138 | -------
139 | u: NdArray
140 | Left singular vectors of shape (M, K).
141 | s: NdArray
142 | Singular values of shape (K,).
143 | v: NdArray
144 | Right singular vectors of shape (K, N).
145 | variance: NdArray
146 | Array of shape (K,) containing the explained variance of each PC.
147 | variance_ratio: NdArray
148 | Array of shape (K,) containing the explained variance ratio of each PC.
149 | """
150 | X_ = X - jnp.mean(X, axis=0)
151 | u, s, v = jnp.linalg.svd(X_, full_matrices=False)
152 | u, v = _svd_flip(u, v)
153 |
154 | variance = (s**2) / (X.shape[0] - 1)
155 | total_variance = jnp.sum(variance)
156 | variance_ratio = variance / total_variance
157 |
158 | return u, s, v, variance, variance_ratio
159 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_pcr.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import numpy as np
3 | import pandas as pd
4 | from jax import jit
5 |
6 | from scib_metrics._types import NdArray
7 |
8 | from ._pca import pca
9 | from ._utils import one_hot
10 |
11 |
12 | def principal_component_regression(
13 | X: NdArray,
14 | covariate: NdArray,
15 | categorical: bool = False,
16 | n_components: int | None = None,
17 | ) -> float:
18 | """Principal component regression (PCR) :cite:p:`buttner2018`.
19 |
20 | Parameters
21 | ----------
22 | X
23 | Array of shape (n_cells, n_features).
24 | covariate
25 | Array of shape (n_cells,) or (n_cells, 1) representing batch/covariate values.
26 | categorical
27 | If True, batch will be treated as categorical and one-hot encoded.
28 | n_components:
29 | Number of components to compute, passed into :func:`~scib_metrics.utils.pca`.
30 | If None, all components are used.
31 |
32 | Returns
33 | -------
34 | pcr: float
35 | Principal component regression using the first n_components principal components.
36 | """
37 | if len(X.shape) != 2:
38 | raise ValueError("Dimension mismatch: X must be 2-dimensional.")
39 | if X.shape[0] != covariate.shape[0]:
40 | raise ValueError("Dimension mismatch: X and batch must have the same number of samples.")
41 | if categorical:
42 | covariate = np.asarray(pd.Categorical(covariate).codes)
43 | else:
44 | covariate = np.asarray(covariate)
45 |
46 | covariate = one_hot(covariate) if categorical else covariate.reshape((covariate.shape[0], 1))
47 |
48 | pca_results = pca(X, n_components=n_components)
49 |
50 | # Center inputs for no intercept
51 | covariate = covariate - jnp.mean(covariate, axis=0)
52 | pcr = _pcr(pca_results.coordinates, covariate, pca_results.variance)
53 | return float(pcr)
54 |
55 |
56 | @jit
57 | def _pcr(
58 | X_pca: NdArray,
59 | covariate: NdArray,
60 | var: NdArray,
61 | ) -> NdArray:
62 | """Principal component regression.
63 |
64 | Parameters
65 | ----------
66 | X_pca
67 | Array of shape (n_cells, n_components) containing PCA coordinates. Must be standardized.
68 | covariate
69 | Array of shape (n_cells, 1) or (n_cells, n_classes) containing batch/covariate values. Must be standardized
70 | if not categorical (one-hot).
71 | var
72 | Array of shape (n_components,) containing the explained variance of each PC.
73 | """
74 | residual_sum = jnp.linalg.lstsq(covariate, X_pca)[1]
75 | total_sum = jnp.sum((X_pca - jnp.mean(X_pca, axis=0, keepdims=True)) ** 2, axis=0)
76 | r2 = jnp.maximum(0, 1 - residual_sum / total_sum)
77 |
78 | return jnp.dot(jnp.ravel(r2), var) / jnp.sum(var)
79 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_silhouette.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from ._dist import cdist
9 | from ._utils import get_ndarray
10 |
11 |
12 | @jax.jit
13 | def _silhouette_reduce(
14 | D_chunk: jnp.ndarray, start: int, labels: jnp.ndarray, label_freqs: jnp.ndarray
15 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
16 | """Accumulate silhouette statistics for vertical chunk of X.
17 |
18 | Follows scikit-learn implementation.
19 |
20 | Parameters
21 | ----------
22 | D_chunk
23 | Array of shape (n_chunk_samples, n_samples)
24 | Precomputed distances for a chunk.
25 | start
26 | First index in the chunk.
27 | labels
28 | Array of shape (n_samples,)
29 | Corresponding cluster labels, encoded as {0, ..., n_clusters-1}.
30 | label_freqs
31 | Distribution of cluster labels in ``labels``.
32 | """
33 | # accumulate distances from each sample to each cluster
34 | D_chunk_len = D_chunk.shape[0]
35 |
36 | # If running into memory issues, use fori_loop instead of vmap
37 | # clust_dists = jnp.zeros((D_chunk_len, len(label_freqs)), dtype=D_chunk.dtype)
38 | # def _bincount(i, _data):
39 | # clust_dists, D_chunk, labels, label_freqs = _data
40 | # clust_dists = clust_dists.at[i].set(jnp.bincount(labels, weights=D_chunk[i], length=label_freqs.shape[0]))
41 | # return clust_dists, D_chunk, labels, label_freqs
42 |
43 | # clust_dists = jax.lax.fori_loop(
44 | # 0, D_chunk_len, lambda i, _data: _bincount(i, _data), (clust_dists, D_chunk, labels, label_freqs)
45 | # )[0]
46 |
47 | clust_dists = jax.vmap(partial(jnp.bincount, length=label_freqs.shape[0]), in_axes=(None, 0))(labels, D_chunk)
48 |
49 | # intra_index selects intra-cluster distances within clust_dists
50 | intra_index = (jnp.arange(D_chunk_len), jax.lax.dynamic_slice(labels, (start,), (D_chunk_len,)))
51 | # intra_clust_dists are averaged over cluster size outside this function
52 | intra_clust_dists = clust_dists[intra_index]
53 | # of the remaining distances we normalise and extract the minimum
54 | clust_dists = clust_dists.at[intra_index].set(jnp.inf)
55 | clust_dists /= label_freqs
56 | inter_clust_dists = clust_dists.min(axis=1)
57 | return intra_clust_dists, inter_clust_dists
58 |
59 |
60 | def _pairwise_distances_chunked(X: jnp.ndarray, chunk_size: int, reduce_fn: callable) -> jnp.ndarray:
61 | """Compute pairwise distances in chunks to reduce memory usage."""
62 | n_samples = X.shape[0]
63 | n_chunks = jnp.ceil(n_samples / chunk_size).astype(int)
64 | intra_dists_all = []
65 | inter_dists_all = []
66 | for i in range(n_chunks):
67 | start = i * chunk_size
68 | end = min((i + 1) * chunk_size, n_samples)
69 | intra_cluster_dists, inter_cluster_dists = reduce_fn(cdist(X[start:end], X), start=start)
70 | intra_dists_all.append(intra_cluster_dists)
71 | inter_dists_all.append(inter_cluster_dists)
72 | return jnp.concatenate(intra_dists_all), jnp.concatenate(inter_dists_all)
73 |
74 |
75 | def silhouette_samples(X: np.ndarray, labels: np.ndarray, chunk_size: int = 256) -> np.ndarray:
76 | """Compute the Silhouette Coefficient for each observation.
77 |
78 | Implements :func:`sklearn.metrics.silhouette_samples`.
79 |
80 | Parameters
81 | ----------
82 | X
83 | Array of shape (n_cells, n_features) representing a
84 | feature array.
85 | labels
86 | Array of shape (n_cells,) representing label values
87 | for each observation.
88 | chunk_size
89 | Number of samples to process at a time for distance computation.
90 |
91 | Returns
92 | -------
93 | silhouette scores array of shape (n_cells,)
94 | """
95 | if X.shape[0] != labels.shape[0]:
96 | raise ValueError("X and labels should have the same number of samples")
97 | labels = pd.Categorical(labels).codes
98 | labels = jnp.asarray(labels)
99 | label_freqs = jnp.bincount(labels)
100 | reduce_fn = partial(_silhouette_reduce, labels=labels, label_freqs=label_freqs)
101 | results = _pairwise_distances_chunked(X, chunk_size=chunk_size, reduce_fn=reduce_fn)
102 | intra_clust_dists, inter_clust_dists = results
103 |
104 | denom = jnp.take(label_freqs - 1, labels, mode="clip")
105 | intra_clust_dists /= denom
106 | sil_samples = inter_clust_dists - intra_clust_dists
107 | sil_samples /= jnp.maximum(intra_clust_dists, inter_clust_dists)
108 | return get_ndarray(jnp.nan_to_num(sil_samples))
109 |
--------------------------------------------------------------------------------
/src/scib_metrics/utils/_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | from chex import ArrayDevice
7 | from jax import Array, nn
8 | from scipy.sparse import csr_matrix
9 | from sklearn.neighbors import NearestNeighbors
10 | from sklearn.utils import check_array
11 |
12 | from scib_metrics._types import ArrayLike, IntOrKey, NdArray
13 |
14 |
15 | def get_ndarray(x: ArrayDevice) -> np.ndarray:
16 | """Convert Jax device array to Numpy array."""
17 | return np.array(jax.device_get(x))
18 |
19 |
20 | def one_hot(y: NdArray, n_classes: int | None = None) -> jnp.ndarray:
21 | """One-hot encode an array. Wrapper around :func:`~jax.nn.one_hot`.
22 |
23 | Parameters
24 | ----------
25 | y
26 | Array of shape (n_cells,) or (n_cells, 1).
27 | n_classes
28 | Number of classes. If None, inferred from the data.
29 |
30 | Returns
31 | -------
32 | one_hot: jnp.ndarray
33 | Array of shape (n_cells, n_classes).
34 | """
35 | n_classes = n_classes or int(jax.device_get(jnp.max(y))) + 1
36 | return nn.one_hot(jnp.ravel(y), n_classes)
37 |
38 |
39 | def validate_seed(seed: IntOrKey) -> Array:
40 | """Validate a seed and return a Jax random key."""
41 | return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed
42 |
43 |
44 | def check_square(X: ArrayLike):
45 | """Check if a matrix is square."""
46 | if X.shape[0] != X.shape[1]:
47 | raise ValueError("X must be a square matrix")
48 |
49 |
50 | def convert_knn_graph_to_idx(X: csr_matrix) -> tuple[np.ndarray, np.ndarray]:
51 | """Convert a kNN graph to indices and distances."""
52 | check_array(X, accept_sparse="csr")
53 | check_square(X)
54 |
55 | n_neighbors = np.unique(X.nonzero()[0], return_counts=True)[1]
56 | if len(np.unique(n_neighbors)) > 1:
57 | raise ValueError("Each cell must have the same number of neighbors.")
58 |
59 | n_neighbors = int(np.unique(n_neighbors)[0])
60 | with warnings.catch_warnings():
61 | warnings.filterwarnings("ignore", message="Precomputed sparse input")
62 | nn_obj = NearestNeighbors(n_neighbors=n_neighbors, metric="precomputed").fit(X)
63 | kneighbors = nn_obj.kneighbors(X)
64 | return kneighbors
65 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_benchmarker.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from scib_metrics.benchmark import BatchCorrection, Benchmarker, BioConservation
4 | from scib_metrics.nearest_neighbors import jax_approx_min_k
5 | from tests.utils.data import dummy_benchmarker_adata
6 |
7 |
8 | def test_benchmarker():
9 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
10 | bm = Benchmarker(
11 | ad,
12 | batch_key,
13 | labels_key,
14 | emb_keys,
15 | batch_correction_metrics=BatchCorrection(),
16 | bio_conservation_metrics=BioConservation(),
17 | )
18 | bm.benchmark()
19 | results = bm.get_results()
20 | assert isinstance(results, pd.DataFrame)
21 | bm.plot_results_table()
22 |
23 |
24 | def test_benchmarker_default():
25 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
26 | bm = Benchmarker(
27 | ad,
28 | batch_key,
29 | labels_key,
30 | emb_keys,
31 | )
32 | bm.benchmark()
33 | results = bm.get_results()
34 | assert isinstance(results, pd.DataFrame)
35 | bm.plot_results_table()
36 |
37 |
38 | def test_benchmarker_custom_metric_booleans():
39 | bioc = BioConservation(
40 | isolated_labels=False, nmi_ari_cluster_labels_leiden=False, silhouette_label=False, clisi_knn=True
41 | )
42 | bc = BatchCorrection(kbet_per_label=False, graph_connectivity=False, ilisi_knn=True)
43 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
44 | bm = Benchmarker(ad, batch_key, labels_key, emb_keys, batch_correction_metrics=bc, bio_conservation_metrics=bioc)
45 | bm.benchmark()
46 | results = bm.get_results(clean_names=False)
47 | assert isinstance(results, pd.DataFrame)
48 | assert "isolated_labels" not in results.columns
49 | assert "nmi_ari_cluster_labels_leiden" not in results.columns
50 | assert "silhouette_label" not in results.columns
51 | assert "clisi_knn" in results.columns
52 | assert "kbet_per_label" not in results.columns
53 | assert "graph_connectivity" not in results.columns
54 | assert "ilisi_knn" in results.columns
55 |
56 |
57 | def test_benchmarker_custom_metric_callable():
58 | bioc = BioConservation(clisi_knn={"perplexity": 10})
59 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
60 | bm = Benchmarker(
61 | ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc, batch_correction_metrics=BatchCorrection()
62 | )
63 | bm.benchmark()
64 | results = bm.get_results(clean_names=False)
65 | assert "clisi_knn" in results.columns
66 |
67 |
68 | def test_benchmarker_custom_near_neighs():
69 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
70 | bm = Benchmarker(
71 | ad,
72 | batch_key,
73 | labels_key,
74 | emb_keys,
75 | bio_conservation_metrics=BioConservation(),
76 | batch_correction_metrics=BatchCorrection(),
77 | )
78 | bm.prepare(neighbor_computer=jax_approx_min_k)
79 | bm.benchmark()
80 | results = bm.get_results()
81 | assert isinstance(results, pd.DataFrame)
82 | bm.plot_results_table()
83 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import anndata
2 | import igraph
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import pandas as pd
6 | import pytest
7 | import scanpy as sc
8 | from harmonypy import compute_lisi as harmonypy_lisi
9 | from scib.metrics import isolated_labels_asw
10 | from scipy.spatial.distance import cdist as sp_cdist
11 | from scipy.spatial.distance import pdist, squareform
12 | from sklearn.cluster import KMeans as SKMeans
13 | from sklearn.datasets import make_blobs
14 | from sklearn.metrics import silhouette_samples as sk_silhouette_samples
15 | from sklearn.metrics.pairwise import pairwise_distances_argmin
16 | from sklearn.neighbors import NearestNeighbors
17 |
18 | import scib_metrics
19 | from scib_metrics.nearest_neighbors import NeighborsResults
20 | from tests.utils.data import dummy_x_labels, dummy_x_labels_batch
21 |
22 | scib_metrics.settings.jax_fix_no_kernel_image()
23 |
24 |
25 | def test_package_has_version():
26 | scib_metrics.__version__
27 |
28 |
29 | def test_cdist():
30 | x = jnp.array([[1, 2], [3, 4]])
31 | y = jnp.array([[5, 6], [7, 8]])
32 | assert np.allclose(scib_metrics.utils.cdist(x, y), sp_cdist(x, y))
33 |
34 |
35 | def test_pdist():
36 | x = jnp.array([[1, 2], [3, 4]])
37 | assert np.allclose(scib_metrics.utils.pdist_squareform(x), squareform(pdist(x)))
38 |
39 |
40 | def test_silhouette_samples():
41 | X, labels = dummy_x_labels()
42 | assert np.allclose(scib_metrics.utils.silhouette_samples(X, labels), sk_silhouette_samples(X, labels), atol=1e-5)
43 |
44 |
45 | def test_silhouette_label():
46 | X, labels = dummy_x_labels()
47 | score = scib_metrics.silhouette_label(X, labels)
48 | assert score > 0
49 | scib_metrics.silhouette_label(X, labels, rescale=False)
50 |
51 |
52 | def test_silhouette_batch():
53 | X, labels, batch = dummy_x_labels_batch()
54 | score = scib_metrics.silhouette_batch(X, labels, batch)
55 | assert score > 0
56 | scib_metrics.silhouette_batch(X, labels, batch)
57 |
58 |
59 | def test_compute_simpson_index():
60 | X, labels = dummy_x_labels()
61 | D = scib_metrics.utils.cdist(X, X)
62 | nbrs = NearestNeighbors(n_neighbors=30, algorithm="kd_tree").fit(X)
63 | D, knn_idx = nbrs.kneighbors(X)
64 | row_idx = np.arange(X.shape[0])[:, None]
65 | scib_metrics.utils.compute_simpson_index(
66 | jnp.array(D), jnp.array(knn_idx), jnp.array(row_idx), jnp.array(labels), len(np.unique(labels))
67 | )
68 |
69 |
70 | @pytest.mark.parametrize("n_neighbors", [30, 60, 72])
71 | def test_lisi_knn(n_neighbors):
72 | perplexity = n_neighbors // 3
73 | X, labels = dummy_x_labels()
74 | nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="kd_tree").fit(X)
75 | dists, inds = nbrs.kneighbors(X)
76 | neigh_results = NeighborsResults(indices=inds, distances=dists)
77 | lisi_res = scib_metrics.lisi_knn(neigh_results, labels, perplexity=perplexity)
78 | harmonypy_lisi_res = harmonypy_lisi(
79 | X, pd.DataFrame(labels, columns=["labels"]), label_colnames=["labels"], perplexity=perplexity
80 | )[:, 0]
81 | np.testing.assert_allclose(lisi_res, harmonypy_lisi_res, rtol=5e-5, atol=5e-5)
82 |
83 |
84 | def test_ilisi_clisi_knn():
85 | X, labels, batches = dummy_x_labels_batch(x_is_neighbors_results=True)
86 | scib_metrics.ilisi_knn(X, batches, perplexity=10)
87 | scib_metrics.clisi_knn(X, labels, perplexity=10)
88 |
89 |
90 | def test_nmi_ari_cluster_labels_kmeans():
91 | X, labels = dummy_x_labels()
92 | out = scib_metrics.nmi_ari_cluster_labels_kmeans(X, labels)
93 | nmi, ari = out["nmi"], out["ari"]
94 | assert isinstance(nmi, float)
95 | assert isinstance(ari, float)
96 |
97 |
98 | def test_nmi_ari_cluster_labels_leiden_parallel():
99 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
100 | out = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=True, n_jobs=2)
101 | nmi, ari = out["nmi"], out["ari"]
102 | assert isinstance(nmi, float)
103 | assert isinstance(ari, float)
104 |
105 |
106 | def test_nmi_ari_cluster_labels_leiden_single_resolution():
107 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
108 | out = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=0.1)
109 | nmi, ari = out["nmi"], out["ari"]
110 | assert isinstance(nmi, float)
111 | assert isinstance(ari, float)
112 |
113 |
114 | def test_nmi_ari_cluster_labels_leiden_reproducibility():
115 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
116 | out1 = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=3.0)
117 | out2 = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=3.0)
118 | nmi1, ari1 = out1["nmi"], out1["ari"]
119 | nmi2, ari2 = out2["nmi"], out2["ari"]
120 | assert nmi1 == nmi2
121 | assert ari1 == ari2
122 |
123 |
124 | def test_leiden_graph_construction():
125 | X, _ = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
126 | conn_graph = X.knn_graph_connectivities
127 | g = igraph.Graph.Weighted_Adjacency(conn_graph, mode="directed")
128 | g.to_undirected(mode="each")
129 | sc_g = sc._utils.get_igraph_from_adjacency(conn_graph, directed=False)
130 | assert g.isomorphic(sc_g)
131 | np.testing.assert_equal(g.es["weight"], sc_g.es["weight"])
132 |
133 |
134 | def test_isolated_labels():
135 | X, labels, batch = dummy_x_labels_batch()
136 | pred = scib_metrics.isolated_labels(X, labels, batch)
137 | adata = anndata.AnnData(X)
138 | adata.obsm["embed"] = X
139 | adata.obs["batch"] = batch
140 | adata.obs["labels"] = labels
141 | target = isolated_labels_asw(adata, "labels", "batch", "embed", iso_threshold=5)
142 | np.testing.assert_allclose(np.array(pred), np.array(target))
143 |
144 |
145 | def test_kmeans():
146 | centers = np.array([[1, 1], [-1, -1], [1, -1]]) * 2
147 | len(centers)
148 | X, labels_true = make_blobs(n_samples=3000, centers=centers, cluster_std=0.7, random_state=42)
149 |
150 | kmeans = scib_metrics.utils.KMeans(n_clusters=3)
151 | kmeans.fit(X)
152 | assert kmeans.labels_.shape == (X.shape[0],)
153 |
154 | skmeans = SKMeans(n_clusters=3)
155 | skmeans.fit(X)
156 | sk_inertia = np.array([skmeans.inertia_])
157 | jax_inertia = np.array([kmeans.inertia_])
158 | np.testing.assert_allclose(sk_inertia, jax_inertia, atol=4e-2)
159 |
160 | # Reorder cluster centroids between methods and measure accuracy
161 | k_means_cluster_centers = kmeans.cluster_centroids_
162 | order = pairwise_distances_argmin(kmeans.cluster_centroids_, skmeans.cluster_centers_)
163 | sk_means_cluster_centers = skmeans.cluster_centers_[order]
164 |
165 | k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers)
166 | sk_means_labels = pairwise_distances_argmin(X, sk_means_cluster_centers)
167 |
168 | accuracy = (k_means_labels == sk_means_labels).sum() / len(k_means_labels)
169 | assert accuracy > 0.995
170 |
171 |
172 | def test_kbet():
173 | X, _, batch = dummy_x_labels_batch(x_is_neighbors_results=True)
174 | acc_rate, stats, pvalues = scib_metrics.kbet(X, batch)
175 | assert isinstance(acc_rate, float)
176 | assert len(stats) == X.indices.shape[0]
177 | assert len(pvalues) == X.indices.shape[0]
178 |
179 |
180 | def test_kbet_per_label():
181 | X, labels, batch = dummy_x_labels_batch(x_is_neighbors_results=True)
182 | score = scib_metrics.kbet_per_label(X, batch, labels)
183 | assert isinstance(score, float)
184 |
185 |
186 | def test_graph_connectivity():
187 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
188 | metric = scib_metrics.graph_connectivity(X, labels)
189 | assert isinstance(metric, float)
190 |
--------------------------------------------------------------------------------
/tests/test_neighbors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import scanpy as sc
4 |
5 | from scib_metrics.nearest_neighbors import jax_approx_min_k, pynndescent
6 | from tests.utils.data import dummy_benchmarker_adata
7 |
8 |
9 | def test_jax_neighbors():
10 | ad, emb_keys, _, _ = dummy_benchmarker_adata()
11 | output = jax_approx_min_k(ad.obsm[emb_keys[0]], 10)
12 | assert output.distances.shape == (ad.n_obs, 10)
13 |
14 |
15 | @pytest.mark.parametrize("n", [5, 10, 20, 21])
16 | def test_neighbors_results(n):
17 | adata, embedding_keys, *_ = dummy_benchmarker_adata()
18 | neigh_result = pynndescent(adata.obsm[embedding_keys[0]], n_neighbors=n)
19 | neigh_result = neigh_result.subset_neighbors(n=n)
20 | new_connect = neigh_result.knn_graph_connectivities
21 |
22 | sc_connect = sc.neighbors._connectivity.umap(
23 | neigh_result.indices[:, :n], neigh_result.distances[:, :n], n_obs=adata.n_obs, n_neighbors=n
24 | )
25 |
26 | np.testing.assert_allclose(new_connect.toarray(), sc_connect.toarray())
27 |
--------------------------------------------------------------------------------
/tests/test_pcr_comparison.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import pytest
4 |
5 | import scib_metrics
6 | from tests.utils.sampling import categorical_sample, normal_sample, poisson_sample
7 |
8 | PCR_COMPARISON_PARAMS = list(product([100], [100], [False, True]))
9 |
10 |
11 | @pytest.mark.parametrize("n_obs, n_vars, categorical", PCR_COMPARISON_PARAMS)
12 | def test_pcr_comparison(n_obs, n_vars, categorical):
13 | X_pre = poisson_sample(n_obs, n_vars, seed=0)
14 | X_post = poisson_sample(n_obs, n_vars, seed=1)
15 | covariate = categorical_sample(n_obs, int(n_obs / 5)) if categorical else normal_sample(n_obs, seed=0)
16 |
17 | score = scib_metrics.pcr_comparison(X_pre, X_post, covariate, scale=True)
18 | assert score >= 0 and score <= 1
19 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/data.py:
--------------------------------------------------------------------------------
1 | import anndata
2 | import numpy as np
3 | from scipy.sparse import csr_matrix
4 | from sklearn.neighbors import NearestNeighbors
5 |
6 | import scib_metrics
7 | from scib_metrics.nearest_neighbors import NeighborsResults
8 |
9 |
10 | def dummy_x_labels(symmetric_positive=False, x_is_neighbors_results=False):
11 | rng = np.random.default_rng(seed=42)
12 | X = rng.normal(size=(100, 10))
13 | labels = rng.integers(0, 4, size=(100,))
14 | if symmetric_positive:
15 | X = np.abs(X @ X.T)
16 | if x_is_neighbors_results:
17 | dist_mat = csr_matrix(scib_metrics.utils.cdist(X, X))
18 | nbrs = NearestNeighbors(n_neighbors=30, metric="precomputed").fit(dist_mat)
19 | dist, ind = nbrs.kneighbors(dist_mat)
20 | X = NeighborsResults(indices=ind, distances=dist)
21 | return X, labels
22 |
23 |
24 | def dummy_x_labels_batch(x_is_neighbors_results=False):
25 | rng = np.random.default_rng(seed=43)
26 | X, labels = dummy_x_labels(x_is_neighbors_results=x_is_neighbors_results)
27 | batch = rng.integers(0, 4, size=(100,))
28 | return X, labels, batch
29 |
30 |
31 | def dummy_benchmarker_adata():
32 | X, labels, batch = dummy_x_labels_batch(x_is_neighbors_results=False)
33 | adata = anndata.AnnData(X)
34 | labels_key = "labels"
35 | batch_key = "batch"
36 | adata.obs[labels_key] = labels
37 | adata.obs[batch_key] = batch
38 | embedding_keys = []
39 | for i in range(5):
40 | key = f"X_emb_{i}"
41 | adata.obsm[key] = X
42 | embedding_keys.append(key)
43 | return adata, embedding_keys, labels_key, batch_key
44 |
--------------------------------------------------------------------------------
/tests/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax import Array
4 |
5 | IntOrKey = int | Array
6 |
7 |
8 | def _validate_seed(seed: IntOrKey) -> Array:
9 | return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed
10 |
11 |
12 | def categorical_sample(
13 | n_obs: int,
14 | n_cats: int,
15 | seed: IntOrKey = 0,
16 | ) -> jnp.ndarray:
17 | return jax.random.categorical(_validate_seed(seed), jnp.ones(n_cats), shape=(n_obs,))
18 |
19 |
20 | def normal_sample(
21 | n_obs: int,
22 | mean: float = 0.0,
23 | var: float = 1.0,
24 | seed: IntOrKey = 0,
25 | ) -> jnp.ndarray:
26 | return jax.random.multivariate_normal(_validate_seed(seed), jnp.ones(n_obs) * mean, jnp.eye(n_obs) * var)
27 |
28 |
29 | def poisson_sample(
30 | n_obs: int,
31 | n_vars: int,
32 | rate: float = 1.0,
33 | seed: IntOrKey = 0,
34 | ) -> jnp.ndarray:
35 | return jax.random.poisson(_validate_seed(seed), rate, shape=(n_obs, n_vars))
36 |
--------------------------------------------------------------------------------
/tests/utils/test_pca.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import pytest
6 | from sklearn.decomposition import PCA
7 |
8 | import scib_metrics
9 | from tests.utils.sampling import poisson_sample
10 |
11 | PCA_PARAMS = list(product([10, 100, 1000], [10, 100, 1000]))
12 |
13 |
14 | @pytest.mark.parametrize("n_obs, n_vars", PCA_PARAMS)
15 | def test_pca(n_obs: int, n_vars: int):
16 | def _test_pca(n_obs: int, n_vars: int, n_components: int, eps: float = 1e-4):
17 | X = poisson_sample(n_obs, n_vars)
18 | max_components = min(X.shape)
19 | pca = scib_metrics.utils.pca(X, n_components=n_components, return_svd=True)
20 |
21 | # SANITY CHECKS
22 | assert pca.coordinates.shape == (X.shape[0], n_components)
23 | assert pca.components.shape == (n_components, X.shape[1])
24 | assert pca.variance.shape == (n_components,)
25 | assert pca.variance_ratio.shape == (n_components,)
26 | # SVD should not be truncated to n_components
27 | assert pca.svd is not None
28 | assert pca.svd.u.shape == (X.shape[0], max_components)
29 | assert pca.svd.s.shape == (max_components,)
30 | assert pca.svd.v.shape == (max_components, X.shape[1])
31 |
32 | # VALUE CHECKS
33 | # TODO(martinkim0): Currently not checking coordinates and components,
34 | # TODO(martinkim0): implementations differ very slightly and not sure why.
35 | pca_true = PCA(n_components=n_components, svd_solver="full").fit(X)
36 | # assert np.allclose(pca_true.transform(X), pca.coordinates, atol=eps)
37 | # assert np.allclose(pca_true.components_, pca.components, atol=eps)
38 | assert np.allclose(pca_true.singular_values_, pca.svd.s[:n_components], atol=eps)
39 | assert np.allclose(pca_true.explained_variance_, pca.variance, atol=eps)
40 | assert np.allclose(pca_true.explained_variance_ratio_, pca.variance_ratio, atol=eps)
41 | # Use arpack iff n_components < max_components
42 | if n_components < max_components:
43 | pca_true = PCA(n_components=n_components, svd_solver="arpack").fit(X)
44 | # assert jnp.allclose(pca_true.transform(X), pca.coordinates, atol=eps)
45 | # assert jnp.allclose(pca_true.components_, pca.components, atol=eps)
46 | assert jnp.allclose(pca_true.singular_values_, pca.svd.s[:n_components], atol=eps)
47 | assert jnp.allclose(pca_true.explained_variance_, pca.variance, atol=eps)
48 | assert jnp.allclose(pca_true.explained_variance_ratio_, pca.variance_ratio, atol=eps)
49 |
50 | max_components = min(n_obs, n_vars)
51 | _test_pca(n_obs, n_vars, n_components=max_components)
52 | _test_pca(n_obs, n_vars, n_components=max_components - 1)
53 | _test_pca(n_obs, n_vars, n_components=int(max_components / 2))
54 | _test_pca(n_obs, n_vars, n_components=1)
55 |
--------------------------------------------------------------------------------
/tests/utils/test_pcr.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from scib.metrics import pc_regression
7 |
8 | import scib_metrics
9 | from scib_metrics.utils import get_ndarray
10 | from tests.utils.sampling import categorical_sample, normal_sample, poisson_sample
11 |
12 | PCR_PARAMS = list(product([10, 100, 1000], [10, 100, 1000], [False]))
13 | # TODO(martinkim0): Currently not testing categorical covariates because of
14 | # TODO(martinkim0): reproducibility issues with original scib. See comment in PR #16.
15 |
16 |
17 | @pytest.mark.parametrize("n_obs, n_vars, categorical", PCR_PARAMS)
18 | def test_pcr(n_obs, n_vars, categorical):
19 | def _test_pcr(n_obs: int, n_vars: int, n_components: int, categorical: bool, eps=1e-3, seed=123):
20 | X = poisson_sample(n_obs, n_vars, seed=seed)
21 | covariate = categorical_sample(n_obs, int(n_obs / 5)) if categorical else normal_sample(n_obs, seed=seed)
22 |
23 | pcr_true = pc_regression(
24 | get_ndarray(X),
25 | pd.Categorical(get_ndarray(covariate)) if categorical else get_ndarray(covariate),
26 | n_comps=n_components,
27 | )
28 | pcr = scib_metrics.utils.principal_component_regression(
29 | X,
30 | covariate,
31 | categorical=categorical,
32 | n_components=n_components,
33 | )
34 | assert np.allclose(pcr_true, pcr, atol=eps)
35 |
36 | max_components = min(n_obs, n_vars)
37 | _test_pcr(n_obs, n_vars, n_components=max_components, categorical=categorical)
38 | _test_pcr(n_obs, n_vars, n_components=max_components - 1, categorical=categorical)
39 | _test_pcr(n_obs, n_vars, n_components=int(max_components / 2), categorical=categorical)
40 | _test_pcr(n_obs, n_vars, n_components=1, categorical=categorical)
41 |
--------------------------------------------------------------------------------