├── .codecov.yaml ├── .cruft.json ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml └── workflows │ ├── build.yaml │ ├── release.yaml │ ├── test_linux.yaml │ ├── test_linux_cuda.yaml │ ├── test_linux_pre.yaml │ ├── test_macos.yaml │ ├── test_macos_m1.yaml │ └── test_windows.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── .gitkeep │ └── css │ │ └── custom.css ├── _templates │ ├── .gitkeep │ ├── autosummary │ │ └── class.rst │ └── class_no_inherited.rst ├── api.md ├── changelog.md ├── conf.py ├── contributing.md ├── extensions │ ├── .gitkeep │ └── typed_returns.py ├── index.md ├── notebooks │ ├── large_scale.ipynb │ └── lung_example.ipynb ├── references.bib ├── references.md ├── template_usage.md └── tutorials.md ├── pyproject.toml ├── setup.py ├── src └── scib_metrics │ ├── __init__.py │ ├── _settings.py │ ├── _types.py │ ├── benchmark │ ├── __init__.py │ └── _core.py │ ├── metrics │ ├── __init__.py │ ├── _graph_connectivity.py │ ├── _isolated_labels.py │ ├── _kbet.py │ ├── _lisi.py │ ├── _nmi_ari.py │ ├── _pcr_comparison.py │ └── _silhouette.py │ ├── nearest_neighbors │ ├── __init__.py │ ├── _dataclass.py │ ├── _jax.py │ └── _pynndescent.py │ └── utils │ ├── __init__.py │ ├── _diffusion_nn.py │ ├── _dist.py │ ├── _kmeans.py │ ├── _lisi.py │ ├── _pca.py │ ├── _pcr.py │ ├── _silhouette.py │ └── _utils.py └── tests ├── __init__.py ├── test_benchmarker.py ├── test_metrics.py ├── test_neighbors.py ├── test_pcr_comparison.py └── utils ├── __init__.py ├── data.py ├── sampling.py ├── test_pca.py └── test_pcr.py /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/scverse/cookiecutter-scverse", 3 | "commit": "87a407a65408d75a949c0b54b19fd287475a56f8", 4 | "checkout": "v0.4.0", 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "scib-metrics", 8 | "package_name": "scib_metrics", 9 | "project_description": "Accelerated and Python-only scIB metrics", 10 | "author_full_name": "Adam Gayoso", 11 | "author_email": "adamgayoso@berkeley.edu", 12 | "github_user": "adamgayoso", 13 | "project_repo": "https://github.com/yoseflab/scib-metrics", 14 | "license": "BSD 3-Clause License", 15 | "_copy_without_render": [ 16 | ".github/workflows/build.yaml", 17 | ".github/workflows/test.yaml", 18 | "docs/_templates/autosummary/**.rst" 19 | ], 20 | "_render_devdocs": false, 21 | "_jinja2_env_vars": { 22 | "lstrip_blocks": true, 23 | "trim_blocks": true 24 | }, 25 | "_template": "https://github.com/scverse/cookiecutter-scverse" 26 | } 27 | }, 28 | "directory": null, 29 | "skip": [ 30 | ".github/workflows/**.yaml", 31 | ".pre-commit-config.yaml", 32 | "pyproject.toml", 33 | "tests", 34 | "src/**/__init__.py", 35 | "src/**/basic.py", 36 | "docs/api.md", 37 | "docs/changelog.md", 38 | "docs/references.bib", 39 | "docs/references.md", 40 | "docs/notebooks/example.ipynb", 41 | "tests", 42 | "src/**/__init__.py", 43 | "src/**/basic.py", 44 | "docs/api.md", 45 | "docs/changelog.md", 46 | "docs/references.bib", 47 | "docs/references.md", 48 | "docs/notebooks/example.ipynb" 49 | ] 50 | } 51 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.{yml,yaml}] 12 | indent_size = 2 13 | 14 | [.cruft.json] 15 | indent_size = 2 16 | 17 | [Makefile] 18 | indent_style = tab 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report something that is broken or incorrect 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 9 | detailing how to provide the necessary information for us to reproduce your bug. In brief: 10 | * Please provide exact steps how to reproduce the bug in a clean Python environment. 11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure. 12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly 13 | available datasets or to share a subset of your data. 14 | 15 | - type: textarea 16 | id: report 17 | attributes: 18 | label: Report 19 | description: A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: versions 25 | attributes: 26 | label: Version information 27 | description: | 28 | Please paste below the output of 29 | 30 | ```python 31 | import session_info 32 | session_info.show(html=False, dependencies=True) 33 | ``` 34 | placeholder: | 35 | ----- 36 | anndata 0.8.0rc2.dev27+ge524389 37 | session_info 1.0.0 38 | ----- 39 | asttokens NA 40 | awkward 1.8.0 41 | backcall 0.2.0 42 | cython_runtime NA 43 | dateutil 2.8.2 44 | debugpy 1.6.0 45 | decorator 5.1.1 46 | entrypoints 0.4 47 | executing 0.8.3 48 | h5py 3.7.0 49 | ipykernel 6.15.0 50 | jedi 0.18.1 51 | mpl_toolkits NA 52 | natsort 8.1.0 53 | numpy 1.22.4 54 | packaging 21.3 55 | pandas 1.4.2 56 | parso 0.8.3 57 | pexpect 4.8.0 58 | pickleshare 0.7.5 59 | pkg_resources NA 60 | prompt_toolkit 3.0.29 61 | psutil 5.9.1 62 | ptyprocess 0.7.0 63 | pure_eval 0.2.2 64 | pydev_ipython NA 65 | pydevconsole NA 66 | pydevd 2.8.0 67 | pydevd_file_utils NA 68 | pydevd_plugins NA 69 | pydevd_tracing NA 70 | pygments 2.12.0 71 | pytz 2022.1 72 | scipy 1.8.1 73 | setuptools 62.5.0 74 | setuptools_scm NA 75 | six 1.16.0 76 | stack_data 0.3.0 77 | tornado 6.1 78 | traitlets 5.3.0 79 | wcwidth 0.2.5 80 | zmq 23.1.0 81 | ----- 82 | IPython 8.4.0 83 | jupyter_client 7.3.4 84 | jupyter_core 4.10.0 85 | ----- 86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0] 87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35 88 | ----- 89 | Session information updated at 2022-07-07 17:55 90 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for scib-metrics 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | package: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 3.11 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.11" 22 | cache: "pip" 23 | cache-dependency-path: "**/pyproject.toml" 24 | - name: Install build dependencies 25 | run: python -m pip install --upgrade pip wheel twine build 26 | - name: Build package 27 | run: python -m build 28 | - name: Check package 29 | run: twine check --strict dist/*.whl 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ 8 | jobs: 9 | release: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/scib-metrics 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.x" 25 | cache: "pip" 26 | - run: pip install build 27 | - run: python -m build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.github/workflows/test_linux.yaml: -------------------------------------------------------------------------------- 1 | name: Test (Linux) 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | test: 15 | runs-on: ${{ matrix.os }} 16 | defaults: 17 | run: 18 | shell: bash -e {0} # -e to fail on error 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: [ubuntu-latest] 24 | python: ["3.10", "3.11", "3.12"] 25 | 26 | name: Integration 27 | 28 | env: 29 | OS: ${{ matrix.os }} 30 | PYTHON: ${{ matrix.python }} 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | - name: Set up Python ${{ matrix.python }} 35 | uses: actions/setup-python@v5 36 | with: 37 | python-version: ${{ matrix.python }} 38 | cache: "pip" 39 | cache-dependency-path: "**/pyproject.toml" 40 | 41 | - name: Install test dependencies 42 | run: | 43 | python -m pip install --upgrade pip wheel 44 | 45 | - name: Install dependencies 46 | run: | 47 | pip install ".[dev,test]" 48 | 49 | - name: Test 50 | env: 51 | MPLBACKEND: agg 52 | PLATFORM: ${{ matrix.os }} 53 | DISPLAY: :42 54 | run: | 55 | coverage run -m pytest -v --color=yes 56 | - name: Report coverage 57 | run: | 58 | coverage report 59 | - name: Upload coverage 60 | uses: codecov/codecov-action@v4 61 | with: 62 | token: ${{ secrets.CODECOV_TOKEN }} 63 | -------------------------------------------------------------------------------- /.github/workflows/test_linux_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: Test (Linux, CUDA) 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | types: [labeled, synchronize, opened] 7 | schedule: 8 | - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day 9 | workflow_dispatch: 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | # if PR has label "cuda tests" or "all tests" or if scheduled or manually triggered 18 | if: >- 19 | ( 20 | contains(github.event.pull_request.labels.*.name, 'cuda tests') || 21 | contains(github.event.pull_request.labels.*.name, 'all tests') || 22 | contains(github.event_name, 'schedule') || 23 | contains(github.event_name, 'workflow_dispatch') 24 | ) 25 | runs-on: [self-hosted, Linux, X64, CUDA] 26 | defaults: 27 | run: 28 | shell: bash -e {0} # -e to fail on error 29 | 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | python: ["3.11"] 34 | cuda: ["11"] 35 | 36 | container: 37 | image: scverse/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}-base 38 | options: --user root --gpus all 39 | 40 | name: Integration (CUDA) 41 | 42 | env: 43 | OS: ${{ matrix.os }} 44 | PYTHON: ${{ matrix.python }} 45 | 46 | steps: 47 | - uses: actions/checkout@v4 48 | 49 | - name: Install dependencies 50 | run: | 51 | pip install ".[dev,test]" 52 | 53 | - name: Test 54 | env: 55 | MPLBACKEND: agg 56 | PLATFORM: ${{ matrix.os }} 57 | DISPLAY: :42 58 | run: | 59 | coverage run -m pytest -v --color=yes 60 | - name: Report coverage 61 | run: | 62 | coverage report 63 | - name: Upload coverage 64 | uses: codecov/codecov-action@v4 65 | with: 66 | token: ${{ secrets.CODECOV_TOKEN }} 67 | -------------------------------------------------------------------------------- /.github/workflows/test_linux_pre.yaml: -------------------------------------------------------------------------------- 1 | name: Test (Linux, prereleases) 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | types: [labeled, synchronize, opened] 7 | schedule: 8 | - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day 9 | workflow_dispatch: 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | # if PR has label "cuda tests" or "all tests" or if scheduled or manually triggered 18 | if: >- 19 | ( 20 | contains(github.event.pull_request.labels.*.name, 'prerelease tests') || 21 | contains(github.event.pull_request.labels.*.name, 'all tests') || 22 | contains(github.event_name, 'schedule') || 23 | contains(github.event_name, 'workflow_dispatch') 24 | ) 25 | runs-on: ${{ matrix.os }} 26 | defaults: 27 | run: 28 | shell: bash -e {0} # -e to fail on error 29 | 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | os: [ubuntu-latest] 34 | python: ["3.10", "3.11", "3.12"] 35 | 36 | name: Integration (Prereleases) 37 | 38 | env: 39 | OS: ${{ matrix.os }} 40 | PYTHON: ${{ matrix.python }} 41 | 42 | steps: 43 | - uses: actions/checkout@v4 44 | - name: Set up Python ${{ matrix.python }} 45 | uses: actions/setup-python@v5 46 | with: 47 | python-version: ${{ matrix.python }} 48 | cache: "pip" 49 | cache-dependency-path: "**/pyproject.toml" 50 | 51 | - name: Install test dependencies 52 | run: | 53 | python -m pip install --upgrade pip wheel 54 | 55 | - name: Install dependencies 56 | run: | 57 | pip install --pre ".[dev,test]" 58 | 59 | - name: Test 60 | env: 61 | MPLBACKEND: agg 62 | PLATFORM: ${{ matrix.os }} 63 | DISPLAY: :42 64 | run: | 65 | coverage run -m pytest -v --color=yes 66 | - name: Report coverage 67 | run: | 68 | coverage report 69 | - name: Upload coverage 70 | uses: codecov/codecov-action@v4 71 | with: 72 | token: ${{ secrets.CODECOV_TOKEN }} 73 | -------------------------------------------------------------------------------- /.github/workflows/test_macos.yaml: -------------------------------------------------------------------------------- 1 | name: Test (MacOS) 2 | 3 | on: 4 | schedule: 5 | - cron: "0 10 * * *" # runs at 10:00 UTC -> 03:00 PST every day 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | test: 14 | runs-on: ${{ matrix.os }} 15 | defaults: 16 | run: 17 | shell: bash -e {0} # -e to fail on error 18 | 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [macos-latest] 23 | python: ["3.10", "3.11", "3.12"] 24 | 25 | name: Integration 26 | 27 | env: 28 | OS: ${{ matrix.os }} 29 | PYTHON: ${{ matrix.python }} 30 | 31 | steps: 32 | - uses: actions/checkout@v4 33 | - name: Set up Python ${{ matrix.python }} 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python }} 37 | cache: "pip" 38 | cache-dependency-path: "**/pyproject.toml" 39 | 40 | - name: Install test dependencies 41 | run: | 42 | python -m pip install --upgrade pip wheel 43 | 44 | - name: Install dependencies 45 | run: | 46 | pip install ".[dev,test]" 47 | 48 | - name: Test 49 | env: 50 | MPLBACKEND: agg 51 | PLATFORM: ${{ matrix.os }} 52 | DISPLAY: :42 53 | run: | 54 | coverage run -m pytest -v --color=yes 55 | - name: Report coverage 56 | run: | 57 | coverage report 58 | - name: Upload coverage 59 | uses: codecov/codecov-action@v4 60 | with: 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | -------------------------------------------------------------------------------- /.github/workflows/test_macos_m1.yaml: -------------------------------------------------------------------------------- 1 | name: Test (MacOS M1) 2 | 3 | on: 4 | schedule: 5 | - cron: "0 10 * * *" # runs at 10:00 UTC -> 03:00 PST every day 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | test: 14 | runs-on: ${{ matrix.os }} 15 | defaults: 16 | run: 17 | shell: bash -e {0} # -e to fail on error 18 | 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [macos-14] 23 | python: ["3.10", "3.11", "3.12"] 24 | 25 | name: Integration 26 | 27 | env: 28 | OS: ${{ matrix.os }} 29 | PYTHON: ${{ matrix.python }} 30 | 31 | steps: 32 | - uses: actions/checkout@v4 33 | - name: Set up Python ${{ matrix.python }} 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python }} 37 | cache: "pip" 38 | cache-dependency-path: "**/pyproject.toml" 39 | 40 | - name: Install test dependencies 41 | run: | 42 | python -m pip install --upgrade pip wheel 43 | 44 | - name: Install dependencies 45 | run: | 46 | pip install ".[dev,test]" 47 | 48 | - name: Test 49 | env: 50 | MPLBACKEND: agg 51 | PLATFORM: ${{ matrix.os }} 52 | DISPLAY: :42 53 | run: | 54 | coverage run -m pytest -v --color=yes 55 | - name: Report coverage 56 | run: | 57 | coverage report 58 | - name: Upload coverage 59 | uses: codecov/codecov-action@v4 60 | with: 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | -------------------------------------------------------------------------------- /.github/workflows/test_windows.yaml: -------------------------------------------------------------------------------- 1 | name: Test (Windows) 2 | 3 | on: 4 | schedule: 5 | - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | test: 14 | runs-on: ${{ matrix.os }} 15 | defaults: 16 | run: 17 | shell: bash -e {0} # -e to fail on error 18 | 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [windows-latest] 23 | python: ["3.10", "3.11", "3.12"] 24 | 25 | name: Integration 26 | 27 | env: 28 | OS: ${{ matrix.os }} 29 | PYTHON: ${{ matrix.python }} 30 | 31 | steps: 32 | - uses: actions/checkout@v4 33 | - name: Set up Python ${{ matrix.python }} 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python }} 37 | cache: "pip" 38 | cache-dependency-path: "**/pyproject.toml" 39 | 40 | - name: Install test dependencies 41 | run: | 42 | python -m pip install --upgrade pip wheel 43 | 44 | - name: Install dependencies 45 | run: | 46 | pip install ".[dev,test]" 47 | 48 | - name: Test 49 | env: 50 | MPLBACKEND: agg 51 | PLATFORM: ${{ matrix.os }} 52 | DISPLAY: :42 53 | run: | 54 | coverage run -m pytest -v --color=yes 55 | - name: Report coverage 56 | run: | 57 | coverage report 58 | - name: Upload coverage 59 | uses: codecov/codecov-action@v4 60 | with: 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | .DS_Store 3 | *~ 4 | buck-out/ 5 | 6 | # Compiled files 7 | .venv/ 8 | __pycache__/ 9 | .mypy_cache/ 10 | .ruff_cache/ 11 | 12 | # Distribution / packaging 13 | /build/ 14 | /dist/ 15 | /*.egg-info/ 16 | 17 | # Tests and coverage 18 | /.pytest_cache/ 19 | /.cache/ 20 | /data/ 21 | /node_modules/ 22 | 23 | # docs 24 | /docs/generated/ 25 | /docs/_build/ 26 | 27 | # IDEs 28 | /.idea/ 29 | /.vscode/ 30 | 31 | # node 32 | /node_modules/ 33 | 34 | docs/notebooks/data/ 35 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - pre-commit 6 | - pre-push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/pre-commit/mirrors-prettier 10 | rev: v4.0.0-alpha.8 11 | hooks: 12 | - id: prettier 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.11.11 15 | hooks: 16 | - id: ruff 17 | types_or: [python, pyi, jupyter] 18 | args: [--fix, --exit-non-zero-on-fix] 19 | - id: ruff-format 20 | types_or: [python, pyi, jupyter] 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v5.0.0 23 | hooks: 24 | - id: detect-private-key 25 | - id: check-ast 26 | - id: end-of-file-fixer 27 | - id: mixed-line-ending 28 | args: [--fix=lf] 29 | - id: trailing-whitespace 30 | - id: check-case-conflict 31 | # Check that there are no merge conflicts (could be generated by template sync) 32 | - id: check-merge-conflict 33 | args: [--assume-in-merge] 34 | - repo: local 35 | hooks: 36 | - id: forbid-to-commit 37 | name: Don't commit rej files 38 | entry: | 39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 40 | Fix the merge conflicts manually and remove the .rej files. 41 | language: fail 42 | files: '.*\.rej$' 43 | - repo: https://github.com/kynan/nbstripout 44 | rev: 0.8.1 45 | hooks: 46 | - id: nbstripout 47 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog][], 6 | and this project adheres to [Semantic Versioning][]. 7 | 8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html 10 | 11 | ## 0.6.0 (unreleased) 12 | 13 | ## 0.5.4 (2025-04-23) 14 | 15 | ### Fixed 16 | 17 | - Apply default values for benchmarker metrics {pr}`203`. 18 | 19 | ## 0.5.3 (2025-02-17) 20 | 21 | ### Removed 22 | 23 | - Reverted a change that was needed for scib-autotune in scvi-tools {pr}`189`. 24 | 25 | ## 0.5.2 (2025-02-13) 26 | 27 | ### Added 28 | 29 | - Add `progress_bar` argument to {class}`scib_metrics.benchmark.Benchmarker` {pr}`152`. 30 | - Add ability of {class}`scib_metrics.benchmark.Benchmarker` plotting code to handle missing sets of metrics {pr}`181`. 31 | - Add random score in case of aggregate metrics not selected to be used in scib autotune in scvi-tools, {pr}`188`. 32 | 33 | ### Changed 34 | 35 | - Changed Leiden clustering now has a seed argument for reproducibility {pr}`173`. 36 | - Changed passing `None` to `bio_conservation_metrics` or `batch_correction_metrics` in {class}`scib_metrics.benchmark.Benchmarker` now implies to skip this set of metrics {pr}`181`. 37 | 38 | ### Fixed 39 | 40 | - Fix neighbors connectivities in test to use new scanpy fn {pr}`170`. 41 | - Fix Kmeans test {pr}`172`. 42 | - Fix deprecation and future warnings {pr}`171`. 43 | - Fix lisi return type and docstring {pr}`182`. 44 | 45 | ## 0.5.1 (2024-02-23) 46 | 47 | ### Changed 48 | 49 | - Replace removed {class}`jax.random.KeyArray` with {class}`jax.Array` {pr}`135`. 50 | 51 | ## 0.5.0 (2024-01-04) 52 | 53 | ### Changed 54 | 55 | - Refactor all relevant metrics to use `NeighborsResults` as input instead of sparse 56 | distance/connectivity matrices {pr}`129`. 57 | 58 | ## 0.4.1 (2023-10-08) 59 | 60 | ### Fixed 61 | 62 | - Fix KMeans. All previous versions had a bug with KMeans and ARI/NMI metrics are not reliable 63 | with this clustering {pr}`115`. 64 | 65 | ## 0.4.0 (2023-09-19) 66 | 67 | ### Added 68 | 69 | - Update isolated labels to use newest scib methodology {pr}`108`. 70 | 71 | ### Fixed 72 | 73 | - Fix jax one-hot error {pr}`107`. 74 | 75 | ### Removed 76 | 77 | - Drop Python 3.8 {pr}`107`. 78 | 79 | ## 0.3.3 (2023-03-29) 80 | 81 | ### Fixed 82 | 83 | - Large scale tutorial now properly uses gpu index {pr}`92` 84 | 85 | ## 0.3.2 (2023-03-13) 86 | 87 | ### Changed 88 | 89 | - Switch to Ruff for linting/formatting {pr}`87` 90 | - Update cookiecutter template {pr}`88` 91 | 92 | ## 0.3.1 (2023-02-16) 93 | 94 | ### Changed 95 | 96 | - Expose chunk size for silhouette {pr}`82` 97 | 98 | ## 0.3.0 (2023-02-16) 99 | 100 | ### Changed 101 | 102 | - Rename `KmeansJax` to `Kmeans` and fix ++ initialization, use Kmeans as default in benchmarker instead of Leiden {pr}`81`. 103 | - Warn about joblib, add progress bar postfix str {pr}`80` 104 | 105 | ## 0.2.0 (2023-02-02) 106 | 107 | ### Added 108 | 109 | - Allow custom nearest neighbors methods in Benchmarker {pr}`78`. 110 | 111 | ## 0.1.1 (2023-01-04) 112 | 113 | ### Added 114 | 115 | - Add new tutorial and fix scalability of lisi {pr}`71`. 116 | 117 | ## 0.1.0 (2023-01-03) 118 | 119 | ### Added 120 | 121 | - Add benchmarking pipeline with plotting {pr}`52` {pr}`69`. 122 | 123 | ### Fixed 124 | 125 | - Fix diffusion distance computation, affecting kbet {pr}`70`. 126 | 127 | ## 0.0.9 (2022-12-16) 128 | 129 | ### Added 130 | 131 | - Add kbet {pr}`60`. 132 | - Add graph connectivty metric {pr}`61`. 133 | 134 | ## 0.0.8 (2022-11-18) 135 | 136 | ### Changed 137 | 138 | - Switch to random kmeans initialization due to kmeans++ complexity issues {pr}`54`. 139 | 140 | ### Fixed 141 | 142 | - Begin fixes to make kmeans++ initialization faster {pr}`49`. 143 | 144 | ## 0.0.7 (2022-10-31) 145 | 146 | ### Changed 147 | 148 | - Move PCR to utils module in favor of PCR comparison {pr}`46`. 149 | 150 | ### Fixed 151 | 152 | - Fix memory issue in `KMeansJax` by using `_kmeans_full_run` with `map` instead of `vmap` {pr}`45`. 153 | 154 | ## 0.0.6 (2022-10-25) 155 | 156 | ### Changed 157 | 158 | - Reimplement silhouette in a memory constant way, pdist using lax scan {pr}`42`. 159 | 160 | ## 0.0.5 (2022-10-24) 161 | 162 | ### Added 163 | 164 | - Standardize language of docstring {pr}`30`. 165 | - Use K-means++ initialization {pr}`23`. 166 | - Add pc regression and pc comparsion {pr}`16` {pr}`38`. 167 | - Lax'd silhouette {pr}`33`. 168 | - Cookicutter template sync {pr}`35`. 169 | 170 | ## 0.0.4 (2022-10-10) 171 | 172 | ### Added 173 | 174 | - NMI/ARI metric with Leiden clustering resolution optimization {pr}`24`. 175 | - iLISI/cLISI metrics {pr}`20`. 176 | 177 | ## 0.0.1 - 0.0.3 178 | 179 | See the [GitHub releases][https://github.com/yoseflab/scib-metrics/releases] for details. 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Adam Gayoso 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scib-metrics 2 | 3 | [![Stars][badge-stars]][link-stars] 4 | [![PyPI][badge-pypi]][link-pypi] 5 | [![PyPIDownloads][badge-downloads]][link-downloads] 6 | [![Docs][badge-docs]][link-docs] 7 | [![Build][badge-build]][link-build] 8 | [![Coverage][badge-cov]][link-cov] 9 | [![Discourse][badge-discourse]][link-discourse] 10 | [![Chat][badge-zulip]][link-zulip] 11 | 12 | [badge-stars]: https://img.shields.io/github/stars/YosefLab/scib-metrics?logo=GitHub&color=yellow 13 | [link-stars]: https://github.com/YosefLab/scib-metrics/stargazers 14 | [badge-pypi]: https://img.shields.io/pypi/v/scib-metrics.svg 15 | [link-pypi]: https://pypi.org/project/scib-metrics 16 | [badge-downloads]: https://static.pepy.tech/badge/scib-metrics 17 | [link-downloads]: https://pepy.tech/project/scib-metrics 18 | [badge-docs]: https://readthedocs.org/projects/scib-metrics/badge/?version=latest 19 | [link-docs]: https://scib-metrics.readthedocs.io/en/latest/?badge=latest 20 | [badge-build]: https://github.com/YosefLab/scib-metrics/actions/workflows/build.yaml/badge.svg 21 | [link-build]: https://github.com/YosefLab/scib-metrics/actions/workflows/build.yaml/ 22 | [badge-cov]: https://codecov.io/gh/YosefLab/scib-metrics/branch/main/graph/badge.svg 23 | [link-cov]: https://codecov.io/gh/YosefLab/scib-metrics 24 | [badge-discourse]: https://img.shields.io/discourse/posts?color=yellow&logo=discourse&server=https%3A%2F%2Fdiscourse.scverse.org 25 | [link-discourse]: https://discourse.scverse.org/ 26 | [badge-zulip]: https://img.shields.io/badge/zulip-join_chat-brightgreen.svg 27 | [link-zulip]: https://scverse.zulipchat.com/ 28 | 29 | Accelerated and Python-only metrics for benchmarking single-cell integration outputs. 30 | 31 | This package contains implementations of metrics for evaluating the performance of single-cell omics data integration methods. The implementations of these metrics use [JAX](https://jax.readthedocs.io/en/latest/) when possible for jit-compilation and hardware acceleration. All implementations are in Python. 32 | 33 | Currently we are porting metrics used in the scIB [manuscript](https://www.nature.com/articles/s41592-021-01336-8) (and [code](https://github.com/theislab/scib)). Deviations from the original implementations are documented. However, metric values from this repository should not be compared to the scIB repository. 34 | 35 | ## Getting started 36 | 37 | Please refer to the [documentation][link-docs]. 38 | 39 | ## Installation 40 | 41 | You need to have Python 3.10 or newer installed on your system. If you don't have 42 | Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 43 | 44 | There are several options to install scib-metrics: 45 | 46 | 1. Install the latest release on PyPI: 47 | 48 | ```bash 49 | pip install scib-metrics 50 | ``` 51 | 52 | 2. Install the latest development version: 53 | 54 | ```bash 55 | pip install git+https://github.com/yoseflab/scib-metrics.git@main 56 | ``` 57 | 58 | To leverage hardware acceleration (e.g., GPU) please install the apprpriate version of [JAX](https://github.com/google/jax#installation) separately. Often this can be easier by using conda-distributed versions of JAX. 59 | 60 | ## Release notes 61 | 62 | See the [changelog][changelog]. 63 | 64 | ## Contact 65 | 66 | For questions and help requests, you can reach out in the [scverse Discourse][link-discourse]. 67 | If you found a bug, please use the [issue tracker][issue-tracker]. 68 | 69 | ## Citation 70 | 71 | References for individual metrics can be found in the corresponding documentation. This package is heavily inspired by the single-cell integration benchmarking work: 72 | 73 | ``` 74 | @article{luecken2022benchmarking, 75 | title={Benchmarking atlas-level data integration in single-cell genomics}, 76 | author={Luecken, Malte D and B{\"u}ttner, Maren and Chaichoompu, Kridsadakorn and Danese, Anna and Interlandi, Marta and M{\"u}ller, Michaela F and Strobl, Daniel C and Zappia, Luke and Dugas, Martin and Colom{\'e}-Tatch{\'e}, Maria and others}, 77 | journal={Nature methods}, 78 | volume={19}, 79 | number={1}, 80 | pages={41--50}, 81 | year={2022}, 82 | publisher={Nature Publishing Group} 83 | } 84 | ``` 85 | 86 | [scverse-discourse]: https://discourse.scverse.org/ 87 | [issue-tracker]: https://github.com/YosefLab/scib-metrics/issues 88 | [changelog]: https://scib-metrics.readthedocs.io/en/latest/changelog.html 89 | [link-api]: https://scib-metrics.readthedocs.io/en/latest/api.html 90 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ 2 | div.cell_output table.dataframe { 3 | font-size: 0.8em; 4 | } 5 | -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/_templates/class_no_inherited.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | :show-inheritance: 9 | 10 | {% block attributes %} 11 | {% if attributes %} 12 | Attributes table 13 | ~~~~~~~~~~~~~~~~ 14 | 15 | .. autosummary:: 16 | {% for item in attributes %} 17 | {%- if item not in inherited_members%} 18 | ~{{ fullname }}.{{ item }} 19 | {%- endif -%} 20 | {%- endfor %} 21 | {% endif %} 22 | {% endblock %} 23 | 24 | 25 | {% block methods %} 26 | {% if methods %} 27 | Methods table 28 | ~~~~~~~~~~~~~~ 29 | 30 | .. autosummary:: 31 | {% for item in methods %} 32 | {%- if item != '__init__' and item not in inherited_members%} 33 | ~{{ fullname }}.{{ item }} 34 | {%- endif -%} 35 | 36 | {%- endfor %} 37 | {% endif %} 38 | {% endblock %} 39 | 40 | {% block attributes_documentation %} 41 | {% if attributes %} 42 | Attributes 43 | ~~~~~~~~~~ 44 | 45 | {% for item in attributes %} 46 | {{ item }} 47 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 48 | 49 | .. autoattribute:: {{ [objname, item] | join(".") }} 50 | {%- endfor %} 51 | 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block methods_documentation %} 56 | {% if methods %} 57 | Methods 58 | ~~~~~~~ 59 | 60 | {% for item in methods %} 61 | {%- if item != '__init__' and item not in inherited_members%} 62 | {{ item }} 63 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 64 | .. automethod:: {{ [objname, item] | join(".") }} 65 | {%- endif -%} 66 | {%- endfor %} 67 | 68 | {% endif %} 69 | {% endblock %} 70 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Benchmarking pipeline 4 | 5 | Import as: 6 | 7 | ``` 8 | from scib_metrics.benchmark import Benchmarker 9 | ``` 10 | 11 | ```{eval-rst} 12 | .. module:: scib_metrics.benchmark 13 | .. currentmodule:: scib_metrics 14 | 15 | .. autosummary:: 16 | :toctree: generated 17 | 18 | benchmark.Benchmarker 19 | benchmark.BioConservation 20 | benchmark.BatchCorrection 21 | ``` 22 | 23 | ## Metrics 24 | 25 | Import as: 26 | 27 | ``` 28 | import scib_metrics 29 | scib_metrics.ilisi_knn(...) 30 | ``` 31 | 32 | ```{eval-rst} 33 | .. module:: scib_metrics 34 | .. currentmodule:: scib_metrics 35 | 36 | .. autosummary:: 37 | :toctree: generated 38 | 39 | isolated_labels 40 | nmi_ari_cluster_labels_kmeans 41 | nmi_ari_cluster_labels_leiden 42 | pcr_comparison 43 | silhouette_label 44 | silhouette_batch 45 | ilisi_knn 46 | clisi_knn 47 | kbet 48 | kbet_per_label 49 | graph_connectivity 50 | ``` 51 | 52 | ## Utils 53 | 54 | ```{eval-rst} 55 | .. module:: scib_metrics.utils 56 | .. currentmodule:: scib_metrics 57 | 58 | .. autosummary:: 59 | :toctree: generated 60 | 61 | utils.cdist 62 | utils.pdist_squareform 63 | utils.silhouette_samples 64 | utils.KMeans 65 | utils.pca 66 | utils.principal_component_regression 67 | utils.one_hot 68 | utils.compute_simpson_index 69 | utils.convert_knn_graph_to_idx 70 | utils.check_square 71 | utils.diffusion_nn 72 | ``` 73 | 74 | ### Nearest neighbors 75 | 76 | ```{eval-rst} 77 | .. module:: scib_metrics.nearest_neighbors 78 | .. currentmodule:: scib_metrics 79 | 80 | .. autosummary:: 81 | :toctree: generated 82 | 83 | nearest_neighbors.pynndescent 84 | nearest_neighbors.jax_approx_min_k 85 | nearest_neighbors.NeighborsResults 86 | ``` 87 | 88 | ## Settings 89 | 90 | An instance of the {class}`~scib_metrics._settings.ScibConfig` is available as `scib_metrics.settings` and allows configuring scib_metrics. 91 | 92 | ```{eval-rst} 93 | .. autosummary:: 94 | :toctree: reference/ 95 | :nosignatures: 96 | 97 | _settings.ScibConfig 98 | ``` 99 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | from typing import Any 9 | import subprocess 10 | import os 11 | import importlib 12 | import inspect 13 | import re 14 | import sys 15 | from datetime import datetime 16 | from importlib.metadata import metadata 17 | from pathlib import Path 18 | 19 | HERE = Path(__file__).parent 20 | sys.path.insert(0, str(HERE / "extensions")) 21 | 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project_name = "scib-metrics" 26 | info = metadata(project_name) 27 | package_name = "scib_metrics" 28 | author = info["Author"] 29 | copyright = f"{datetime.now():%Y}, {author}." 30 | version = info["Version"] 31 | repository_url = f"https://github.com/YosefLab/{project_name}" 32 | 33 | # The full version, including alpha/beta/rc tags 34 | release = info["Version"] 35 | 36 | bibtex_bibfiles = ["references.bib"] 37 | templates_path = ["_templates"] 38 | nitpicky = True # Warn about broken links 39 | needs_sphinx = "4.0" 40 | 41 | html_context = { 42 | "display_github": True, # Integrate GitHub 43 | "github_user": "yoseflab", # Username 44 | "github_repo": project_name, # Repo name 45 | "github_version": "main", # Version 46 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 47 | } 48 | 49 | # -- General configuration --------------------------------------------------- 50 | 51 | # Add any Sphinx extension module names here, as strings. 52 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 53 | extensions = [ 54 | "myst_nb", 55 | "sphinx.ext.autodoc", 56 | "sphinx_copybutton", 57 | "sphinx.ext.linkcode", 58 | "sphinx.ext.intersphinx", 59 | "sphinx.ext.autosummary", 60 | "sphinx.ext.napoleon", 61 | "sphinx.ext.extlinks", 62 | "sphinxcontrib.bibtex", 63 | "sphinx_autodoc_typehints", 64 | "sphinx.ext.mathjax", 65 | "IPython.sphinxext.ipython_console_highlighting", 66 | "sphinxext.opengraph", 67 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 68 | ] 69 | 70 | autosummary_generate = True 71 | autodoc_member_order = "groupwise" 72 | default_role = "literal" 73 | bibtex_reference_style = "author_year" 74 | napoleon_google_docstring = False 75 | napoleon_numpy_docstring = True 76 | napoleon_include_init_with_doc = False 77 | napoleon_use_rtype = True # having a separate entry generally helps readability 78 | napoleon_use_param = True 79 | myst_heading_anchors = 6 # create anchors for h1-h6 80 | myst_enable_extensions = [ 81 | "amsmath", 82 | "colon_fence", 83 | "deflist", 84 | "dollarmath", 85 | "html_image", 86 | "html_admonition", 87 | ] 88 | myst_url_schemes = ("http", "https", "mailto") 89 | nb_output_stderr = "remove" 90 | nb_execution_mode = "off" 91 | nb_merge_streams = True 92 | typehints_defaults = "braces" 93 | 94 | source_suffix = { 95 | ".rst": "restructuredtext", 96 | ".ipynb": "myst-nb", 97 | ".myst": "myst-nb", 98 | } 99 | 100 | intersphinx_mapping = { 101 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 102 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None), 103 | "matplotlib": ("https://matplotlib.org/", None), 104 | "numpy": ("https://numpy.org/doc/stable/", None), 105 | "pandas": ("https://pandas.pydata.org/docs/", None), 106 | "python": ("https://docs.python.org/3", None), 107 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), 108 | "sklearn": ("https://scikit-learn.org/stable/", None), 109 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), 110 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 111 | "plottable": ("https://plottable.readthedocs.io/en/latest/", None), 112 | } 113 | 114 | # List of patterns, relative to source directory, that match files and 115 | # directories to ignore when looking for source files. 116 | # This pattern also affects html_static_path and html_extra_path. 117 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 118 | 119 | # extlinks config 120 | extlinks = { 121 | "issue": (f"{repository_url}/issues/%s", "#%s"), 122 | "pr": (f"{repository_url}/pull/%s", "#%s"), 123 | "ghuser": ("https://github.com/%s", "@%s"), 124 | } 125 | 126 | # -- Linkcode settings ------------------------------------------------- 127 | 128 | 129 | def git(*args): 130 | """Run a git command and return the output.""" 131 | return subprocess.check_output(["git", *args]).strip().decode() 132 | 133 | 134 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192 135 | # Current git reference. Uses branch/tag name if found, otherwise uses commit hash 136 | git_ref = None 137 | try: 138 | git_ref = git("name-rev", "--name-only", "--no-undefined", "HEAD") 139 | git_ref = re.sub(r"^(remotes/[^/]+|tags)/", "", git_ref) 140 | except Exception: 141 | pass 142 | 143 | # (if no name found or relative ref, use commit hash instead) 144 | if not git_ref or re.search(r"[\^~]", git_ref): 145 | try: 146 | git_ref = git("rev-parse", "HEAD") 147 | except Exception: 148 | git_ref = "main" 149 | 150 | # https://github.com/DisnakeDev/disnake/blob/7853da70b13fcd2978c39c0b7efa59b34d298186/docs/conf.py#L192 151 | github_repo = "https://github.com/" + html_context["github_user"] + "/" + project_name 152 | _project_module_path = os.path.dirname(importlib.util.find_spec(package_name).origin) # type: ignore 153 | 154 | 155 | def linkcode_resolve(domain, info): 156 | """Resolve links for the linkcode extension.""" 157 | if domain != "py": 158 | return None 159 | 160 | try: 161 | obj: Any = sys.modules[info["module"]] 162 | for part in info["fullname"].split("."): 163 | obj = getattr(obj, part) 164 | obj = inspect.unwrap(obj) 165 | 166 | if isinstance(obj, property): 167 | obj = inspect.unwrap(obj.fget) # type: ignore 168 | 169 | path = os.path.relpath(inspect.getsourcefile(obj), start=_project_module_path) # type: ignore 170 | src, lineno = inspect.getsourcelines(obj) 171 | except Exception: 172 | return None 173 | 174 | path = f"{path}#L{lineno}-L{lineno + len(src) - 1}" 175 | return f"{github_repo}/blob/{git_ref}/src/{package_name}/{path}" 176 | 177 | 178 | # -- Options for HTML output ------------------------------------------------- 179 | 180 | # The theme to use for HTML and HTML Help pages. See the documentation for 181 | # a list of builtin themes. 182 | # 183 | html_theme = "sphinx_book_theme" 184 | html_static_path = ["_static"] 185 | html_css_files = ["css/custom.css"] 186 | html_title = "scib-metrics" 187 | 188 | html_theme_options = { 189 | "repository_url": github_repo, 190 | "use_repository_button": True, 191 | } 192 | 193 | pygments_style = "default" 194 | 195 | nitpick_ignore = [ 196 | # If building the documentation fails because of a missing link that is outside your control, 197 | # you can add an exception to this list. 198 | ] 199 | 200 | 201 | def setup(app): 202 | """App setup hook.""" 203 | app.add_config_value( 204 | "recommonmark_config", 205 | { 206 | "auto_toc_tree_section": "Contents", 207 | "enable_auto_toc_tree": True, 208 | "enable_math": True, 209 | "enable_inline_math": False, 210 | "enable_eval_rst": True, 211 | }, 212 | True, 213 | ) 214 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too. 4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important 5 | information to get you started on contributing. 6 | 7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer 8 | to the [scanpy developer guide][]. 9 | 10 | ## Installing dev dependencies 11 | 12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build 13 | the documentation_. It's easy to install them using `pip`: 14 | 15 | ```bash 16 | cd scib-metrics 17 | pip install -e ".[dev,test,doc]" 18 | ``` 19 | 20 | ## Code-style 21 | 22 | This package uses [pre-commit][] to enforce consistent code-styles. 23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message. 24 | 25 | To enable pre-commit locally, simply run 26 | 27 | ```bash 28 | pre-commit install 29 | ``` 30 | 31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time. 32 | 33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before 34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message. 35 | 36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use 37 | 38 | ```bash 39 | git pull --rebase 40 | ``` 41 | 42 | to integrate the changes into yours. 43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage. 44 | 45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors] 46 | and [prettier][prettier-editors]. 47 | 48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/ 49 | [prettier-editors]: https://prettier.io/docs/en/editors.html 50 | 51 | ## Writing tests 52 | 53 | ```{note} 54 | Remember to first install the package with `pip install -e '.[dev,test]'` 55 | ``` 56 | 57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added 58 | to the package. 59 | 60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the 61 | command line by executing 62 | 63 | ```bash 64 | pytest 65 | ``` 66 | 67 | in the root of the repository. 68 | 69 | ### Continuous integration 70 | 71 | Continuous integration will automatically run the tests on all pull requests and test 72 | against the minimum and maximum supported Python version. 73 | 74 | Additionally, there's a CI job that tests against pre-releases of all dependencies 75 | (if there are any). The purpose of this check is to detect incompatibilities 76 | of new package versions early on and gives you time to fix the issue or reach 77 | out to the developers of the dependency before the package is released to a wider audience. 78 | 79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests 80 | 81 | ## Publishing a release 82 | 83 | ### Updating the version number 84 | 85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief 86 | 87 | > Given a version number MAJOR.MINOR.PATCH, increment the: 88 | > 89 | > 1. MAJOR version when you make incompatible API changes, 90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and 91 | > 3. PATCH version when you make backwards compatible bug fixes. 92 | > 93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 94 | 95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub. 96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI. 97 | 98 | ## Writing documentation 99 | 100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: 101 | 102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text 103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). 104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) 105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types 106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) 107 | 108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information 109 | on how to write documentation. 110 | 111 | ### Tutorials with myst-nb and jupyter notebooks 112 | 113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. 114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. 115 | It is your responsibility to update and re-run the notebook whenever necessary. 116 | 117 | If you are interested in automatically running notebooks as part of the continuous integration, please check 118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse` 119 | repository. 120 | 121 | #### Hints 122 | 123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only 124 | if you do so can sphinx automatically create a link to the external documentation. 125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to 126 | the `nitpick_ignore` list in `docs/conf.py` 127 | 128 | #### Building the docs locally 129 | 130 | ```bash 131 | cd docs 132 | make html 133 | open _build/html/index.html 134 | ``` 135 | 136 | 137 | 138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html 140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 141 | [codecov]: https://about.codecov.io/sign-up/ 142 | [codecov docs]: https://docs.codecov.com/docs 143 | [codecov bot]: https://docs.codecov.com/docs/team-bot 144 | [codecov app]: https://github.com/apps/codecov 145 | [pre-commit.ci]: https://pre-commit.ci/ 146 | [readthedocs.org]: https://readthedocs.org/ 147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 149 | [pre-commit]: https://pre-commit.com/ 150 | [anndata]: https://github.com/scverse/anndata 151 | [mudata]: https://github.com/scverse/mudata 152 | [pytest]: https://docs.pytest.org/ 153 | [semver]: https://semver.org/ 154 | [sphinx]: https://www.sphinx-doc.org/en/master/ 155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 159 | [pypi]: https://pypi.org/ 160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository 161 | -------------------------------------------------------------------------------- /docs/extensions/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/docs/extensions/.gitkeep -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | from __future__ import annotations 4 | from sphinx.ext.napoleon import NumpyDocstring 5 | from typing import TYPE_CHECKING 6 | import re 7 | 8 | if TYPE_CHECKING: 9 | from collections.abc import Generator, Iterable 10 | 11 | from sphinx.application import Sphinx 12 | 13 | 14 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: 15 | for line in lines: 16 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): 17 | yield f"-{m['param']} (:class:`~{m['type']}`)" 18 | else: 19 | yield line 20 | 21 | 22 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 23 | lines_raw = self._dedent(self._consume_to_next_section()) 24 | if lines_raw[0] == ":": 25 | del lines_raw[0] 26 | lines = self._format_block(":returns: ", list(_process_return(lines_raw))) 27 | if lines and lines[-1]: 28 | lines.append("") 29 | return lines 30 | 31 | 32 | def setup(app: Sphinx): 33 | """Set app.""" 34 | NumpyDocstring._parse_returns_section = _parse_returns_section 35 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | ```{toctree} 6 | :hidden: true 7 | :maxdepth: 1 8 | 9 | api.md 10 | tutorials.md 11 | changelog.md 12 | template_usage.md 13 | contributing.md 14 | references.md 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{luecken2022benchmarking, 2 | title = {Benchmarking atlas-level data integration in single-cell genomics}, 3 | author = {Luecken, Malte D and B{\"u}ttner, Maren and Chaichoompu, Kridsadakorn and Danese, Anna and Interlandi, Marta and M{\"u}ller, Michaela F and Strobl, Daniel C and Zappia, Luke and Dugas, Martin and Colom{\'e}-Tatch{\'e}, Maria and others}, 4 | journal = {Nature methods}, 5 | volume = {19}, 6 | number = {1}, 7 | pages = {41--50}, 8 | year = {2022}, 9 | publisher = {Nature Publishing Group} 10 | } 11 | 12 | 13 | @article{korsunsky2019harmony, 14 | title = {Fast, sensitive and accurate integration of single-cell data with 15 | Harmony}, 16 | author = {Korsunsky, Ilya and Millard, Nghia and Fan, Jean and Slowikowski, 17 | Kamil and Zhang, Fan and Wei, Kevin and Baglaenko, Yuriy and 18 | Brenner, Michael and Loh, Po-Ru and Raychaudhuri, Soumya}, 19 | journal = {Nat. Methods}, 20 | volume = {16}, 21 | number = {12}, 22 | pages = {1289--1296}, 23 | month = {dec}, 24 | year = {2019}, 25 | } 26 | 27 | @article{buttner2018, 28 | title = {A test metric for assessing single-cell {RNA}-seq batch correction}, 29 | author = {Maren B\"{u}ttner and Zhichao Miao and F. Alexander Wolf and Sarah A. Teichmann and Fabian J. Theis}, 30 | doi = {10.1038/s41592-018-0254-1}, 31 | year = {2018}, 32 | month = dec, 33 | journal = {Nature Methods}, 34 | volume = {16}, 35 | number = {1}, 36 | pages = {43--49}, 37 | publisher = {Springer Science and Business Media {LLC}} 38 | } 39 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/template_usage.md: -------------------------------------------------------------------------------- 1 | # Using this template 2 | 3 | Welcome to the developer guidelines! This document is split into two parts: 4 | 5 | 1. The [repository setup](#setting-up-the-repository). This section is relevant primarily for the repository maintainer and shows how to connect 6 | continuous integration services and documents initial set-up of the repository. 7 | 2. The [contributor guide](contributing.md#contributing-guide). It contains information relevant to all developers who want to make a contribution. 8 | 9 | ## Setting up the repository 10 | 11 | ### First commit 12 | 13 | If you are reading this, you should have just completed the repository creation with : 14 | 15 | ```bash 16 | cruft create https://github.com/scverse/cookiecutter-scverse 17 | ``` 18 | 19 | and you should have 20 | 21 | ``` 22 | cd scib-metrics 23 | ``` 24 | 25 | into the new project directory. Now that you have created a new repository locally, the first step is to push it to github. To do this, you'd have to create a **new repository** on github. 26 | You can follow the instructions directly on [github quickstart guide][]. 27 | Since `cruft` already populated the local repository of your project with all the necessary files, we suggest to _NOT_ initialize the repository with a `README.md` file or `.gitignore`, because you might encounter git conflicts on your first push. 28 | If you are familiar with git and knows how to handle git conflicts, you can go ahead with your preferred choice. 29 | 30 | :::{note} 31 | If you are looking at this document in the [cookiecutter-scverse-instance][] repository documentation, throughout this document the name of the project is `cookiecutter-scverse-instance`. Otherwise it should be replaced by your new project name: `scib-metrics`. 32 | ::: 33 | 34 | Now that your new project repository has been created on github at `https://github.com/adamgayoso/scib-metrics` you can push your first commit to github. 35 | To do this, simply follow the instructions on your github repository page or a more verbose walkthrough here: 36 | 37 | Assuming you are in `/your/path/to/scib-metrics`. Add all files and commit. 38 | 39 | ```bash 40 | # stage all files of your new repo 41 | git add --all 42 | # commit 43 | git commit -m "first commit" 44 | ``` 45 | 46 | You'll notice that the command `git commit` installed a bunch of packages and triggered their execution: those are pre-commit! To read more about what they are and what they do, you can go to the related section [Pre-commit checks](#pre-commit-checks) in this document. 47 | 48 | :::{note} 49 | There is a chance that `git commit -m "first commit"` fails due to the `prettier` pre-commit formatting the file `.cruft.json`. No problem, you have just experienced what pre-commit checks do in action. Just go ahead and re-add the modified file and try to commit again: 50 | 51 | ```bash 52 | git add -u # update all tracked file 53 | git commit -m "first commit" 54 | ``` 55 | 56 | ::: 57 | 58 | Now that all the files of the newly created project have been committed, go ahead with the remaining steps: 59 | 60 | ```bash 61 | # update the `origin` of your local repo with the remote github link 62 | git remote add origin https://github.com/adamgayoso/scib-metrics.git 63 | # rename the default branch to main 64 | git branch -M main 65 | # push all your files to remote 66 | git push -u origin main 67 | ``` 68 | 69 | Your project should be now available at `https://github.com/adamgayoso/scib-metrics`. While the repository at this point can be directly used, there are few remaining steps that needs to be done in order to achieve full functionality. 70 | 71 | ### Coverage tests with _Codecov_ 72 | 73 | Coverage tells what fraction of the code is "covered" by unit tests, thereby encouraging contributors to 74 | [write tests](contributing.md#writing-tests). 75 | To enable coverage checks, head over to [codecov][] and sign in with your GitHub account. 76 | You'll find more information in "getting started" section of the [codecov docs][]. 77 | 78 | In the `Actions` tab of your projects' github repository, you can see that the workflows are failing due to the **Upload coverage** step. The error message in the workflow should display something like: 79 | 80 | ``` 81 | ... 82 | Retrying 5/5 in 2s.. 83 | {'detail': ErrorDetail(string='Could not find a repository, try using repo upload token', code='not_found')} 84 | Error: 404 Client Error: Not Found for url: 85 | ... 86 | ``` 87 | 88 | While [codecov docs][] has a very extensive documentation on how to get started, _if_ you are using the default settings of this template we can assume that you are using [codecov][] in a github action workflow and hence you can make use of the [codecov bot][]. 89 | 90 | To set it up, simply go to the [codecov app][] page and follow the instructions to activate it for your repository. 91 | Once the activation is completed, go back to the `Actions` tab and re-run the failing workflows. 92 | 93 | The workflows should now succeed and you will be able to find the code coverage at this link: `https://app.codecov.io/gh/adamgayoso/scib-metrics`. You might have to wait couple of minutes and the coverage of this repository should be ~60%. 94 | 95 | If your repository is private, you will have to specify an additional token in the repository secrets. In brief, you need to: 96 | 97 | 1. Generate a Codecov Token by clicking _setup repo_ in the codecov dashboard. 98 | - If you have already set up codecov in the repository by following the previous steps, you can directly go to the codecov repo webpage. 99 | 2. Go to _Settings_ and copy **only** the token `_______-____-...`. 100 | 3. Go to _Settings_ of your newly created repository on GitHub. 101 | 4. Go to _Security > Secrets > Actions_. 102 | 5. Create new repository secret with name `CODECOV_TOKEN` and paste the token generated by codecov. 103 | 6. Past these additional lines in `/.github/workflows.test.yaml` under the **Upload coverage** step: 104 | ```bash 105 | - name: Upload coverage 106 | uses: codecov/codecov-action@v3 107 | with: 108 | token: ${{ secrets.CODECOV_TOKEN }} 109 | ``` 110 | 7. Go back to github `Actions` page an re-run previously failed jobs. 111 | 112 | ### Documentation on _readthedocs_ 113 | 114 | We recommend using [readthedocs.org][] (RTD) to build and host the documentation for your project. 115 | To enable readthedocs, head over to [their website][readthedocs.org] and sign in with your GitHub account. 116 | On the RTD dashboard choose "Import a Project" and follow the instructions to add your repository. 117 | 118 | - Make sure to choose the correct name of the default branch. On GitHub, the name of the default branch should be `main` (it has 119 | recently changed from `master` to `main`). 120 | - We recommend to enable documentation builds for pull requests (PRs). This ensures that a PR doesn't introduce changes 121 | that break the documentation. To do so, got to `Admin -> Advanced Settings`, check the 122 | `Build pull requests for this projects` option, and click `Save`. For more information, please refer to 123 | the [official RTD documentation](https://docs.readthedocs.io/en/stable/pull-requests.html). 124 | - If you find the RTD builds are failing, you can disable the `fail_on_warning` option in `.readthedocs.yaml`. 125 | 126 | If your project is private, there are ways to enable docs rendering on [readthedocs.org][] but it is more cumbersome and requires a different subscription for read the docs. See a guide [here](https://docs.readthedocs.io/en/stable/guides/importing-private-repositories.html). 127 | 128 | ### Pre-commit checks 129 | 130 | [Pre-commit][] checks are fast programs that 131 | check code for errors, inconsistencies and code styles, before the code 132 | is committed. 133 | 134 | This template uses a number of pre-commit checks. In this section we'll detail what is used, where they're defined, and how to modify these checks. 135 | 136 | #### Pre-commit CI 137 | 138 | We recommend setting up [pre-commit.ci][] to enforce consistency checks on every commit 139 | and pull-request. 140 | 141 | To do so, head over to [pre-commit.ci][] and click "Sign In With GitHub". Follow 142 | the instructions to enable pre-commit.ci for your account or your organization. You 143 | may choose to enable the service for an entire organization or on a per-repository basis. 144 | 145 | Once authorized, pre-commit.ci should automatically be activated. 146 | 147 | #### Overview of pre-commit hooks used by the template 148 | 149 | The following pre-commit hooks are for code style and format: 150 | 151 | - [black](https://black.readthedocs.io/en/stable/): 152 | standard code formatter in Python. 153 | - [blacken-docs](https://github.com/asottile/blacken-docs): 154 | black on Python code in docs. 155 | - [prettier](https://prettier.io/docs/en/index.html): 156 | standard code formatter for non-Python files (e.g. YAML). 157 | - [ruff][] based checks: 158 | - [isort](https://beta.ruff.rs/docs/rules/#isort-i) (rule category: `I`): 159 | sort module imports into sections and types. 160 | - [pydocstyle](https://beta.ruff.rs/docs/rules/#pydocstyle-d) (rule category: `D`): 161 | pydocstyle extension of flake8. 162 | - [flake8-tidy-imports](https://beta.ruff.rs/docs/rules/#flake8-tidy-imports-tid) (rule category: `TID`): 163 | tidy module imports. 164 | - [flake8-comprehensions](https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4) (rule category: `C4`): 165 | write better list/set/dict comprehensions. 166 | - [pyupgrade](https://beta.ruff.rs/docs/rules/#pyupgrade-up) (rule category: `UP`): 167 | upgrade syntax for newer versions of the language. 168 | 169 | The following pre-commit hooks are for errors and inconsistencies: 170 | 171 | - [pre-commit-hooks](https://github.com/pre-commit/pre-commit-hooks): generic pre-commit hooks for text files. 172 | - **detect-private-key**: checks for the existence of private keys. 173 | - **check-ast**: check whether files parse as valid python. 174 | - **end-of-file-fixer**: check files end in a newline and only a newline. 175 | - **mixed-line-ending**: checks mixed line ending. 176 | - **trailing-whitespace**: trims trailing whitespace. 177 | - **check-case-conflict**: check files that would conflict with case-insensitive file systems. 178 | - **forbid-to-commit**: Make sure that `*.rej` files cannot be commited. 179 | These files are created by the [automated template sync](#automated-template-sync) 180 | if there's a merge conflict and need to be addressed manually. 181 | - [ruff][] based checks: 182 | - [pyflakes](https://beta.ruff.rs/docs/rules/#pyflakes-f) (rule category: `F`): 183 | various checks for errors. 184 | - [pycodestyle](https://beta.ruff.rs/docs/rules/#pycodestyle-e-w) (rule category: `E`, `W`): 185 | various checks for errors. 186 | - [flake8-bugbear](https://beta.ruff.rs/docs/rules/#flake8-bugbear-b) (rule category: `B`): 187 | find possible bugs and design issues in program. 188 | - [flake8-blind-except](https://beta.ruff.rs/docs/rules/#flake8-blind-except-ble) (rule category: `BLE`): 189 | checks for blind, catch-all `except` statements. 190 | - [Ruff-specific rules](https://beta.ruff.rs/docs/rules/#ruff-specific-rules-ruf) (rule category: `RUF`): 191 | - `RUF100`: remove unneccesary `# noqa` comments () 192 | 193 | #### How to add or remove pre-commit checks 194 | 195 | The [pre-commit checks](#pre-commit-checks) check for both correctness and stylistic errors. 196 | In some cases it might overshoot and you may have good reasons to ignore certain warnings. 197 | This section shows you where these checks are defined, and how to enable/ disable them. 198 | 199 | ##### pre-commit 200 | 201 | You can add or remove pre-commit checks by simply deleting relevant lines in the `.pre-commit-config.yaml` file under the repository root. 202 | Some pre-commit checks have additional options that can be specified either in the `pyproject.toml` (for `ruff` and `black`) or tool-specific 203 | config files, such as `.prettierrc.yml` for **prettier**. 204 | 205 | ##### Ruff 206 | 207 | This template configures `ruff` through the `[tool.ruff]` entry in the `pyproject.toml`. 208 | For further information `ruff` configuration, see [the docs](https://beta.ruff.rs/docs/configuration/). 209 | 210 | Ruff assigns code to the rules it checks (e.g. `E401`) and groups them under a rule category (e.g. `E`). 211 | Rule categories are selectively enabled by including them under the `select` key: 212 | 213 | ```toml 214 | [tool.ruff] 215 | ... 216 | 217 | select = [ 218 | "F", # Errors detected by Pyflakes 219 | "E", # Error detected by Pycodestyle 220 | "W", # Warning detected by Pycodestyle 221 | "I", # isort 222 | ... 223 | ] 224 | ``` 225 | 226 | The `ignore` entry is used to disable specific rules for the entire project. 227 | Add the rule code(s) you want to ignore and don't forget to add a comment explaining why. 228 | You can find a long list of checks that this template disables by default sitting there already. 229 | 230 | ```toml 231 | ignore = [ 232 | ... 233 | # __magic__ methods are are often self-explanatory, allow missing docstrings 234 | "D105", 235 | ... 236 | ] 237 | ``` 238 | 239 | Checks can be ignored per-file (or glob pattern) with `[tool.ruff.per-file-ignores]`. 240 | 241 | ```toml 242 | [tool.ruff.per-file-ignores] 243 | "docs/*" = ["I"] 244 | "tests/*" = ["D"] 245 | "*/__init__.py" = ["F401"] 246 | ``` 247 | 248 | To ignore a specific rule on a per-case basis, you can add a `# noqa: [, , …]` comment to the offending line. 249 | Specify the rule code(s) to ignore, with e.g. `# noqa: E731`. Check the [Ruff guide][] for reference. 250 | 251 | ```{note} 252 | The `RUF100` check will remove rule codes that are no longer necessary from `noqa` comments. 253 | If you want to add a code that comes from a tool other than Ruff, 254 | add it to Ruff’s [`external = [...]`](https://beta.ruff.rs/docs/settings/#external) setting to prevent `RUF100` from removing it. 255 | ``` 256 | 257 | [ruff]: https://beta.ruff.rs/docs/ 258 | [ruff guide]: https://beta.ruff.rs/docs/configuration/#suppressing-errors 259 | 260 | ### API design 261 | 262 | Scverse ecosystem packages should operate on [AnnData][] and/or [MuData][] data structures and typically use an API 263 | as originally [introduced by scanpy][scanpy-api] with the following submodules: 264 | 265 | - `pp` for preprocessing 266 | - `tl` for tools (that, compared to `pp` generate interpretable output, often associated with a corresponding plotting 267 | function) 268 | - `pl` for plotting functions 269 | 270 | You may add additional submodules as appropriate. While we encourage to follow a scanpy-like API for ecosystem packages, 271 | there may also be good reasons to choose a different approach, e.g. using an object-oriented API. 272 | 273 | [scanpy-api]: https://scanpy.readthedocs.io/en/stable/usage-principles.html 274 | 275 | ### Using VCS-based versioning 276 | 277 | By default, the template uses hard-coded version numbers that are set in `pyproject.toml` and [managed with 278 | bump2version](contributing.md#publishing-a-release). If you prefer to have your project automatically infer version numbers from git 279 | tags, it is straightforward to switch to vcs-based versioning using [hatch-vcs][]. 280 | 281 | In `pyproject.toml` add the following changes, and you are good to go! 282 | 283 | ```diff 284 | --- a/pyproject.toml 285 | +++ b/pyproject.toml 286 | @@ -1,11 +1,11 @@ 287 | [build-system] 288 | build-backend = "hatchling.build" 289 | -requires = ["hatchling"] 290 | +requires = ["hatchling", "hatch-vcs"] 291 | [project] 292 | name = "scib-metrics" 293 | -version = "0.3.1dev" 294 | +dynamic = ["version"] 295 | @@ -60,6 +60,9 @@ 296 | +[tool.hatch.version] 297 | +source = "vcs" 298 | + 299 | [tool.coverage.run] 300 | source = ["scib-metrics"] 301 | omit = [ 302 | ``` 303 | 304 | Don't forget to update the [Making a release section](contributing.md#publishing-a-release) in this document accordingly, after you are done! 305 | 306 | [hatch-vcs]: https://pypi.org/project/hatch-vcs/ 307 | 308 | ### Automated template sync 309 | 310 | Automated template sync is enabled by default. This means that every night, a GitHub action runs [cruft][] to check 311 | if a new version of the `scverse-cookiecutter` template got released. If there are any new changes, a pull request 312 | proposing these changes is created automatically. This helps keeping the repository up-to-date with the latest 313 | coding standards. 314 | 315 | It may happen that a template sync results in a merge conflict. If this is the case a `*.ref` file with the 316 | diff is created. You need to manually address these changes and remove the `.rej` file when you are done. 317 | The pull request can only be merged after all `*.rej` files have been removed. 318 | 319 | :::{tip} 320 | The following hints may be useful to work with the template sync: 321 | 322 | - GitHub automatically disables scheduled actions if there has been not activity to the repository for 60 days. 323 | You can re-enable or manually trigger the sync by navigating to `Actions` -> `Sync Template` in your GitHub repository. 324 | - If you want to ignore certain files from the template update, you can add them to the `[tool.cruft]` section in the 325 | `pyproject.toml` file in the root of your repository. More details are described in the 326 | [cruft documentation][cruft-update-project]. 327 | - To disable the sync entirely, simply remove the file `.github/workflows/sync.yaml`. 328 | 329 | ::: 330 | 331 | [cruft]: https://cruft.github.io/cruft/ 332 | [cruft-update-project]: https://cruft.github.io/cruft/#updating-a-project 333 | 334 | ### Making a release 335 | 336 | #### Updating the version number 337 | 338 | Before making a release, you need to update the version number. Please adhere to [Semantic Versioning][semver], in brief 339 | 340 | > Given a version number MAJOR.MINOR.PATCH, increment the: 341 | > 342 | > 1. MAJOR version when you make incompatible API changes, 343 | > 2. MINOR version when you add functionality in a backwards compatible manner, and 344 | > 3. PATCH version when you make backwards compatible bug fixes. 345 | > 346 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 347 | > We use [bump2version][] to automatically update the version number in all places and automatically create a git tag. 348 | > Run one of the following commands in the root of the repository 349 | 350 | ```bash 351 | bump2version patch 352 | bump2version minor 353 | bump2version major 354 | ``` 355 | 356 | Once you are done, run 357 | 358 | ``` 359 | git push --tags 360 | ``` 361 | 362 | to publish the created tag on GitHub. 363 | 364 | [bump2version]: https://github.com/c4urself/bump2version 365 | 366 | #### Upload on PyPI 367 | 368 | Please follow the [Python packaging tutorial][]. 369 | 370 | It is possible to automate this with GitHub actions, see also [this feature request][pypi-feature-request] 371 | in the cookiecutter-scverse template. 372 | 373 | [python packaging tutorial]: https://packaging.python.org/en/latest/tutorials/packaging-projects/#generating-distribution-archives 374 | [pypi-feature-request]: https://github.com/scverse/cookiecutter-scverse/issues/88 375 | 376 | ### Writing documentation 377 | 378 | Please write documentation for your package. This project uses [sphinx][] with the following features: 379 | 380 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text 381 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). 382 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) 383 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types 384 | 385 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information 386 | on how to write documentation. 387 | 388 | ### Tutorials with myst-nb and jupyter notebooks 389 | 390 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. 391 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. 392 | It is your reponsibility to update and re-run the notebook whenever necessary. 393 | 394 | If you are interested in automatically running notebooks as part of the continuous integration, please check 395 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse` 396 | repository. 397 | 398 | #### Hints 399 | 400 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only 401 | if you do so can sphinx automatically create a link to the external documentation. 402 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to 403 | the `nitpick_ignore` list in `docs/conf.py` 404 | 405 | #### Building the docs locally 406 | 407 | ```bash 408 | cd docs 409 | make html 410 | open _build/html/index.html 411 | ``` 412 | 413 | 414 | 415 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 416 | [codecov]: https://about.codecov.io/sign-up/ 417 | [codecov docs]: https://docs.codecov.com/docs 418 | [pre-commit.ci]: https://pre-commit.ci/ 419 | [readthedocs.org]: https://readthedocs.org/ 420 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 421 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 422 | [pre-commit]: https://pre-commit.com/ 423 | [anndata]: https://github.com/scverse/anndata 424 | [mudata]: https://github.com/scverse/mudata 425 | [pytest]: https://docs.pytest.org/ 426 | [semver]: https://semver.org/ 427 | [sphinx]: https://www.sphinx-doc.org/en/master/ 428 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 429 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 430 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 431 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 432 | -------------------------------------------------------------------------------- /docs/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | ## Walkthrough 4 | 5 | ```{toctree} 6 | :maxdepth: 1 7 | 8 | notebooks/lung_example 9 | notebooks/large_scale 10 | ``` 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | 6 | [project] 7 | name = "scib-metrics" 8 | version = "0.5.4" 9 | description = "Accelerated and Python-only scIB metrics" 10 | readme = "README.md" 11 | requires-python = ">=3.10" 12 | license = { file = "LICENSE" } 13 | authors = [{ name = "Adam Gayoso" }] 14 | maintainers = [{ name = "Adam Gayoso", email = "adamgayoso@berkeley.edu" }] 15 | urls.Documentation = "https://scib-metrics.readthedocs.io/" 16 | urls.Source = "https://github.com/yoseflab/scib-metrics" 17 | urls.Home-page = "https://github.com/yoseflab/scib-metrics" 18 | dependencies = [ 19 | "anndata", 20 | "chex", 21 | "jax", 22 | "jaxlib", 23 | "numpy", 24 | "pandas", 25 | "scipy", 26 | "scikit-learn", 27 | "scanpy>=1.9", 28 | "rich", 29 | "pynndescent", 30 | "igraph>0.9.0", 31 | "matplotlib", 32 | "plottable", 33 | "tqdm", 34 | "umap-learn>=0.5.0", 35 | ] 36 | 37 | [project.optional-dependencies] 38 | dev = ["pre-commit", "twine>=4.0.2"] 39 | doc = [ 40 | "sphinx>=4", 41 | "sphinx-book-theme>=1.0", 42 | "myst-nb", 43 | "sphinxcontrib-bibtex>=1.0.0", 44 | "scanpydoc[typehints]>=0.7.4", 45 | "sphinxext-opengraph", 46 | # For notebooks 47 | "ipython", 48 | "ipykernel", 49 | "sphinx-copybutton", 50 | "numba>=0.57.1", 51 | ] 52 | test = [ 53 | "pytest", 54 | "coverage", 55 | "scib>=1.1.4", 56 | "harmonypy", 57 | "joblib", 58 | # For vscode Python extension testing 59 | "flake8", 60 | "black", 61 | "numba>=0.57.1", 62 | ] 63 | parallel = ["joblib"] 64 | tutorial = [ 65 | "rich", 66 | "scanorama", 67 | "harmony-pytorch", 68 | "scvi-tools", 69 | "pyliger", 70 | "numexpr", # missing liger dependency 71 | "plotnine", # missing liger dependency 72 | "mygene", # missing liger dependency 73 | "goatools", # missing liger dependency 74 | "adjustText", # missing liger dependency 75 | ] 76 | 77 | [tool.hatch.build.targets.wheel] 78 | packages = ['src/scib_metrics'] 79 | 80 | [tool.coverage.run] 81 | source = ["scib_metrics"] 82 | omit = ["**/test_*.py"] 83 | 84 | [tool.pytest.ini_options] 85 | testpaths = ["tests"] 86 | xfail_strict = true 87 | 88 | 89 | [tool.ruff] 90 | src = ["src"] 91 | line-length = 120 92 | lint.select = [ 93 | "F", # Errors detected by Pyflakes 94 | "E", # Error detected by Pycodestyle 95 | "W", # Warning detected by Pycodestyle 96 | "I", # isort 97 | "D", # pydocstyle 98 | "B", # flake8-bugbear 99 | "TID", # flake8-tidy-imports 100 | "C4", # flake8-comprehensions 101 | "BLE", # flake8-blind-except 102 | "UP", # pyupgrade 103 | "RUF100", # Report unused noqa directives 104 | "ICN", # flake8-import-conventions 105 | "TCH", # flake8-type-checking 106 | "FA", # flake8-future-annotations 107 | ] 108 | lint.ignore = [ 109 | # line too long -> we accept long comment lines; formatter gets rid of long code lines 110 | "E501", 111 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 112 | "E731", 113 | # allow I, O, l as variable names -> I is the identity matrix 114 | "E741", 115 | # Missing docstring in public package 116 | "D104", 117 | # Missing docstring in public module 118 | "D100", 119 | # Missing docstring in __init__ 120 | "D107", 121 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 122 | "B008", 123 | # __magic__ methods are are often self-explanatory, allow missing docstrings 124 | "D105", 125 | # First line should be in imperative mood; try rephrasing 126 | "D401", 127 | ## Disable one in each pair of mutually incompatible rules 128 | # We don’t want a blank line before a class docstring 129 | "D203", 130 | # We want docstrings to start immediately after the opening triple quote 131 | "D213", 132 | # Missing argument description in the docstring TODO: enable 133 | "D417", 134 | # No explicit stacklevel argument 135 | "B028", 136 | ] 137 | extend-include = ["*.ipynb"] 138 | 139 | [tool.ruff.lint.per-file-ignores] 140 | "docs/*" = ["I", "BLE001", "E402", "B018"] 141 | "tests/*" = ["D", "B018"] 142 | "*/__init__.py" = ["F401"] 143 | "src/scib_metrics/__init__.py" = ["I"] 144 | 145 | [tool.ruff.format] 146 | docstring-code-format = true 147 | 148 | [tool.ruff.lint.pydocstyle] 149 | convention = "numpy" 150 | 151 | [tool.jupytext] 152 | formats = "ipynb,md" 153 | 154 | [tool.cruft] 155 | skip = [ 156 | "tests", 157 | "src/**/__init__.py", 158 | "src/**/basic.py", 159 | "docs/api.md", 160 | "docs/changelog.md", 161 | "docs/references.bib", 162 | "docs/references.md", 163 | "docs/notebooks/example.ipynb", 164 | ] 165 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This is a shim to allow Github to detect the package, build is done with hatch 4 | 5 | import setuptools 6 | 7 | if __name__ == "__main__": 8 | setuptools.setup(name="scib-metrics") 9 | -------------------------------------------------------------------------------- /src/scib_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from importlib.metadata import version 3 | 4 | from . import nearest_neighbors, utils 5 | from .metrics import ( 6 | graph_connectivity, 7 | isolated_labels, 8 | kbet, 9 | kbet_per_label, 10 | clisi_knn, 11 | ilisi_knn, 12 | lisi_knn, 13 | nmi_ari_cluster_labels_kmeans, 14 | nmi_ari_cluster_labels_leiden, 15 | pcr_comparison, 16 | silhouette_batch, 17 | silhouette_label, 18 | ) 19 | from ._settings import settings 20 | 21 | __all__ = [ 22 | "utils", 23 | "nearest_neighbors", 24 | "isolated_labels", 25 | "pcr_comparison", 26 | "silhouette_label", 27 | "silhouette_batch", 28 | "ilisi_knn", 29 | "clisi_knn", 30 | "lisi_knn", 31 | "nmi_ari_cluster_labels_kmeans", 32 | "nmi_ari_cluster_labels_leiden", 33 | "kbet", 34 | "kbet_per_label", 35 | "graph_connectivity", 36 | "settings", 37 | ] 38 | 39 | __version__ = version("scib-metrics") 40 | 41 | settings.verbosity = logging.INFO 42 | # Jax sets the root logger, this prevents double output. 43 | logger = logging.getLogger("scib_metrics") 44 | logger.propagate = False 45 | -------------------------------------------------------------------------------- /src/scib_metrics/_settings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Literal 4 | 5 | from rich.console import Console 6 | from rich.logging import RichHandler 7 | 8 | scib_logger = logging.getLogger("scib_metrics") 9 | 10 | 11 | class ScibConfig: 12 | """Config manager for scib-metrics. 13 | 14 | Examples 15 | -------- 16 | To set the progress bar style, choose one of "rich", "tqdm" 17 | 18 | >>> scib_metrics.settings.progress_bar_style = "rich" 19 | 20 | To set the verbosity 21 | 22 | >>> import logging 23 | >>> scib_metrics.settings.verbosity = logging.INFO 24 | """ 25 | 26 | def __init__( 27 | self, 28 | verbosity: int = logging.INFO, 29 | progress_bar_style: Literal["rich", "tqdm"] = "tqdm", 30 | jax_preallocate_gpu_memory: bool = False, 31 | ): 32 | if progress_bar_style not in ["rich", "tqdm"]: 33 | raise ValueError("Progress bar style must be in ['rich', 'tqdm']") 34 | self.progress_bar_style = progress_bar_style 35 | self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory 36 | self.verbosity = verbosity 37 | 38 | @property 39 | def progress_bar_style(self) -> str: 40 | """Library to use for progress bar.""" 41 | return self._pbar_style 42 | 43 | @progress_bar_style.setter 44 | def progress_bar_style(self, pbar_style: Literal["tqdm", "rich"]): 45 | """Library to use for progress bar.""" 46 | self._pbar_style = pbar_style 47 | 48 | @property 49 | def verbosity(self) -> int: 50 | """Verbosity level (default `logging.INFO`). 51 | 52 | Returns 53 | ------- 54 | verbosity: int 55 | """ 56 | return self._verbosity 57 | 58 | @verbosity.setter 59 | def verbosity(self, level: str | int): 60 | """Set verbosity level. 61 | 62 | If "scib_metrics" logger has no StreamHandler, add one. 63 | Else, set its level to `level`. 64 | 65 | Parameters 66 | ---------- 67 | level 68 | Sets "scib_metrics" logging level to `level` 69 | force_terminal 70 | Rich logging option, set to False if piping to file output. 71 | """ 72 | self._verbosity = level 73 | scib_logger.setLevel(level) 74 | if len(scib_logger.handlers) == 0: 75 | console = Console(force_terminal=True) 76 | if console.is_jupyter is True: 77 | console.is_jupyter = False 78 | ch = RichHandler(level=level, show_path=False, console=console, show_time=False) 79 | formatter = logging.Formatter("%(message)s") 80 | ch.setFormatter(formatter) 81 | scib_logger.addHandler(ch) 82 | else: 83 | scib_logger.setLevel(level) 84 | 85 | def reset_logging_handler(self) -> None: 86 | """Reset "scib_metrics" log handler to a basic RichHandler(). 87 | 88 | This is useful if piping outputs to a file. 89 | 90 | Returns 91 | ------- 92 | None 93 | """ 94 | scib_logger.removeHandler(scib_logger.handlers[0]) 95 | ch = RichHandler(level=self._verbosity, show_path=False, show_time=False) 96 | formatter = logging.Formatter("%(message)s") 97 | ch.setFormatter(formatter) 98 | scib_logger.addHandler(ch) 99 | 100 | def jax_fix_no_kernel_image(self) -> None: 101 | """Fix for JAX error "No kernel image is available for execution on the device".""" 102 | os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1" 103 | 104 | @property 105 | def jax_preallocate_gpu_memory(self): 106 | """Jax GPU memory allocation settings. 107 | 108 | If False, Jax will ony preallocate GPU memory it needs. 109 | If float in (0, 1), Jax will preallocate GPU memory to that 110 | fraction of the GPU memory. 111 | 112 | Returns 113 | ------- 114 | jax_preallocate_gpu_memory: bool or float 115 | """ 116 | return self._jax_gpu 117 | 118 | @jax_preallocate_gpu_memory.setter 119 | def jax_preallocate_gpu_memory(self, value: float | bool): 120 | # see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation 121 | if value is False: 122 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 123 | elif isinstance(value, float): 124 | if value >= 1 or value <= 0: 125 | raise ValueError("Need to use a value between 0 and 1") 126 | # format is ".XX" 127 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(value)[1:4] 128 | else: 129 | raise ValueError("value not understood, need bool or float in (0, 1)") 130 | self._jax_gpu = value 131 | 132 | 133 | settings = ScibConfig() 134 | -------------------------------------------------------------------------------- /src/scib_metrics/_types.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import scipy.sparse as sp 4 | from jax import Array 5 | 6 | NdArray = np.ndarray | jnp.ndarray 7 | IntOrKey = int | Array 8 | ArrayLike = np.ndarray | sp.spmatrix | jnp.ndarray 9 | -------------------------------------------------------------------------------- /src/scib_metrics/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from ._core import BatchCorrection, Benchmarker, BioConservation 2 | 3 | __all__ = ["Benchmarker", "BioConservation", "BatchCorrection"] 4 | -------------------------------------------------------------------------------- /src/scib_metrics/benchmark/_core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from collections.abc import Callable 4 | from dataclasses import asdict, dataclass 5 | from enum import Enum 6 | from functools import partial 7 | from typing import Any 8 | 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import pandas as pd 13 | import scanpy as sc 14 | from anndata import AnnData 15 | from plottable import ColumnDefinition, Table 16 | from plottable.cmap import normed_cmap 17 | from plottable.plots import bar 18 | from sklearn.preprocessing import MinMaxScaler 19 | from tqdm import tqdm 20 | 21 | import scib_metrics 22 | from scib_metrics.nearest_neighbors import NeighborsResults, pynndescent 23 | 24 | Kwargs = dict[str, Any] 25 | MetricType = bool | Kwargs 26 | 27 | _LABELS = "labels" 28 | _BATCH = "batch" 29 | _X_PRE = "X_pre" 30 | _METRIC_TYPE = "Metric Type" 31 | _AGGREGATE_SCORE = "Aggregate score" 32 | 33 | # Mapping of metric fn names to clean DataFrame column names 34 | metric_name_cleaner = { 35 | "silhouette_label": "Silhouette label", 36 | "silhouette_batch": "Silhouette batch", 37 | "isolated_labels": "Isolated labels", 38 | "nmi_ari_cluster_labels_leiden_nmi": "Leiden NMI", 39 | "nmi_ari_cluster_labels_leiden_ari": "Leiden ARI", 40 | "nmi_ari_cluster_labels_kmeans_nmi": "KMeans NMI", 41 | "nmi_ari_cluster_labels_kmeans_ari": "KMeans ARI", 42 | "clisi_knn": "cLISI", 43 | "ilisi_knn": "iLISI", 44 | "kbet_per_label": "KBET", 45 | "graph_connectivity": "Graph connectivity", 46 | "pcr_comparison": "PCR comparison", 47 | } 48 | 49 | 50 | @dataclass(frozen=True) 51 | class BioConservation: 52 | """Specification of bio conservation metrics to run in the pipeline. 53 | 54 | Metrics can be included using a boolean flag. Custom keyword args can be 55 | used by passing a dictionary here. Keyword args should not set data-related 56 | parameters, such as `X` or `labels`. 57 | """ 58 | 59 | isolated_labels: MetricType = True 60 | nmi_ari_cluster_labels_leiden: MetricType = False 61 | nmi_ari_cluster_labels_kmeans: MetricType = True 62 | silhouette_label: MetricType = True 63 | clisi_knn: MetricType = True 64 | 65 | 66 | @dataclass(frozen=True) 67 | class BatchCorrection: 68 | """Specification of which batch correction metrics to run in the pipeline. 69 | 70 | Metrics can be included using a boolean flag. Custom keyword args can be 71 | used by passing a dictionary here. Keyword args should not set data-related 72 | parameters, such as `X` or `labels`. 73 | """ 74 | 75 | silhouette_batch: MetricType = True 76 | ilisi_knn: MetricType = True 77 | kbet_per_label: MetricType = True 78 | graph_connectivity: MetricType = True 79 | pcr_comparison: MetricType = True 80 | 81 | 82 | class MetricAnnDataAPI(Enum): 83 | """Specification of the AnnData API for a metric.""" 84 | 85 | isolated_labels = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH]) 86 | nmi_ari_cluster_labels_leiden = lambda ad, fn: fn(ad.uns["15_neighbor_res"], ad.obs[_LABELS]) 87 | nmi_ari_cluster_labels_kmeans = lambda ad, fn: fn(ad.X, ad.obs[_LABELS]) 88 | silhouette_label = lambda ad, fn: fn(ad.X, ad.obs[_LABELS]) 89 | clisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_LABELS]) 90 | graph_connectivity = lambda ad, fn: fn(ad.uns["15_neighbor_res"], ad.obs[_LABELS]) 91 | silhouette_batch = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH]) 92 | pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True) 93 | ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH]) 94 | kbet_per_label = lambda ad, fn: fn(ad.uns["50_neighbor_res"], ad.obs[_BATCH], ad.obs[_LABELS]) 95 | 96 | 97 | class Benchmarker: 98 | """Benchmarking pipeline for the single-cell integration task. 99 | 100 | Parameters 101 | ---------- 102 | adata 103 | AnnData object containing the raw count data and integrated embeddings as obsm keys. 104 | batch_key 105 | Key in `adata.obs` that contains the batch information. 106 | label_key 107 | Key in `adata.obs` that contains the cell type labels. 108 | embedding_obsm_keys 109 | List of obsm keys that contain the embeddings to be benchmarked. 110 | bio_conservation_metrics 111 | Specification of which bio conservation metrics to run in the pipeline. 112 | batch_correction_metrics 113 | Specification of which batch correction metrics to run in the pipeline. 114 | pre_integrated_embedding_obsm_key 115 | Obsm key containing a non-integrated embedding of the data. If `None`, the embedding will be computed 116 | in the prepare step. See the notes below for more information. 117 | n_jobs 118 | Number of jobs to use for parallelization of neighbor search. 119 | progress_bar 120 | Whether to show a progress bar for :meth:`~scib_metrics.benchmark.Benchmarker.prepare` and 121 | :meth:`~scib_metrics.benchmark.Benchmarker.benchmark`. 122 | 123 | Notes 124 | ----- 125 | `adata.X` should contain a form of the data that is not integrated, but is normalized. The `prepare` method will 126 | use `adata.X` for PCA via :func:`~scanpy.tl.pca`, which also only uses features masked via `adata.var['highly_variable']`. 127 | 128 | See further usage examples in the following tutorial: 129 | 130 | 1. :doc:`/notebooks/lung_example` 131 | """ 132 | 133 | def __init__( 134 | self, 135 | adata: AnnData, 136 | batch_key: str, 137 | label_key: str, 138 | embedding_obsm_keys: list[str], 139 | bio_conservation_metrics: BioConservation | None = BioConservation(), 140 | batch_correction_metrics: BatchCorrection | None = BatchCorrection(), 141 | pre_integrated_embedding_obsm_key: str | None = None, 142 | n_jobs: int = 1, 143 | progress_bar: bool = True, 144 | ): 145 | self._adata = adata 146 | self._embedding_obsm_keys = embedding_obsm_keys 147 | self._pre_integrated_embedding_obsm_key = pre_integrated_embedding_obsm_key 148 | self._bio_conservation_metrics = bio_conservation_metrics 149 | self._batch_correction_metrics = batch_correction_metrics 150 | self._results = pd.DataFrame(columns=list(self._embedding_obsm_keys) + [_METRIC_TYPE]) 151 | self._emb_adatas = {} 152 | self._neighbor_values = (15, 50, 90) 153 | self._prepared = False 154 | self._benchmarked = False 155 | self._batch_key = batch_key 156 | self._label_key = label_key 157 | self._n_jobs = n_jobs 158 | self._progress_bar = progress_bar 159 | 160 | if self._bio_conservation_metrics is None and self._batch_correction_metrics is None: 161 | raise ValueError("Either batch or bio metrics must be defined.") 162 | 163 | self._metric_collection_dict = {} 164 | if self._bio_conservation_metrics is not None: 165 | self._metric_collection_dict.update({"Bio conservation": self._bio_conservation_metrics}) 166 | if self._batch_correction_metrics is not None: 167 | self._metric_collection_dict.update({"Batch correction": self._batch_correction_metrics}) 168 | 169 | def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None) -> None: 170 | """Prepare the data for benchmarking. 171 | 172 | Parameters 173 | ---------- 174 | neighbor_computer 175 | Function that computes the neighbors of the data. If `None`, the neighbors will be computed 176 | with :func:`~scib_metrics.utils.nearest_neighbors.pynndescent`. The function should take as input 177 | the data and the number of neighbors to compute and return a :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` 178 | object. 179 | """ 180 | # Compute PCA 181 | if self._pre_integrated_embedding_obsm_key is None: 182 | # This is how scib does it 183 | # https://github.com/theislab/scib/blob/896f689e5fe8c57502cb012af06bed1a9b2b61d2/scib/metrics/pcr.py#L197 184 | sc.tl.pca(self._adata, use_highly_variable=False) 185 | self._pre_integrated_embedding_obsm_key = "X_pca" 186 | 187 | for emb_key in self._embedding_obsm_keys: 188 | self._emb_adatas[emb_key] = AnnData(self._adata.obsm[emb_key], obs=self._adata.obs) 189 | self._emb_adatas[emb_key].obs[_BATCH] = np.asarray(self._adata.obs[self._batch_key].values) 190 | self._emb_adatas[emb_key].obs[_LABELS] = np.asarray(self._adata.obs[self._label_key].values) 191 | self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key] 192 | 193 | # Compute neighbors 194 | progress = self._emb_adatas.values() 195 | if self._progress_bar: 196 | progress = tqdm(progress, desc="Computing neighbors") 197 | 198 | for ad in progress: 199 | if neighbor_computer is not None: 200 | neigh_result = neighbor_computer(ad.X, max(self._neighbor_values)) 201 | else: 202 | neigh_result = pynndescent( 203 | ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs 204 | ) 205 | for n in self._neighbor_values: 206 | ad.uns[f"{n}_neighbor_res"] = neigh_result.subset_neighbors(n=n) 207 | 208 | self._prepared = True 209 | 210 | def benchmark(self) -> None: 211 | """Run the pipeline.""" 212 | if self._benchmarked: 213 | warnings.warn( 214 | "The benchmark has already been run. Running it again will overwrite the previous results.", 215 | UserWarning, 216 | ) 217 | 218 | if not self._prepared: 219 | self.prepare() 220 | 221 | num_metrics = sum( 222 | [sum([v is not False for v in asdict(met_col)]) for met_col in self._metric_collection_dict.values()] 223 | ) 224 | 225 | progress_embs = self._emb_adatas.items() 226 | if self._progress_bar: 227 | progress_embs = tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green") 228 | 229 | for emb_key, ad in progress_embs: 230 | pbar = None 231 | if self._progress_bar: 232 | pbar = tqdm(total=num_metrics, desc="Metrics", position=1, leave=False, colour="blue") 233 | for metric_type, metric_collection in self._metric_collection_dict.items(): 234 | for metric_name, use_metric_or_kwargs in asdict(metric_collection).items(): 235 | if use_metric_or_kwargs: 236 | pbar.set_postfix_str(f"{metric_type}: {metric_name}") if pbar is not None else None 237 | metric_fn = getattr(scib_metrics, metric_name) 238 | if isinstance(use_metric_or_kwargs, dict): 239 | # Kwargs in this case 240 | metric_fn = partial(metric_fn, **use_metric_or_kwargs) 241 | metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn) 242 | # nmi/ari metrics return a dict 243 | if isinstance(metric_value, dict): 244 | for k, v in metric_value.items(): 245 | self._results.loc[f"{metric_name}_{k}", emb_key] = v 246 | self._results.loc[f"{metric_name}_{k}", _METRIC_TYPE] = metric_type 247 | else: 248 | self._results.loc[metric_name, emb_key] = metric_value 249 | self._results.loc[metric_name, _METRIC_TYPE] = metric_type 250 | pbar.update(1) if pbar is not None else None 251 | 252 | self._benchmarked = True 253 | 254 | def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> pd.DataFrame: 255 | """Return the benchmarking results. 256 | 257 | Parameters 258 | ---------- 259 | min_max_scale 260 | Whether to min max scale the results. 261 | clean_names 262 | Whether to clean the metric names. 263 | 264 | Returns 265 | ------- 266 | The benchmarking results. 267 | """ 268 | df = self._results.transpose() 269 | df.index.name = "Embedding" 270 | df = df.loc[df.index != _METRIC_TYPE] 271 | if min_max_scale: 272 | # Use sklearn to min max scale 273 | df = pd.DataFrame( 274 | MinMaxScaler().fit_transform(df), 275 | columns=df.columns, 276 | index=df.index, 277 | ) 278 | if clean_names: 279 | df = df.rename(columns=metric_name_cleaner) 280 | df = df.transpose() 281 | df[_METRIC_TYPE] = self._results[_METRIC_TYPE].values 282 | 283 | # Compute scores 284 | per_class_score = df.groupby(_METRIC_TYPE).mean().transpose() 285 | # This is the default scIB weighting from the manuscript 286 | if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None: 287 | per_class_score["Total"] = ( 288 | 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"] 289 | ) 290 | df = pd.concat([df.transpose(), per_class_score], axis=1) 291 | df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE 292 | return df 293 | 294 | def plot_results_table(self, min_max_scale: bool = True, show: bool = True, save_dir: str | None = None) -> Table: 295 | """Plot the benchmarking results. 296 | 297 | Parameters 298 | ---------- 299 | min_max_scale 300 | Whether to min max scale the results. 301 | show 302 | Whether to show the plot. 303 | save_dir 304 | The directory to save the plot to. If `None`, the plot is not saved. 305 | """ 306 | num_embeds = len(self._embedding_obsm_keys) 307 | cmap_fn = lambda col_data: normed_cmap(col_data, cmap=mpl.cm.PRGn, num_stds=2.5) 308 | df = self.get_results(min_max_scale=min_max_scale) 309 | # Do not want to plot what kind of metric it is 310 | plot_df = df.drop(_METRIC_TYPE, axis=0) 311 | # Sort by total score 312 | if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None: 313 | sort_col = "Total" 314 | elif self._batch_correction_metrics is not None: 315 | sort_col = "Batch correction" 316 | else: 317 | sort_col = "Bio conservation" 318 | plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64) 319 | plot_df["Method"] = plot_df.index 320 | 321 | # Split columns by metric type, using df as it doesn't have the new method col 322 | score_cols = df.columns[df.loc[_METRIC_TYPE] == _AGGREGATE_SCORE] 323 | other_cols = df.columns[df.loc[_METRIC_TYPE] != _AGGREGATE_SCORE] 324 | column_definitions = [ 325 | ColumnDefinition("Method", width=1.5, textprops={"ha": "left", "weight": "bold"}), 326 | ] 327 | # Circles for the metric values 328 | column_definitions += [ 329 | ColumnDefinition( 330 | col, 331 | title=col.replace(" ", "\n", 1), 332 | width=1, 333 | textprops={ 334 | "ha": "center", 335 | "bbox": {"boxstyle": "circle", "pad": 0.25}, 336 | }, 337 | cmap=cmap_fn(plot_df[col]), 338 | group=df.loc[_METRIC_TYPE, col], 339 | formatter="{:.2f}", 340 | ) 341 | for i, col in enumerate(other_cols) 342 | ] 343 | # Bars for the aggregate scores 344 | column_definitions += [ 345 | ColumnDefinition( 346 | col, 347 | width=1, 348 | title=col.replace(" ", "\n", 1), 349 | plot_fn=bar, 350 | plot_kw={ 351 | "cmap": mpl.cm.YlGnBu, 352 | "plot_bg_bar": False, 353 | "annotate": True, 354 | "height": 0.9, 355 | "formatter": "{:.2f}", 356 | }, 357 | group=df.loc[_METRIC_TYPE, col], 358 | border="left" if i == 0 else None, 359 | ) 360 | for i, col in enumerate(score_cols) 361 | ] 362 | # Allow to manipulate text post-hoc (in illustrator) 363 | with mpl.rc_context({"svg.fonttype": "none"}): 364 | fig, ax = plt.subplots(figsize=(len(df.columns) * 1.25, 3 + 0.3 * num_embeds)) 365 | tab = Table( 366 | plot_df, 367 | cell_kw={ 368 | "linewidth": 0, 369 | "edgecolor": "k", 370 | }, 371 | column_definitions=column_definitions, 372 | ax=ax, 373 | row_dividers=True, 374 | footer_divider=True, 375 | textprops={"fontsize": 10, "ha": "center"}, 376 | row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 5))}, 377 | col_label_divider_kw={"linewidth": 1, "linestyle": "-"}, 378 | column_border_kw={"linewidth": 1, "linestyle": "-"}, 379 | index_col="Method", 380 | ).autoset_fontcolors(colnames=plot_df.columns) 381 | if show: 382 | plt.show() 383 | if save_dir is not None: 384 | fig.savefig(os.path.join(save_dir, "scib_results.svg"), facecolor=ax.get_facecolor(), dpi=300) 385 | 386 | return tab 387 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from ._graph_connectivity import graph_connectivity 2 | from ._isolated_labels import isolated_labels 3 | from ._kbet import kbet, kbet_per_label 4 | from ._lisi import clisi_knn, ilisi_knn, lisi_knn 5 | from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden 6 | from ._pcr_comparison import pcr_comparison 7 | from ._silhouette import silhouette_batch, silhouette_label 8 | 9 | __all__ = [ 10 | "isolated_labels", 11 | "pcr_comparison", 12 | "silhouette_label", 13 | "silhouette_batch", 14 | "ilisi_knn", 15 | "clisi_knn", 16 | "lisi_knn", 17 | "nmi_ari_cluster_labels_kmeans", 18 | "nmi_ari_cluster_labels_leiden", 19 | "kbet", 20 | "kbet_per_label", 21 | "graph_connectivity", 22 | ] 23 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_graph_connectivity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from scipy.sparse.csgraph import connected_components 4 | 5 | from scib_metrics.nearest_neighbors import NeighborsResults 6 | 7 | 8 | def graph_connectivity(X: NeighborsResults, labels: np.ndarray) -> float: 9 | """Quantify the connectivity of the subgraph per cell type label. 10 | 11 | Parameters 12 | ---------- 13 | X 14 | Array of shape (n_cells, n_cells) with non-zero values 15 | representing distances to exactly each cell's k nearest neighbors. 16 | labels 17 | Array of shape (n_cells,) representing label values 18 | for each cell. 19 | """ 20 | # TODO(adamgayoso): Utils for validating inputs 21 | clust_res = [] 22 | 23 | graph = X.knn_graph_distances 24 | 25 | for label in np.unique(labels): 26 | mask = labels == label 27 | if hasattr(mask, "values"): 28 | mask = mask.values 29 | graph_sub = graph[mask] 30 | graph_sub = graph_sub[:, mask] 31 | _, comps = connected_components(graph_sub, connection="strong") 32 | tab = pd.value_counts(comps) 33 | clust_res.append(tab.max() / sum(tab)) 34 | 35 | return np.mean(clust_res) 36 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_isolated_labels.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from scib_metrics.utils import silhouette_samples 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def isolated_labels( 12 | X: np.ndarray, 13 | labels: np.ndarray, 14 | batch: np.ndarray, 15 | rescale: bool = True, 16 | iso_threshold: int | None = None, 17 | ) -> float: 18 | """Isolated label score :cite:p:`luecken2022benchmarking`. 19 | 20 | Score how well labels of isolated labels are distiguished in the dataset by 21 | average-width silhouette score (ASW) on isolated label vs all other labels. 22 | 23 | The default of the original scib package is to use a cluster-based F1 scoring 24 | procedure, but here we use the ASW for speed and simplicity. 25 | 26 | Parameters 27 | ---------- 28 | X 29 | Array of shape (n_cells, n_features). 30 | labels 31 | Array of shape (n_cells,) representing label values 32 | batch 33 | Array of shape (n_cells,) representing batch values 34 | rescale 35 | Scale asw into the range [0, 1]. 36 | iso_threshold 37 | Max number of batches per label for label to be considered as 38 | isolated, if integer. If `None`, considers minimum number of 39 | batches that labels are present in 40 | 41 | Returns 42 | ------- 43 | isolated_label_score 44 | """ 45 | scores = {} 46 | isolated_labels = _get_isolated_labels(labels, batch, iso_threshold) 47 | 48 | silhouette_all = silhouette_samples(X, labels) 49 | if rescale: 50 | silhouette_all = (silhouette_all + 1) / 2 51 | 52 | for label in isolated_labels: 53 | score = np.mean(silhouette_all[labels == label]) 54 | scores[label] = score 55 | scores = pd.Series(scores) 56 | 57 | return scores.mean() 58 | 59 | 60 | def _get_isolated_labels(labels: np.ndarray, batch: np.ndarray, iso_threshold: float): 61 | """Get labels that are isolated depending on the number of batches.""" 62 | tmp = pd.DataFrame() 63 | label_key = "label" 64 | batch_key = "batch" 65 | tmp[label_key] = labels 66 | tmp[batch_key] = batch 67 | tmp = tmp.drop_duplicates() 68 | batch_per_lab = tmp.groupby(label_key).agg({batch_key: "count"}) 69 | 70 | # threshold for determining when label is considered isolated 71 | if iso_threshold is None: 72 | iso_threshold = batch_per_lab.min().tolist()[0] 73 | 74 | logging.info(f"isolated labels: no more than {iso_threshold} batches per label") 75 | 76 | labels = batch_per_lab[batch_per_lab[batch_key] <= iso_threshold].index.tolist() 77 | if len(labels) == 0: 78 | logging.info(f"no isolated labels with less than {iso_threshold} batches") 79 | 80 | return np.array(labels) 81 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_kbet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import pandas as pd 9 | import scipy 10 | 11 | from scib_metrics._types import NdArray 12 | from scib_metrics.nearest_neighbors import NeighborsResults 13 | from scib_metrics.utils import diffusion_nn, get_ndarray 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def _chi2_cdf(df: int | NdArray, x: NdArray) -> float: 19 | """Chi2 cdf. 20 | 21 | See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.chdtr.html 22 | for explanation of gammainc. 23 | """ 24 | return jax.scipy.special.gammainc(df / 2, x / 2) 25 | 26 | 27 | @partial(jax.jit, static_argnums=2) 28 | def _kbet(neigh_batch_ids: jnp.ndarray, batches: jnp.ndarray, n_batches: int) -> float: 29 | expected_freq = jnp.bincount(batches, length=n_batches) 30 | expected_freq = expected_freq / jnp.sum(expected_freq) 31 | dof = n_batches - 1 32 | 33 | observed_counts = jax.vmap(partial(jnp.bincount, length=n_batches))(neigh_batch_ids) 34 | expected_counts = expected_freq * neigh_batch_ids.shape[1] 35 | test_statistics = jnp.sum(jnp.square(observed_counts - expected_counts) / expected_counts, axis=1) 36 | p_values = 1 - jax.vmap(_chi2_cdf, in_axes=(None, 0))(dof, test_statistics) 37 | 38 | return test_statistics, p_values 39 | 40 | 41 | def kbet(X: NeighborsResults, batches: np.ndarray, alpha: float = 0.05) -> float: 42 | """Compute kbet :cite:p:`buttner2018`. 43 | 44 | This implementation is inspired by the implementation in Pegasus: 45 | https://pegasus.readthedocs.io/en/stable/index.html 46 | 47 | A higher acceptance rate means more mixing of batches. This implementation does 48 | not exactly mirror the default original implementation, as there is currently no 49 | `adapt` option. 50 | 51 | Note that this is also not equivalent to the kbet used in the original scib package, 52 | as that one computes kbet for each cell type label. To achieve this, use 53 | :func:`scib_metrics.kbet_per_label`. 54 | 55 | Parameters 56 | ---------- 57 | X 58 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object. 59 | batches 60 | Array of shape (n_cells,) representing batch values 61 | for each cell. 62 | alpha 63 | Significance level for the statistical test. 64 | 65 | Returns 66 | ------- 67 | acceptance_rate 68 | Kbet acceptance rate of the sample. 69 | stat_mean 70 | Mean Kbet chi-square statistic over all cells. 71 | pvalue_mean 72 | Mean Kbet p-value over all cells. 73 | """ 74 | if len(batches) != len(X.indices): 75 | raise ValueError("Length of batches does not match number of cells.") 76 | knn_idx = X.indices 77 | batches = np.asarray(pd.Categorical(batches).codes) 78 | neigh_batch_ids = batches[knn_idx] 79 | chex.assert_equal_shape([neigh_batch_ids, knn_idx]) 80 | n_batches = jnp.unique(batches).shape[0] 81 | test_statistics, p_values = _kbet(neigh_batch_ids, batches, n_batches) 82 | test_statistics = get_ndarray(test_statistics) 83 | p_values = get_ndarray(p_values) 84 | acceptance_rate = (p_values >= alpha).mean() 85 | 86 | return acceptance_rate, test_statistics, p_values 87 | 88 | 89 | def kbet_per_label( 90 | X: NeighborsResults, 91 | batches: np.ndarray, 92 | labels: np.ndarray, 93 | alpha: float = 0.05, 94 | diffusion_n_comps: int = 100, 95 | return_df: bool = False, 96 | ) -> float | tuple[float, pd.DataFrame]: 97 | """Compute kBET score per cell type label as in :cite:p:`luecken2022benchmarking`. 98 | 99 | This approximates the method used in the original scib package. Notably, the underlying 100 | kbet might have some inconsistencies with the R implementation. Furthermore, to equalize 101 | the neighbor graphs of cell type subsets we use diffusion distance approximated with diffusion 102 | maps. Increasing `diffusion_n_comps` will increase the accuracy of the approximation. 103 | 104 | Parameters 105 | ---------- 106 | X 107 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object. 108 | batches 109 | Array of shape (n_cells,) representing batch values 110 | for each cell. 111 | alpha 112 | Significance level for the statistical test. 113 | diffusion_n_comps 114 | Number of diffusion components to use for diffusion distance approximation. 115 | return_df 116 | Return dataframe of results in addition to score. 117 | 118 | Returns 119 | ------- 120 | kbet_score 121 | Kbet score over all cells. Higher means more integrated, as in the kBET acceptance rate. 122 | df 123 | Dataframe with kBET score per cell type label. 124 | 125 | Notes 126 | ----- 127 | This function requires X to be cell-cell connectivities, not distances. 128 | """ 129 | if len(batches) != len(X.indices): 130 | raise ValueError("Length of batches does not match number of cells.") 131 | if len(labels) != len(X.indices): 132 | raise ValueError("Length of labels does not match number of cells.") 133 | # set upper bound for k0 134 | size_max = 2**31 - 1 135 | batches = np.asarray(pd.Categorical(batches).codes) 136 | labels = np.asarray(labels) 137 | 138 | conn_graph = X.knn_graph_connectivities 139 | 140 | # prepare call of kBET per cluster 141 | clusters = [] 142 | clusters, counts = np.unique(labels, return_counts=True) 143 | skipped = clusters[counts > 10] 144 | clusters = clusters[counts <= 10] 145 | kbet_scores = {"cluster": list(skipped), "kBET": [np.nan] * len(skipped)} 146 | logger.info(f"{len(skipped)} clusters consist of a single batch or are too small. Skip.") 147 | 148 | for clus in clusters: 149 | # subset by label 150 | mask = labels == clus 151 | conn_graph_sub = conn_graph[mask, :][:, mask] 152 | conn_graph_sub.sort_indices() 153 | n_obs = conn_graph_sub.shape[0] 154 | batches_sub = batches[mask] 155 | 156 | quarter_mean = np.floor(np.mean(pd.Series(batches_sub).value_counts()) / 4).astype("int") 157 | k0 = np.min([70, np.max([10, quarter_mean])]) 158 | # check k0 for reasonability 159 | if k0 * n_obs >= size_max: 160 | k0 = np.floor(size_max / n_obs).astype("int") 161 | 162 | n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong") 163 | 164 | if n_comp == 1: # a single component to compute kBET on 165 | try: 166 | diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1]) 167 | nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps) 168 | # call kBET 169 | score, _, _ = kbet( 170 | nn_graph_sub, 171 | batches=batches_sub, 172 | alpha=alpha, 173 | ) 174 | except ValueError: 175 | logger.info("Diffusion distance failed. Skip.") 176 | score = 0 # i.e. 100% rejection 177 | 178 | else: 179 | # check the number of components where kBET can be computed upon 180 | comp_size = pd.Series(labs).value_counts() 181 | # check which components are small 182 | comp_size_thresh = 3 * k0 183 | idx_nonan = np.flatnonzero(np.in1d(labs, comp_size[comp_size >= comp_size_thresh].index)) 184 | 185 | # check if 75% of all cells can be used for kBET run 186 | if len(idx_nonan) / len(labs) >= 0.75: 187 | # create another subset of components, assume they are not visited in a diffusion process 188 | conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan] 189 | conn_graph_sub_sub.sort_indices() 190 | 191 | try: 192 | diffusion_n_comps = np.min([diffusion_n_comps, conn_graph_sub_sub.shape[0] - 1]) 193 | nn_results_sub_sub = diffusion_nn(conn_graph_sub_sub, k=k0, n_comps=diffusion_n_comps) 194 | # call kBET 195 | score, _, _ = kbet( 196 | nn_results_sub_sub, 197 | batches=batches_sub[idx_nonan], 198 | alpha=alpha, 199 | ) 200 | except ValueError: 201 | logger.info("Diffusion distance failed. Skip.") 202 | score = 0 # i.e. 100% rejection 203 | else: # if there are too many too small connected components, set kBET score to 0 204 | score = 0 # i.e. 100% rejection 205 | 206 | kbet_scores["cluster"].append(clus) 207 | kbet_scores["kBET"].append(score) 208 | 209 | kbet_scores = pd.DataFrame.from_dict(kbet_scores) 210 | kbet_scores = kbet_scores.reset_index(drop=True) 211 | 212 | final_score = np.nanmean(kbet_scores["kBET"]) 213 | if not return_df: 214 | return final_score 215 | else: 216 | return final_score, kbet_scores 217 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_lisi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from scib_metrics.nearest_neighbors import NeighborsResults 5 | from scib_metrics.utils import compute_simpson_index 6 | 7 | 8 | def lisi_knn(X: NeighborsResults, labels: np.ndarray, perplexity: float = None) -> np.ndarray: 9 | """Compute the local inverse simpson index (LISI) for each cell :cite:p:`korsunsky2019harmony`. 10 | 11 | Parameters 12 | ---------- 13 | X 14 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object. 15 | labels 16 | Array of shape (n_cells,) representing label values 17 | for each cell. 18 | perplexity 19 | Parameter controlling effective neighborhood size. If None, the 20 | perplexity is set to the number of neighbors // 3. 21 | 22 | Returns 23 | ------- 24 | lisi 25 | Array of shape (n_cells,) with the LISI score for each cell. 26 | """ 27 | labels = np.asarray(pd.Categorical(labels).codes) 28 | knn_dists, knn_idx = X.distances, X.indices 29 | row_idx = np.arange(X.n_samples)[:, np.newaxis] 30 | 31 | if perplexity is None: 32 | perplexity = np.floor(knn_idx.shape[1] / 3) 33 | 34 | n_labels = len(np.unique(labels)) 35 | 36 | simpson = compute_simpson_index( 37 | knn_dists=knn_dists, knn_idx=knn_idx, row_idx=row_idx, labels=labels, n_labels=n_labels, perplexity=perplexity 38 | ) 39 | return 1 / simpson 40 | 41 | 42 | def ilisi_knn(X: NeighborsResults, batches: np.ndarray, perplexity: float = None, scale: bool = True) -> float: 43 | """Compute the integration local inverse simpson index (iLISI) for each cell :cite:p:`korsunsky2019harmony`. 44 | 45 | Returns a scaled version of the iLISI score for each cell, by default :cite:p:`luecken2022benchmarking`. 46 | 47 | Parameters 48 | ---------- 49 | X 50 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object. 51 | batches 52 | Array of shape (n_cells,) representing batch values 53 | for each cell. 54 | perplexity 55 | Parameter controlling effective neighborhood size. If None, the 56 | perplexity is set to the number of neighbors // 3. 57 | scale 58 | Scale lisi into the range [0, 1]. If True, higher values are better. 59 | 60 | Returns 61 | ------- 62 | ilisi 63 | iLISI score. 64 | """ 65 | batches = np.asarray(pd.Categorical(batches).codes) 66 | lisi = lisi_knn(X, batches, perplexity=perplexity) 67 | ilisi = np.nanmedian(lisi) 68 | if scale: 69 | nbatches = len(np.unique(batches)) 70 | ilisi = (ilisi - 1) / (nbatches - 1) 71 | return ilisi 72 | 73 | 74 | def clisi_knn(X: NeighborsResults, labels: np.ndarray, perplexity: float = None, scale: bool = True) -> float: 75 | """Compute the cell-type local inverse simpson index (cLISI) for each cell :cite:p:`korsunsky2019harmony`. 76 | 77 | Returns a scaled version of the cLISI score for each cell, by default :cite:p:`luecken2022benchmarking`. 78 | 79 | Parameters 80 | ---------- 81 | X 82 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object. 83 | labels 84 | Array of shape (n_cells,) representing cell type label values 85 | for each cell. 86 | perplexity 87 | Parameter controlling effective neighborhood size. If None, the 88 | perplexity is set to the number of neighbors // 3. 89 | scale 90 | Scale lisi into the range [0, 1]. If True, higher values are better. 91 | 92 | Returns 93 | ------- 94 | clisi 95 | cLISI score. 96 | """ 97 | labels = np.asarray(pd.Categorical(labels).codes) 98 | lisi = lisi_knn(X, labels, perplexity=perplexity) 99 | clisi = np.nanmedian(lisi) 100 | if scale: 101 | nlabels = len(np.unique(labels)) 102 | clisi = (nlabels - clisi) / (nlabels - 1) 103 | return clisi 104 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_nmi_ari.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import warnings 4 | 5 | import igraph 6 | import numpy as np 7 | from scipy.sparse import spmatrix 8 | from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score 9 | from sklearn.utils import check_array 10 | 11 | from scib_metrics.nearest_neighbors import NeighborsResults 12 | from scib_metrics.utils import KMeans 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def _compute_clustering_kmeans(X: np.ndarray, n_clusters: int) -> np.ndarray: 18 | kmeans = KMeans(n_clusters) 19 | kmeans.fit(X) 20 | return kmeans.labels_ 21 | 22 | 23 | def _compute_clustering_leiden(connectivity_graph: spmatrix, resolution: float, seed: int) -> np.ndarray: 24 | rng = random.Random(seed) 25 | igraph.set_random_number_generator(rng) 26 | # The connectivity graph with the umap method is symmetric, but we need to first make it directed 27 | # to have both sets of edges as is done in scanpy. See test for more details. 28 | g = igraph.Graph.Weighted_Adjacency(connectivity_graph, mode="directed") 29 | g.to_undirected(mode="each") 30 | clustering = g.community_leiden(objective_function="modularity", weights="weight", resolution=resolution) 31 | clusters = clustering.membership 32 | return np.asarray(clusters) 33 | 34 | 35 | def _compute_nmi_ari_cluster_labels( 36 | X: spmatrix, 37 | labels: np.ndarray, 38 | resolution: float = 1.0, 39 | seed: int = 42, 40 | ) -> tuple[float, float]: 41 | labels_pred = _compute_clustering_leiden(X, resolution, seed) 42 | nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic") 43 | ari = adjusted_rand_score(labels, labels_pred) 44 | return nmi, ari 45 | 46 | 47 | def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> dict[str, float]: 48 | """Compute nmi and ari between k-means clusters and labels. 49 | 50 | This deviates from the original implementation in scib by using k-means 51 | with k equal to the known number of cell types/labels. This leads to 52 | a more efficient computation of the nmi and ari scores. 53 | 54 | Parameters 55 | ---------- 56 | X 57 | Array of shape (n_cells, n_features). 58 | labels 59 | Array of shape (n_cells,) representing label values 60 | 61 | Returns 62 | ------- 63 | nmi 64 | Normalized mutual information score 65 | ari 66 | Adjusted rand index score 67 | """ 68 | X = check_array(X, accept_sparse=False, ensure_2d=True) 69 | n_clusters = len(np.unique(labels)) 70 | labels_pred = _compute_clustering_kmeans(X, n_clusters) 71 | nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic") 72 | ari = adjusted_rand_score(labels, labels_pred) 73 | 74 | return {"nmi": nmi, "ari": ari} 75 | 76 | 77 | def nmi_ari_cluster_labels_leiden( 78 | X: NeighborsResults, 79 | labels: np.ndarray, 80 | optimize_resolution: bool = True, 81 | resolution: float = 1.0, 82 | n_jobs: int = 1, 83 | seed: int = 42, 84 | ) -> dict[str, float]: 85 | """Compute nmi and ari between leiden clusters and labels. 86 | 87 | This deviates from the original implementation in scib by using leiden instead of 88 | louvain clustering. Installing joblib allows for parallelization of the leiden 89 | resoution optimization. 90 | 91 | Parameters 92 | ---------- 93 | X 94 | A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object. 95 | labels 96 | Array of shape (n_cells,) representing label values 97 | optimize_resolution 98 | Whether to optimize the resolution parameter of leiden clustering by searching over 99 | 10 values 100 | resolution 101 | Resolution parameter of leiden clustering. Only used if optimize_resolution is False. 102 | n_jobs 103 | Number of jobs for parallelizing resolution optimization via joblib. If -1, all CPUs 104 | are used. 105 | seed 106 | Seed used for reproducibility of clustering. 107 | 108 | Returns 109 | ------- 110 | nmi 111 | Normalized mutual information score 112 | ari 113 | Adjusted rand index score 114 | """ 115 | conn_graph = X.knn_graph_connectivities 116 | if optimize_resolution: 117 | n = 10 118 | resolutions = np.array([2 * x / n for x in range(1, n + 1)]) 119 | try: 120 | from joblib import Parallel, delayed 121 | 122 | out = Parallel(n_jobs=n_jobs)( 123 | delayed(_compute_nmi_ari_cluster_labels)(conn_graph, labels, r) for r in resolutions 124 | ) 125 | except ImportError: 126 | warnings.warn("Using for loop over clustering resolutions. `pip install joblib` for parallelization.") 127 | out = [_compute_nmi_ari_cluster_labels(conn_graph, labels, r, seed=seed) for r in resolutions] 128 | nmi_ari = np.array(out) 129 | nmi_ind = np.argmax(nmi_ari[:, 0]) 130 | nmi, ari = nmi_ari[nmi_ind, :] 131 | return {"nmi": nmi, "ari": ari} 132 | else: 133 | nmi, ari = _compute_nmi_ari_cluster_labels(conn_graph, labels, resolution) 134 | 135 | return {"nmi": nmi, "ari": ari} 136 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_pcr_comparison.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from scib_metrics._types import NdArray 4 | from scib_metrics.utils import principal_component_regression 5 | 6 | 7 | def pcr_comparison( 8 | X_pre: NdArray, 9 | X_post: NdArray, 10 | covariate: NdArray, 11 | scale: bool = True, 12 | **kwargs, 13 | ) -> float: 14 | """Principal component regression (PCR) comparison :cite:p:`buttner2018`. 15 | 16 | Compare the explained variance before and after integration. 17 | 18 | Parameters 19 | ---------- 20 | X_pre 21 | Pre-integration array of shape (n_cells, n_features). 22 | X_post 23 | Post-integration array of shape (n_celss, n_features). 24 | covariate_pre: 25 | Array of shape (n_cells,) or (n_cells, 1) representing batch/covariate values. 26 | scale 27 | Whether to scale the score between 0 and 1. If True, larger values correspond to 28 | larger differences in variance contributions between `X_pre` and `X_post`. 29 | kwargs 30 | Keyword arguments passed into :func:`~scib_metrics.principal_component_regression`. 31 | 32 | Returns 33 | ------- 34 | pcr_compared: float 35 | Principal component regression score comparing the explained variance before and 36 | after integration. 37 | """ 38 | if X_pre.shape[0] != X_post.shape[0]: 39 | raise ValueError("Dimension mismatch: `X_pre` and `X_post` must have the same number of samples.") 40 | if covariate.shape[0] != X_pre.shape[0]: 41 | raise ValueError("Dimension mismatch: `X_pre` and `covariate` must have the same number of samples.") 42 | 43 | pcr_pre = principal_component_regression(X_pre, covariate, **kwargs) 44 | pcr_post = principal_component_regression(X_post, covariate, **kwargs) 45 | 46 | if scale: 47 | pcr_compared = (pcr_pre - pcr_post) / pcr_pre 48 | if pcr_compared < 0: 49 | warnings.warn( 50 | "PCR comparison score is negative, meaning variance contribution " 51 | "increased after integration. Setting to 0." 52 | ) 53 | pcr_compared = 0 54 | else: 55 | pcr_compared = pcr_post - pcr_pre 56 | 57 | return pcr_compared 58 | -------------------------------------------------------------------------------- /src/scib_metrics/metrics/_silhouette.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from scib_metrics.utils import silhouette_samples 5 | 6 | 7 | def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True, chunk_size: int = 256) -> float: 8 | """Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`. 9 | 10 | Parameters 11 | ---------- 12 | X 13 | Array of shape (n_cells, n_features). 14 | labels 15 | Array of shape (n_cells,) representing label values 16 | rescale 17 | Scale asw into the range [0, 1]. 18 | chunk_size 19 | Size of chunks to process at a time for distance computation. 20 | 21 | Returns 22 | ------- 23 | silhouette score 24 | """ 25 | asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size)) 26 | if rescale: 27 | asw = (asw + 1) / 2 28 | return np.mean(asw) 29 | 30 | 31 | def silhouette_batch( 32 | X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True, chunk_size: int = 256 33 | ) -> float: 34 | """Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`. 35 | 36 | Parameters 37 | ---------- 38 | X 39 | Array of shape (n_cells, n_features). 40 | labels 41 | Array of shape (n_cells,) representing label values 42 | batch 43 | Array of shape (n_cells,) representing batch values 44 | rescale 45 | Scale asw into the range [0, 1]. If True, higher values are better. 46 | chunk_size 47 | Size of chunks to process at a time for distance computation. 48 | 49 | Returns 50 | ------- 51 | silhouette score 52 | """ 53 | sil_dfs = [] 54 | unique_labels = np.unique(labels) 55 | for group in unique_labels: 56 | labels_mask = labels == group 57 | X_subset = X[labels_mask] 58 | batch_subset = batch[labels_mask] 59 | n_batches = len(np.unique(batch_subset)) 60 | 61 | if (n_batches == 1) or (n_batches == X_subset.shape[0]): 62 | continue 63 | 64 | sil_per_group = silhouette_samples(X_subset, batch_subset, chunk_size=chunk_size) 65 | 66 | # take only absolute value 67 | sil_per_group = np.abs(sil_per_group) 68 | 69 | if rescale: 70 | # scale s.t. highest number is optimal 71 | sil_per_group = 1 - sil_per_group 72 | 73 | sil_dfs.append( 74 | pd.DataFrame( 75 | { 76 | "group": [group] * len(sil_per_group), 77 | "silhouette_score": sil_per_group, 78 | } 79 | ) 80 | ) 81 | 82 | sil_df = pd.concat(sil_dfs).reset_index(drop=True) 83 | sil_means = sil_df.groupby("group").mean() 84 | asw = sil_means["silhouette_score"].mean() 85 | 86 | return asw 87 | -------------------------------------------------------------------------------- /src/scib_metrics/nearest_neighbors/__init__.py: -------------------------------------------------------------------------------- 1 | from ._dataclass import NeighborsResults 2 | from ._jax import jax_approx_min_k 3 | from ._pynndescent import pynndescent 4 | 5 | __all__ = [ 6 | "pynndescent", 7 | "jax_approx_min_k", 8 | "NeighborsResults", 9 | ] 10 | -------------------------------------------------------------------------------- /src/scib_metrics/nearest_neighbors/_dataclass.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import cached_property 3 | 4 | import chex 5 | import numpy as np 6 | from scipy.sparse import coo_matrix, csr_matrix 7 | from umap.umap_ import fuzzy_simplicial_set 8 | 9 | 10 | @dataclass 11 | class NeighborsResults: 12 | """Nearest neighbors results data store. 13 | 14 | Attributes 15 | ---------- 16 | distances : np.ndarray 17 | Array of distances to the nearest neighbors. 18 | indices : np.ndarray 19 | Array of indices of the nearest neighbors. Self should always 20 | be included here; however, some approximate algorithms may not return 21 | the self edge. 22 | """ 23 | 24 | indices: np.ndarray 25 | distances: np.ndarray 26 | 27 | def __post_init__(self): 28 | chex.assert_equal_shape([self.indices, self.distances]) 29 | 30 | @property 31 | def n_samples(self) -> np.ndarray: 32 | """Number of samples (cells).""" 33 | return self.indices.shape[0] 34 | 35 | @property 36 | def n_neighbors(self) -> np.ndarray: 37 | """Number of neighbors.""" 38 | return self.indices.shape[1] 39 | 40 | @cached_property 41 | def knn_graph_distances(self) -> csr_matrix: 42 | """Return the sparse weighted adjacency matrix.""" 43 | n_samples, n_neighbors = self.indices.shape 44 | # Efficient creation of row pointer 45 | rowptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors) 46 | # Create CSR matrix 47 | return csr_matrix((self.distances.ravel(), self.indices.ravel(), rowptr), shape=(n_samples, n_samples)) 48 | 49 | @cached_property 50 | def knn_graph_connectivities(self) -> coo_matrix: 51 | """Compute connectivities using the UMAP approach. 52 | 53 | Connectivities (similarities) are computed from distances 54 | using the approach from the UMAP method, which is also used by scanpy. 55 | """ 56 | conn_graph = coo_matrix(([], ([], [])), shape=(self.n_samples, 1)) 57 | connectivities = fuzzy_simplicial_set( 58 | conn_graph, 59 | n_neighbors=self.n_neighbors, 60 | random_state=None, 61 | metric=None, 62 | knn_indices=self.indices, 63 | knn_dists=self.distances, 64 | set_op_mix_ratio=1.0, 65 | local_connectivity=1.0, 66 | ) 67 | return connectivities[0] 68 | 69 | def subset_neighbors(self, n: int) -> "NeighborsResults": 70 | """Subset down to `n` neighbors.""" 71 | if n > self.n_neighbors: 72 | raise ValueError("n must be smaller than the number of neighbors") 73 | return self.__class__(indices=self.indices[:, :n], distances=self.distances[:, :n]) 74 | -------------------------------------------------------------------------------- /src/scib_metrics/nearest_neighbors/_jax.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from scib_metrics.utils import cdist, get_ndarray 8 | 9 | from ._dataclass import NeighborsResults 10 | 11 | 12 | @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) 13 | def _euclidean_ann(qy: jnp.ndarray, db: jnp.ndarray, k: int, recall_target: float = 0.95): 14 | """Compute half squared L2 distance between query points and database points.""" 15 | dists = cdist(qy, db) 16 | return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target) 17 | 18 | 19 | def jax_approx_min_k( 20 | X: np.ndarray, n_neighbors: int, recall_target: float = 0.95, chunk_size: int = 2048 21 | ) -> NeighborsResults: 22 | """Run approximate nearest neighbor search using jax. 23 | 24 | On TPU backends, this is approximate nearest neighbor search. On other backends, this is exact nearest neighbor search. 25 | 26 | Parameters 27 | ---------- 28 | X 29 | Data matrix. 30 | n_neighbors 31 | Number of neighbors to search for. 32 | recall_target 33 | Target recall for approximate nearest neighbor search. 34 | chunk_size 35 | Number of query points to search for at once. 36 | """ 37 | db = jnp.asarray(X) 38 | # Loop over query points in chunks 39 | neighbors = [] 40 | dists = [] 41 | for i in range(0, db.shape[0], chunk_size): 42 | start = i 43 | end = min(i + chunk_size, db.shape[0]) 44 | qy = db[start:end] 45 | dist, neighbor = _euclidean_ann(qy, db, k=n_neighbors, recall_target=recall_target) 46 | neighbors.append(neighbor) 47 | dists.append(dist) 48 | neighbors = jnp.concatenate(neighbors, axis=0) 49 | dists = jnp.concatenate(dists, axis=0) 50 | return NeighborsResults(indices=get_ndarray(neighbors), distances=get_ndarray(dists)) 51 | -------------------------------------------------------------------------------- /src/scib_metrics/nearest_neighbors/_pynndescent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pynndescent import NNDescent 3 | 4 | from ._dataclass import NeighborsResults 5 | 6 | 7 | def pynndescent(X: np.ndarray, n_neighbors: int, random_state: int = 0, n_jobs: int = 1) -> NeighborsResults: 8 | """Run pynndescent approximate nearest neighbor search. 9 | 10 | Parameters 11 | ---------- 12 | X 13 | Data matrix. 14 | n_neighbors 15 | Number of neighbors to search for. 16 | random_state 17 | Random state. 18 | n_jobs 19 | Number of jobs to use. 20 | """ 21 | # Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327) 22 | # which is used by scanpy under the hood 23 | n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0))) 24 | n_iters = max(5, int(round(np.log2(X.shape[0])))) 25 | max_candidates = 60 26 | 27 | knn_search_index = NNDescent( 28 | X, 29 | n_neighbors=n_neighbors, 30 | random_state=random_state, 31 | low_memory=True, 32 | n_jobs=n_jobs, 33 | compressed=False, 34 | n_trees=n_trees, 35 | n_iters=n_iters, 36 | max_candidates=max_candidates, 37 | ) 38 | indices, distances = knn_search_index.neighbor_graph 39 | 40 | return NeighborsResults(indices=indices, distances=distances) 41 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._diffusion_nn import diffusion_nn 2 | from ._dist import cdist, pdist_squareform 3 | from ._kmeans import KMeans 4 | from ._lisi import compute_simpson_index 5 | from ._pca import pca 6 | from ._pcr import principal_component_regression 7 | from ._silhouette import silhouette_samples 8 | from ._utils import check_square, convert_knn_graph_to_idx, get_ndarray, one_hot 9 | 10 | __all__ = [ 11 | "silhouette_samples", 12 | "cdist", 13 | "pdist_squareform", 14 | "get_ndarray", 15 | "KMeans", 16 | "pca", 17 | "principal_component_regression", 18 | "one_hot", 19 | "compute_simpson_index", 20 | "convert_knn_graph_to_idx", 21 | "check_square", 22 | "diffusion_nn", 23 | ] 24 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_diffusion_nn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Literal 3 | 4 | import numpy as np 5 | import scipy 6 | from scipy.sparse import csr_matrix, issparse 7 | 8 | from scib_metrics import nearest_neighbors 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | _EPS = 1e-8 13 | 14 | 15 | def _compute_transitions(X: csr_matrix, density_normalize: bool = True): 16 | """Code from scanpy. 17 | 18 | https://github.com/scverse/scanpy/blob/2e98705347ea484c36caa9ba10de1987b09081bf/scanpy/neighbors/__init__.py#L899 19 | """ 20 | # TODO(adamgayoso): Refactor this with Jax 21 | # density normalization as of Coifman et al. (2005) 22 | # ensures that kernel matrix is independent of sampling density 23 | if density_normalize: 24 | # q[i] is an estimate for the sampling density at point i 25 | # it's also the degree of the underlying graph 26 | q = np.asarray(X.sum(axis=0)) 27 | if not issparse(X): 28 | Q = np.diag(1.0 / q) 29 | else: 30 | Q = scipy.sparse.spdiags(1.0 / q, 0, X.shape[0], X.shape[0]) 31 | K = Q @ X @ Q 32 | else: 33 | K = X 34 | 35 | # z[i] is the square root of the row sum of K 36 | z = np.sqrt(np.asarray(K.sum(axis=0))) 37 | if not issparse(K): 38 | Z = np.diag(1.0 / z) 39 | else: 40 | Z = scipy.sparse.spdiags(1.0 / z, 0, K.shape[0], K.shape[0]) 41 | transitions_sym = Z @ K @ Z 42 | 43 | return transitions_sym 44 | 45 | 46 | def _compute_eigen( 47 | transitions_sym: csr_matrix, 48 | n_comps: int = 15, 49 | sort: Literal["decrease", "increase"] = "decrease", 50 | ): 51 | """Compute eigen decomposition of transition matrix. 52 | 53 | https://github.com/scverse/scanpy/blob/2e98705347ea484c36caa9ba10de1987b09081bf/scanpy/neighbors/__init__.py 54 | """ 55 | # TODO(adamgayoso): Refactor this with Jax 56 | matrix = transitions_sym 57 | # compute the spectrum 58 | if n_comps == 0: 59 | evals, evecs = scipy.linalg.eigh(matrix) 60 | else: 61 | n_comps = min(matrix.shape[0] - 1, n_comps) 62 | # ncv = max(2 * n_comps + 1, int(np.sqrt(matrix.shape[0]))) 63 | ncv = None 64 | which = "LM" if sort == "decrease" else "SM" 65 | # it pays off to increase the stability with a bit more precision 66 | matrix = matrix.astype(np.float64) 67 | 68 | evals, evecs = scipy.sparse.linalg.eigsh(matrix, k=n_comps, which=which, ncv=ncv) 69 | evals, evecs = evals.astype(np.float32), evecs.astype(np.float32) 70 | if sort == "decrease": 71 | evals = evals[::-1] 72 | evecs = evecs[:, ::-1] 73 | 74 | return evals, evecs 75 | 76 | 77 | def _get_sparse_matrix_from_indices_distances_numpy(indices, distances, n_obs, n_neighbors): 78 | """Code from scanpy.""" 79 | n_nonzero = n_obs * n_neighbors 80 | indptr = np.arange(0, n_nonzero + 1, n_neighbors) 81 | D = csr_matrix( 82 | ( 83 | distances.copy().ravel(), # copy the data, otherwise strange behavior here 84 | indices.copy().ravel(), 85 | indptr, 86 | ), 87 | shape=(n_obs, n_obs), 88 | ) 89 | D.eliminate_zeros() 90 | D.sort_indices() 91 | return D 92 | 93 | 94 | def diffusion_nn(X: csr_matrix, k: int, n_comps: int = 100) -> nearest_neighbors.NeighborsResults: 95 | """Diffusion-based neighbors. 96 | 97 | This function generates a nearest neighbour list from a connectivities matrix. 98 | This allows us to select a consistent number of nearest neighbors across all methods. 99 | 100 | This differs from the original scIB implemenation by leveraging diffusion maps. Here we 101 | embed the data with diffusion maps in which euclidean distance represents well the diffusion 102 | distance. We then use pynndescent to find the nearest neighbours in this embedding space. 103 | 104 | Parameters 105 | ---------- 106 | X 107 | Array of shape (n_cells, n_cells) with non-zero values 108 | representing connectivities. 109 | k 110 | Number of nearest neighbours to select. 111 | n_comps 112 | Number of components for diffusion map 113 | 114 | Returns 115 | ------- 116 | Neighbors results 117 | """ 118 | transitions = _compute_transitions(X) 119 | evals, evecs = _compute_eigen(transitions, n_comps=n_comps) 120 | evals += _EPS # Avoid division by zero 121 | # Multiscale such that the number of steps t gets "integrated out" 122 | embedding = evecs 123 | scaled_evals = np.array([e if e == 1 else e / (1 - e) for e in evals]) 124 | embedding *= scaled_evals 125 | nn_result = nearest_neighbors.pynndescent(embedding, n_neighbors=k + 1) 126 | 127 | return nn_result 128 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_dist.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | 6 | @jax.jit 7 | def _euclidean_distance(x: np.array, y: np.array) -> float: 8 | dist = jnp.sqrt(jnp.sum((x - y) ** 2)) 9 | return dist 10 | 11 | 12 | @jax.jit 13 | def cdist(x: np.ndarray, y: np.ndarray) -> jnp.ndarray: 14 | """Jax implementation of :func:`scipy.spatial.distance.cdist`. 15 | 16 | Uses euclidean distance. 17 | 18 | Parameters 19 | ---------- 20 | x 21 | Array of shape (n_cells_a, n_features) 22 | y 23 | Array of shape (n_cells_b, n_features) 24 | 25 | Returns 26 | ------- 27 | dist 28 | Array of shape (n_cells_a, n_cells_b) 29 | """ 30 | return jax.vmap(lambda x1: jax.vmap(lambda y1: _euclidean_distance(x1, y1))(y))(x) 31 | 32 | 33 | @jax.jit 34 | def pdist_squareform(X: np.ndarray) -> jnp.ndarray: 35 | """Jax implementation of :func:`scipy.spatial.distance.pdist` and :func:`scipy.spatial.distance.squareform`. 36 | 37 | Uses euclidean distance. 38 | 39 | Parameters 40 | ---------- 41 | X 42 | Array of shape (n_cells, n_features) 43 | 44 | Returns 45 | ------- 46 | dist 47 | Array of shape (n_cells, n_cells) 48 | """ 49 | n_cells = X.shape[0] 50 | inds = jnp.triu_indices(n_cells) 51 | 52 | def _body_fn(X, i_j): 53 | i, j = i_j 54 | return X, _euclidean_distance(X[i], X[j]) 55 | 56 | dist_mat = jnp.zeros((n_cells, n_cells)) 57 | dist_mat = dist_mat.at[inds].set(jax.lax.scan(_body_fn, X, (inds[0], inds[1]))[1]) 58 | dist_mat = jnp.maximum(dist_mat, dist_mat.T) 59 | return dist_mat 60 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_kmeans.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Literal 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from jax import Array 8 | from sklearn.utils import check_array 9 | 10 | from scib_metrics._types import IntOrKey 11 | 12 | from ._dist import cdist 13 | from ._utils import get_ndarray, validate_seed 14 | 15 | 16 | def _tolerance(X: jnp.ndarray, tol: float) -> float: 17 | """Return a tolerance which is dependent on the dataset.""" 18 | variances = np.var(X, axis=0) 19 | return np.mean(variances) * tol 20 | 21 | 22 | def _initialize_random(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray: 23 | """Initialize cluster centroids randomly.""" 24 | n_obs = X.shape[0] 25 | key, subkey = jax.random.split(key) 26 | indices = jax.random.choice(subkey, n_obs, (n_clusters,), replace=False) 27 | initial_state = X[indices] 28 | return initial_state 29 | 30 | 31 | @partial(jax.jit, static_argnums=1) 32 | def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray: 33 | """Initialize cluster centroids with k-means++ algorithm.""" 34 | n_obs = X.shape[0] 35 | key, subkey = jax.random.split(key) 36 | initial_centroid_idx = jax.random.choice(subkey, n_obs, (1,), replace=False) 37 | initial_centroid = X[initial_centroid_idx].ravel() 38 | dist_sq = jnp.square(cdist(initial_centroid[jnp.newaxis, :], X)).ravel() 39 | initial_state = {"min_dist_sq": dist_sq, "centroid": initial_centroid, "key": key} 40 | n_local_trials = 2 + int(np.log(n_clusters)) 41 | 42 | def _step(state, _): 43 | prob = state["min_dist_sq"] / jnp.sum(state["min_dist_sq"]) 44 | # note that observations already chosen as centers will have 0 probability 45 | # and will not be chosen again 46 | state["key"], subkey = jax.random.split(state["key"]) 47 | next_centroid_idx_candidates = jax.random.choice(subkey, n_obs, (n_local_trials,), replace=False, p=prob) 48 | next_centroid_candidates = X[next_centroid_idx_candidates] 49 | # candidates by observations 50 | dist_sq_candidates = jnp.square(cdist(next_centroid_candidates, X)) 51 | dist_sq_candidates = jnp.minimum(state["min_dist_sq"][jnp.newaxis, :], dist_sq_candidates) 52 | candidates_pot = dist_sq_candidates.sum(axis=1) 53 | 54 | # Decide which candidate is the best 55 | best_candidate = jnp.argmin(candidates_pot) 56 | min_dist_sq = dist_sq_candidates[best_candidate] 57 | best_candidate = next_centroid_idx_candidates[best_candidate] 58 | 59 | state["min_dist_sq"] = min_dist_sq.ravel() 60 | state["centroid"] = X[best_candidate].ravel() 61 | return state, state["centroid"] 62 | 63 | _, centroids = jax.lax.scan(_step, initial_state, jnp.arange(n_clusters - 1)) 64 | centroids = jnp.concatenate([initial_centroid[jnp.newaxis, :], centroids]) 65 | return centroids 66 | 67 | 68 | @jax.jit 69 | def _get_dist_labels(X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray: 70 | """Get the distance and labels for each observation.""" 71 | dist = jnp.square(cdist(X, centroids)) 72 | labels = jnp.argmin(dist, axis=1) 73 | return dist, labels 74 | 75 | 76 | class KMeans: 77 | """Jax implementation of :class:`sklearn.cluster.KMeans`. 78 | 79 | This implementation is limited to Euclidean distance. 80 | 81 | Parameters 82 | ---------- 83 | n_clusters 84 | Number of clusters. 85 | init 86 | Cluster centroid initialization method. One of the following: 87 | 88 | * ``'k-means++'``: Sample initial cluster centroids based on an 89 | empirical distribution of the points' contributions to the 90 | overall inertia. 91 | * ``'random'``: Uniformly sample observations as initial centroids 92 | n_init 93 | Number of times the k-means algorithm will be initialized. 94 | max_iter 95 | Maximum number of iterations of the k-means algorithm for a single run. 96 | tol 97 | Relative tolerance with regards to inertia to declare convergence. 98 | seed 99 | Random seed. 100 | """ 101 | 102 | def __init__( 103 | self, 104 | n_clusters: int = 8, 105 | init: Literal["k-means++", "random"] = "k-means++", 106 | n_init: int = 1, 107 | max_iter: int = 300, 108 | tol: float = 1e-4, 109 | seed: IntOrKey = 0, 110 | ): 111 | self.n_clusters = n_clusters 112 | self.n_init = n_init 113 | self.max_iter = max_iter 114 | self.tol_scale = tol 115 | self.seed: jax.Array = validate_seed(seed) 116 | 117 | if init not in ["k-means++", "random"]: 118 | raise ValueError("Invalid init method, must be one of ['k-means++' or 'random'].") 119 | if init == "k-means++": 120 | self._initialize = _initialize_plus_plus 121 | else: 122 | self._initialize = _initialize_random 123 | 124 | def fit(self, X: np.ndarray): 125 | """Fit the model to the data.""" 126 | X = check_array(X, dtype=np.float32, order="C") 127 | self.tol = _tolerance(X, self.tol_scale) 128 | # Subtract mean for numerical accuracy 129 | mean = X.mean(axis=0) 130 | X -= mean 131 | self._fit(X) 132 | X += mean 133 | self.cluster_centroids_ += mean 134 | return self 135 | 136 | def _fit(self, X: np.ndarray): 137 | all_centroids, all_inertias = jax.lax.map( 138 | lambda key: self._kmeans_full_run(X, key), jax.random.split(self.seed, self.n_init) 139 | ) 140 | i = jnp.argmin(all_inertias) 141 | self.cluster_centroids_ = get_ndarray(all_centroids[i]) 142 | self.inertia_ = get_ndarray(all_inertias[i]) 143 | _, labels = _get_dist_labels(X, self.cluster_centroids_) 144 | self.labels_ = get_ndarray(labels) 145 | 146 | @partial(jax.jit, static_argnums=(0,)) 147 | def _kmeans_full_run(self, X: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray: 148 | def _kmeans_step(state): 149 | centroids, old_inertia, _, n_iter = state 150 | # TODO(adamgayoso): Efficiently compute argmin and min simultaneously. 151 | dist, new_labels = _get_dist_labels(X, centroids) 152 | # From https://colab.research.google.com/drive/1AwS4haUx6swF82w3nXr6QKhajdF8aSvA?usp=sharing 153 | counts = (new_labels[jnp.newaxis, :] == jnp.arange(self.n_clusters)[:, jnp.newaxis]).sum( 154 | axis=1, keepdims=True 155 | ) 156 | counts = jnp.clip(counts, min=1, max=None) 157 | # Sum over points in a centroid by zeroing others out 158 | new_centroids = ( 159 | jnp.sum( 160 | jnp.where( 161 | # axes: (data points, clusters, data dimension) 162 | new_labels[:, jnp.newaxis, jnp.newaxis] 163 | == jnp.arange(self.n_clusters)[jnp.newaxis, :, jnp.newaxis], 164 | X[:, jnp.newaxis, :], 165 | 0.0, 166 | ), 167 | axis=0, 168 | ) 169 | / counts 170 | ) 171 | new_inertia = jnp.sum(jnp.min(dist, axis=1)) 172 | n_iter = n_iter + 1 173 | return new_centroids, new_inertia, old_inertia, n_iter 174 | 175 | def _kmeans_convergence(state): 176 | _, new_inertia, old_inertia, n_iter = state 177 | cond1 = jnp.abs(old_inertia - new_inertia) > self.tol 178 | cond2 = n_iter < self.max_iter 179 | return jnp.logical_or(cond1, cond2)[0] 180 | 181 | centroids = self._initialize(X, self.n_clusters, key) 182 | # centroids, new_inertia, old_inertia, n_iter 183 | state = (centroids, jnp.inf, jnp.inf, jnp.array([0.0])) 184 | state = jax.lax.while_loop(_kmeans_convergence, _kmeans_step, state) 185 | # Compute final inertia 186 | centroids = state[0] 187 | dist, _ = _get_dist_labels(X, centroids) 188 | final_intertia = jnp.sum(jnp.min(dist, axis=1)) 189 | return centroids, final_intertia 190 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_lisi.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import chex 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | 8 | from ._utils import get_ndarray 9 | 10 | NdArray = np.ndarray | jnp.ndarray 11 | 12 | 13 | @chex.dataclass 14 | class _NeighborProbabilityState: 15 | H: float 16 | P: chex.ArrayDevice 17 | Hdiff: float 18 | beta: float 19 | betamin: float 20 | betamax: float 21 | tries: int 22 | 23 | 24 | @jax.jit 25 | def _Hbeta(knn_dists_row: jnp.ndarray, row_self_mask: jnp.ndarray, beta: float) -> tuple[jnp.ndarray, jnp.ndarray]: 26 | P = jnp.exp(-knn_dists_row * beta) 27 | # Mask out self edges to be zero 28 | P = jnp.where(row_self_mask, P, 0) 29 | sumP = jnp.nansum(P) 30 | H = jnp.where(sumP == 0, 0, jnp.log(sumP) + beta * jnp.nansum(knn_dists_row * P) / sumP) 31 | P = jnp.where(sumP == 0, jnp.zeros_like(knn_dists_row), P / sumP) 32 | return H, P 33 | 34 | 35 | @jax.jit 36 | def _get_neighbor_probability( 37 | knn_dists_row: jnp.ndarray, row_self_mask: jnp.ndarray, perplexity: float, tol: float 38 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 39 | beta = 1 40 | betamin = -jnp.inf 41 | betamax = jnp.inf 42 | H, P = _Hbeta(knn_dists_row, row_self_mask, beta) 43 | Hdiff = H - jnp.log(perplexity) 44 | 45 | def _get_neighbor_probability_step(state): 46 | Hdiff = state.Hdiff 47 | beta = state.beta 48 | betamin = state.betamin 49 | betamax = state.betamax 50 | tries = state.tries 51 | 52 | new_betamin = jnp.where(Hdiff > 0, beta, betamin) 53 | new_betamax = jnp.where(Hdiff > 0, betamax, beta) 54 | new_beta = jnp.where( 55 | Hdiff > 0, 56 | jnp.where(betamax == jnp.inf, beta * 2, (beta + betamax) / 2), 57 | jnp.where(betamin == -jnp.inf, beta / 2, (beta + betamin) / 2), 58 | ) 59 | new_H, new_P = _Hbeta(knn_dists_row, row_self_mask, new_beta) 60 | new_Hdiff = new_H - jnp.log(perplexity) 61 | return _NeighborProbabilityState( 62 | H=new_H, P=new_P, Hdiff=new_Hdiff, beta=new_beta, betamin=new_betamin, betamax=new_betamax, tries=tries + 1 63 | ) 64 | 65 | def _get_neighbor_probability_convergence(state): 66 | Hdiff, tries = state.Hdiff, state.tries 67 | return jnp.logical_and(jnp.abs(Hdiff) >= tol, tries < 50) 68 | 69 | init_state = _NeighborProbabilityState(H=H, P=P, Hdiff=Hdiff, beta=beta, betamin=betamin, betamax=betamax, tries=0) 70 | final_state = jax.lax.while_loop(_get_neighbor_probability_convergence, _get_neighbor_probability_step, init_state) 71 | return final_state.H, final_state.P 72 | 73 | 74 | def _compute_simpson_index_cell( 75 | knn_dists_row: jnp.ndarray, 76 | knn_labels_row: jnp.ndarray, 77 | row_self_mask: jnp.ndarray, 78 | n_batches: int, 79 | perplexity: float, 80 | tol: float, 81 | ) -> jnp.ndarray: 82 | H, P = _get_neighbor_probability(knn_dists_row, row_self_mask, perplexity, tol) 83 | 84 | def _non_zero_H_simpson(): 85 | sumP = jnp.bincount(knn_labels_row, weights=P, length=n_batches) 86 | return jnp.where(knn_labels_row.shape[0] == P.shape[0], jnp.dot(sumP, sumP), 1) 87 | 88 | return jnp.where(H == 0, -1, _non_zero_H_simpson()) 89 | 90 | 91 | def compute_simpson_index( 92 | knn_dists: NdArray, 93 | knn_idx: NdArray, 94 | row_idx: NdArray, 95 | labels: NdArray, 96 | n_labels: int, 97 | perplexity: float = 30, 98 | tol: float = 1e-5, 99 | ) -> np.ndarray: 100 | """Compute the Simpson index for each cell. 101 | 102 | Parameters 103 | ---------- 104 | knn_dists 105 | KNN distances of size (n_cells, n_neighbors). 106 | knn_idx 107 | KNN indices of size (n_cells, n_neighbors) corresponding to distances. 108 | row_idx 109 | Idx of each row (n_cells, 1). 110 | labels 111 | Cell labels of size (n_cells,). 112 | n_labels 113 | Number of labels. 114 | perplexity 115 | Measure of the effective number of neighbors. 116 | tol 117 | Tolerance for binary search. 118 | 119 | Returns 120 | ------- 121 | simpson_index 122 | Simpson index of size (n_cells,). 123 | """ 124 | knn_dists = jnp.array(knn_dists) 125 | knn_idx = jnp.array(knn_idx) 126 | labels = jnp.array(labels) 127 | row_idx = jnp.array(row_idx) 128 | knn_labels = labels[knn_idx] 129 | self_mask = knn_idx != row_idx 130 | simpson_fn = partial(_compute_simpson_index_cell, n_batches=n_labels, perplexity=perplexity, tol=tol) 131 | out = jax.vmap(simpson_fn)(knn_dists, knn_labels, self_mask) 132 | return get_ndarray(out) 133 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_pca.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from chex import dataclass 3 | from jax import jit 4 | 5 | from scib_metrics._types import NdArray 6 | 7 | from ._utils import get_ndarray 8 | 9 | 10 | @dataclass 11 | class _SVDResult: 12 | """SVD result. 13 | 14 | Attributes 15 | ---------- 16 | u 17 | Array of shape (n_cells, n_components) containing the left singular vectors. 18 | s 19 | Array of shape (n_components,) containing the singular values. 20 | v 21 | Array of shape (n_components, n_features) containing the right singular vectors. 22 | """ 23 | 24 | u: NdArray 25 | s: NdArray 26 | v: NdArray 27 | 28 | 29 | @dataclass 30 | class _PCAResult: 31 | """PCA result. 32 | 33 | Attributes 34 | ---------- 35 | coordinates 36 | Array of shape (n_cells, n_components) containing the PCA coordinates. 37 | components 38 | Array of shape (n_components, n_features) containing the PCA components. 39 | variance 40 | Array of shape (n_components,) containing the explained variance of each PC. 41 | variance_ratio 42 | Array of shape (n_components,) containing the explained variance ratio of each PC. 43 | svd 44 | Dataclass containing the SVD data. 45 | """ 46 | 47 | coordinates: NdArray 48 | components: NdArray 49 | variance: NdArray 50 | variance_ratio: NdArray 51 | svd: _SVDResult | None = None 52 | 53 | 54 | def _svd_flip( 55 | u: NdArray, 56 | v: NdArray, 57 | u_based_decision: bool = True, 58 | ): 59 | """Sign correction to ensure deterministic output from SVD. 60 | 61 | Jax implementation of :func:`~sklearn.utils.extmath.svd_flip`. 62 | 63 | Parameters 64 | ---------- 65 | u 66 | Left singular vectors of shape (M, K). 67 | v 68 | Right singular vectors of shape (K, N). 69 | u_based_decision 70 | If True, use the columns of u as the basis for sign flipping. 71 | """ 72 | if u_based_decision: 73 | max_abs_cols = jnp.argmax(jnp.abs(u), axis=0) 74 | signs = jnp.sign(u[max_abs_cols, jnp.arange(u.shape[1])]) 75 | else: 76 | max_abs_rows = jnp.argmax(jnp.abs(v), axis=1) 77 | signs = jnp.sign(v[jnp.arange(v.shape[0]), max_abs_rows]) 78 | u_ = u * signs 79 | v_ = v * signs[:, None] 80 | return u_, v_ 81 | 82 | 83 | def pca( 84 | X: NdArray, 85 | n_components: int | None = None, 86 | return_svd: bool = False, 87 | ) -> _PCAResult: 88 | """Principal component analysis (PCA). 89 | 90 | Parameters 91 | ---------- 92 | X 93 | Array of shape (n_cells, n_features). 94 | n_components 95 | Number of components to keep. If None, all components are kept. 96 | return_svd 97 | If True, also return the results from SVD. 98 | 99 | Returns 100 | ------- 101 | results: _PCAData 102 | """ 103 | max_components = min(X.shape) 104 | if n_components and n_components > max_components: 105 | raise ValueError(f"n_components = {n_components} must be <= min(n_cells, n_features) = {max_components}") 106 | n_components = n_components or max_components 107 | 108 | u, s, v, variance, variance_ratio = _pca(X) 109 | 110 | # Select n_components 111 | coordinates = u[:, :n_components] * s[:n_components] 112 | components = v[:n_components] 113 | variance_ = variance[:n_components] 114 | variance_ratio_ = variance_ratio[:n_components] 115 | 116 | results = _PCAResult( 117 | coordinates=get_ndarray(coordinates), 118 | components=get_ndarray(components), 119 | variance=get_ndarray(variance_), 120 | variance_ratio=get_ndarray(variance_ratio_), 121 | svd=_SVDResult(u=get_ndarray(u), s=get_ndarray(s), v=get_ndarray(v)) if return_svd else None, 122 | ) 123 | return results 124 | 125 | 126 | @jit 127 | def _pca( 128 | X: NdArray, 129 | ) -> tuple[NdArray, NdArray, NdArray, NdArray, NdArray]: 130 | """Principal component analysis. 131 | 132 | Parameters 133 | ---------- 134 | X 135 | Array of shape (n_cells, n_features). 136 | 137 | Returns 138 | ------- 139 | u: NdArray 140 | Left singular vectors of shape (M, K). 141 | s: NdArray 142 | Singular values of shape (K,). 143 | v: NdArray 144 | Right singular vectors of shape (K, N). 145 | variance: NdArray 146 | Array of shape (K,) containing the explained variance of each PC. 147 | variance_ratio: NdArray 148 | Array of shape (K,) containing the explained variance ratio of each PC. 149 | """ 150 | X_ = X - jnp.mean(X, axis=0) 151 | u, s, v = jnp.linalg.svd(X_, full_matrices=False) 152 | u, v = _svd_flip(u, v) 153 | 154 | variance = (s**2) / (X.shape[0] - 1) 155 | total_variance = jnp.sum(variance) 156 | variance_ratio = variance / total_variance 157 | 158 | return u, s, v, variance, variance_ratio 159 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_pcr.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pandas as pd 4 | from jax import jit 5 | 6 | from scib_metrics._types import NdArray 7 | 8 | from ._pca import pca 9 | from ._utils import one_hot 10 | 11 | 12 | def principal_component_regression( 13 | X: NdArray, 14 | covariate: NdArray, 15 | categorical: bool = False, 16 | n_components: int | None = None, 17 | ) -> float: 18 | """Principal component regression (PCR) :cite:p:`buttner2018`. 19 | 20 | Parameters 21 | ---------- 22 | X 23 | Array of shape (n_cells, n_features). 24 | covariate 25 | Array of shape (n_cells,) or (n_cells, 1) representing batch/covariate values. 26 | categorical 27 | If True, batch will be treated as categorical and one-hot encoded. 28 | n_components: 29 | Number of components to compute, passed into :func:`~scib_metrics.utils.pca`. 30 | If None, all components are used. 31 | 32 | Returns 33 | ------- 34 | pcr: float 35 | Principal component regression using the first n_components principal components. 36 | """ 37 | if len(X.shape) != 2: 38 | raise ValueError("Dimension mismatch: X must be 2-dimensional.") 39 | if X.shape[0] != covariate.shape[0]: 40 | raise ValueError("Dimension mismatch: X and batch must have the same number of samples.") 41 | if categorical: 42 | covariate = np.asarray(pd.Categorical(covariate).codes) 43 | else: 44 | covariate = np.asarray(covariate) 45 | 46 | covariate = one_hot(covariate) if categorical else covariate.reshape((covariate.shape[0], 1)) 47 | 48 | pca_results = pca(X, n_components=n_components) 49 | 50 | # Center inputs for no intercept 51 | covariate = covariate - jnp.mean(covariate, axis=0) 52 | pcr = _pcr(pca_results.coordinates, covariate, pca_results.variance) 53 | return float(pcr) 54 | 55 | 56 | @jit 57 | def _pcr( 58 | X_pca: NdArray, 59 | covariate: NdArray, 60 | var: NdArray, 61 | ) -> NdArray: 62 | """Principal component regression. 63 | 64 | Parameters 65 | ---------- 66 | X_pca 67 | Array of shape (n_cells, n_components) containing PCA coordinates. Must be standardized. 68 | covariate 69 | Array of shape (n_cells, 1) or (n_cells, n_classes) containing batch/covariate values. Must be standardized 70 | if not categorical (one-hot). 71 | var 72 | Array of shape (n_components,) containing the explained variance of each PC. 73 | """ 74 | residual_sum = jnp.linalg.lstsq(covariate, X_pca)[1] 75 | total_sum = jnp.sum((X_pca - jnp.mean(X_pca, axis=0, keepdims=True)) ** 2, axis=0) 76 | r2 = jnp.maximum(0, 1 - residual_sum / total_sum) 77 | 78 | return jnp.dot(jnp.ravel(r2), var) / jnp.sum(var) 79 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_silhouette.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from ._dist import cdist 9 | from ._utils import get_ndarray 10 | 11 | 12 | @jax.jit 13 | def _silhouette_reduce( 14 | D_chunk: jnp.ndarray, start: int, labels: jnp.ndarray, label_freqs: jnp.ndarray 15 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 16 | """Accumulate silhouette statistics for vertical chunk of X. 17 | 18 | Follows scikit-learn implementation. 19 | 20 | Parameters 21 | ---------- 22 | D_chunk 23 | Array of shape (n_chunk_samples, n_samples) 24 | Precomputed distances for a chunk. 25 | start 26 | First index in the chunk. 27 | labels 28 | Array of shape (n_samples,) 29 | Corresponding cluster labels, encoded as {0, ..., n_clusters-1}. 30 | label_freqs 31 | Distribution of cluster labels in ``labels``. 32 | """ 33 | # accumulate distances from each sample to each cluster 34 | D_chunk_len = D_chunk.shape[0] 35 | 36 | # If running into memory issues, use fori_loop instead of vmap 37 | # clust_dists = jnp.zeros((D_chunk_len, len(label_freqs)), dtype=D_chunk.dtype) 38 | # def _bincount(i, _data): 39 | # clust_dists, D_chunk, labels, label_freqs = _data 40 | # clust_dists = clust_dists.at[i].set(jnp.bincount(labels, weights=D_chunk[i], length=label_freqs.shape[0])) 41 | # return clust_dists, D_chunk, labels, label_freqs 42 | 43 | # clust_dists = jax.lax.fori_loop( 44 | # 0, D_chunk_len, lambda i, _data: _bincount(i, _data), (clust_dists, D_chunk, labels, label_freqs) 45 | # )[0] 46 | 47 | clust_dists = jax.vmap(partial(jnp.bincount, length=label_freqs.shape[0]), in_axes=(None, 0))(labels, D_chunk) 48 | 49 | # intra_index selects intra-cluster distances within clust_dists 50 | intra_index = (jnp.arange(D_chunk_len), jax.lax.dynamic_slice(labels, (start,), (D_chunk_len,))) 51 | # intra_clust_dists are averaged over cluster size outside this function 52 | intra_clust_dists = clust_dists[intra_index] 53 | # of the remaining distances we normalise and extract the minimum 54 | clust_dists = clust_dists.at[intra_index].set(jnp.inf) 55 | clust_dists /= label_freqs 56 | inter_clust_dists = clust_dists.min(axis=1) 57 | return intra_clust_dists, inter_clust_dists 58 | 59 | 60 | def _pairwise_distances_chunked(X: jnp.ndarray, chunk_size: int, reduce_fn: callable) -> jnp.ndarray: 61 | """Compute pairwise distances in chunks to reduce memory usage.""" 62 | n_samples = X.shape[0] 63 | n_chunks = jnp.ceil(n_samples / chunk_size).astype(int) 64 | intra_dists_all = [] 65 | inter_dists_all = [] 66 | for i in range(n_chunks): 67 | start = i * chunk_size 68 | end = min((i + 1) * chunk_size, n_samples) 69 | intra_cluster_dists, inter_cluster_dists = reduce_fn(cdist(X[start:end], X), start=start) 70 | intra_dists_all.append(intra_cluster_dists) 71 | inter_dists_all.append(inter_cluster_dists) 72 | return jnp.concatenate(intra_dists_all), jnp.concatenate(inter_dists_all) 73 | 74 | 75 | def silhouette_samples(X: np.ndarray, labels: np.ndarray, chunk_size: int = 256) -> np.ndarray: 76 | """Compute the Silhouette Coefficient for each observation. 77 | 78 | Implements :func:`sklearn.metrics.silhouette_samples`. 79 | 80 | Parameters 81 | ---------- 82 | X 83 | Array of shape (n_cells, n_features) representing a 84 | feature array. 85 | labels 86 | Array of shape (n_cells,) representing label values 87 | for each observation. 88 | chunk_size 89 | Number of samples to process at a time for distance computation. 90 | 91 | Returns 92 | ------- 93 | silhouette scores array of shape (n_cells,) 94 | """ 95 | if X.shape[0] != labels.shape[0]: 96 | raise ValueError("X and labels should have the same number of samples") 97 | labels = pd.Categorical(labels).codes 98 | labels = jnp.asarray(labels) 99 | label_freqs = jnp.bincount(labels) 100 | reduce_fn = partial(_silhouette_reduce, labels=labels, label_freqs=label_freqs) 101 | results = _pairwise_distances_chunked(X, chunk_size=chunk_size, reduce_fn=reduce_fn) 102 | intra_clust_dists, inter_clust_dists = results 103 | 104 | denom = jnp.take(label_freqs - 1, labels, mode="clip") 105 | intra_clust_dists /= denom 106 | sil_samples = inter_clust_dists - intra_clust_dists 107 | sil_samples /= jnp.maximum(intra_clust_dists, inter_clust_dists) 108 | return get_ndarray(jnp.nan_to_num(sil_samples)) 109 | -------------------------------------------------------------------------------- /src/scib_metrics/utils/_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from chex import ArrayDevice 7 | from jax import Array, nn 8 | from scipy.sparse import csr_matrix 9 | from sklearn.neighbors import NearestNeighbors 10 | from sklearn.utils import check_array 11 | 12 | from scib_metrics._types import ArrayLike, IntOrKey, NdArray 13 | 14 | 15 | def get_ndarray(x: ArrayDevice) -> np.ndarray: 16 | """Convert Jax device array to Numpy array.""" 17 | return np.array(jax.device_get(x)) 18 | 19 | 20 | def one_hot(y: NdArray, n_classes: int | None = None) -> jnp.ndarray: 21 | """One-hot encode an array. Wrapper around :func:`~jax.nn.one_hot`. 22 | 23 | Parameters 24 | ---------- 25 | y 26 | Array of shape (n_cells,) or (n_cells, 1). 27 | n_classes 28 | Number of classes. If None, inferred from the data. 29 | 30 | Returns 31 | ------- 32 | one_hot: jnp.ndarray 33 | Array of shape (n_cells, n_classes). 34 | """ 35 | n_classes = n_classes or int(jax.device_get(jnp.max(y))) + 1 36 | return nn.one_hot(jnp.ravel(y), n_classes) 37 | 38 | 39 | def validate_seed(seed: IntOrKey) -> Array: 40 | """Validate a seed and return a Jax random key.""" 41 | return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed 42 | 43 | 44 | def check_square(X: ArrayLike): 45 | """Check if a matrix is square.""" 46 | if X.shape[0] != X.shape[1]: 47 | raise ValueError("X must be a square matrix") 48 | 49 | 50 | def convert_knn_graph_to_idx(X: csr_matrix) -> tuple[np.ndarray, np.ndarray]: 51 | """Convert a kNN graph to indices and distances.""" 52 | check_array(X, accept_sparse="csr") 53 | check_square(X) 54 | 55 | n_neighbors = np.unique(X.nonzero()[0], return_counts=True)[1] 56 | if len(np.unique(n_neighbors)) > 1: 57 | raise ValueError("Each cell must have the same number of neighbors.") 58 | 59 | n_neighbors = int(np.unique(n_neighbors)[0]) 60 | with warnings.catch_warnings(): 61 | warnings.filterwarnings("ignore", message="Precomputed sparse input") 62 | nn_obj = NearestNeighbors(n_neighbors=n_neighbors, metric="precomputed").fit(X) 63 | kneighbors = nn_obj.kneighbors(X) 64 | return kneighbors 65 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_benchmarker.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from scib_metrics.benchmark import BatchCorrection, Benchmarker, BioConservation 4 | from scib_metrics.nearest_neighbors import jax_approx_min_k 5 | from tests.utils.data import dummy_benchmarker_adata 6 | 7 | 8 | def test_benchmarker(): 9 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() 10 | bm = Benchmarker( 11 | ad, 12 | batch_key, 13 | labels_key, 14 | emb_keys, 15 | batch_correction_metrics=BatchCorrection(), 16 | bio_conservation_metrics=BioConservation(), 17 | ) 18 | bm.benchmark() 19 | results = bm.get_results() 20 | assert isinstance(results, pd.DataFrame) 21 | bm.plot_results_table() 22 | 23 | 24 | def test_benchmarker_default(): 25 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() 26 | bm = Benchmarker( 27 | ad, 28 | batch_key, 29 | labels_key, 30 | emb_keys, 31 | ) 32 | bm.benchmark() 33 | results = bm.get_results() 34 | assert isinstance(results, pd.DataFrame) 35 | bm.plot_results_table() 36 | 37 | 38 | def test_benchmarker_custom_metric_booleans(): 39 | bioc = BioConservation( 40 | isolated_labels=False, nmi_ari_cluster_labels_leiden=False, silhouette_label=False, clisi_knn=True 41 | ) 42 | bc = BatchCorrection(kbet_per_label=False, graph_connectivity=False, ilisi_knn=True) 43 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() 44 | bm = Benchmarker(ad, batch_key, labels_key, emb_keys, batch_correction_metrics=bc, bio_conservation_metrics=bioc) 45 | bm.benchmark() 46 | results = bm.get_results(clean_names=False) 47 | assert isinstance(results, pd.DataFrame) 48 | assert "isolated_labels" not in results.columns 49 | assert "nmi_ari_cluster_labels_leiden" not in results.columns 50 | assert "silhouette_label" not in results.columns 51 | assert "clisi_knn" in results.columns 52 | assert "kbet_per_label" not in results.columns 53 | assert "graph_connectivity" not in results.columns 54 | assert "ilisi_knn" in results.columns 55 | 56 | 57 | def test_benchmarker_custom_metric_callable(): 58 | bioc = BioConservation(clisi_knn={"perplexity": 10}) 59 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() 60 | bm = Benchmarker( 61 | ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc, batch_correction_metrics=BatchCorrection() 62 | ) 63 | bm.benchmark() 64 | results = bm.get_results(clean_names=False) 65 | assert "clisi_knn" in results.columns 66 | 67 | 68 | def test_benchmarker_custom_near_neighs(): 69 | ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() 70 | bm = Benchmarker( 71 | ad, 72 | batch_key, 73 | labels_key, 74 | emb_keys, 75 | bio_conservation_metrics=BioConservation(), 76 | batch_correction_metrics=BatchCorrection(), 77 | ) 78 | bm.prepare(neighbor_computer=jax_approx_min_k) 79 | bm.benchmark() 80 | results = bm.get_results() 81 | assert isinstance(results, pd.DataFrame) 82 | bm.plot_results_table() 83 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import igraph 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | import scanpy as sc 8 | from harmonypy import compute_lisi as harmonypy_lisi 9 | from scib.metrics import isolated_labels_asw 10 | from scipy.spatial.distance import cdist as sp_cdist 11 | from scipy.spatial.distance import pdist, squareform 12 | from sklearn.cluster import KMeans as SKMeans 13 | from sklearn.datasets import make_blobs 14 | from sklearn.metrics import silhouette_samples as sk_silhouette_samples 15 | from sklearn.metrics.pairwise import pairwise_distances_argmin 16 | from sklearn.neighbors import NearestNeighbors 17 | 18 | import scib_metrics 19 | from scib_metrics.nearest_neighbors import NeighborsResults 20 | from tests.utils.data import dummy_x_labels, dummy_x_labels_batch 21 | 22 | scib_metrics.settings.jax_fix_no_kernel_image() 23 | 24 | 25 | def test_package_has_version(): 26 | scib_metrics.__version__ 27 | 28 | 29 | def test_cdist(): 30 | x = jnp.array([[1, 2], [3, 4]]) 31 | y = jnp.array([[5, 6], [7, 8]]) 32 | assert np.allclose(scib_metrics.utils.cdist(x, y), sp_cdist(x, y)) 33 | 34 | 35 | def test_pdist(): 36 | x = jnp.array([[1, 2], [3, 4]]) 37 | assert np.allclose(scib_metrics.utils.pdist_squareform(x), squareform(pdist(x))) 38 | 39 | 40 | def test_silhouette_samples(): 41 | X, labels = dummy_x_labels() 42 | assert np.allclose(scib_metrics.utils.silhouette_samples(X, labels), sk_silhouette_samples(X, labels), atol=1e-5) 43 | 44 | 45 | def test_silhouette_label(): 46 | X, labels = dummy_x_labels() 47 | score = scib_metrics.silhouette_label(X, labels) 48 | assert score > 0 49 | scib_metrics.silhouette_label(X, labels, rescale=False) 50 | 51 | 52 | def test_silhouette_batch(): 53 | X, labels, batch = dummy_x_labels_batch() 54 | score = scib_metrics.silhouette_batch(X, labels, batch) 55 | assert score > 0 56 | scib_metrics.silhouette_batch(X, labels, batch) 57 | 58 | 59 | def test_compute_simpson_index(): 60 | X, labels = dummy_x_labels() 61 | D = scib_metrics.utils.cdist(X, X) 62 | nbrs = NearestNeighbors(n_neighbors=30, algorithm="kd_tree").fit(X) 63 | D, knn_idx = nbrs.kneighbors(X) 64 | row_idx = np.arange(X.shape[0])[:, None] 65 | scib_metrics.utils.compute_simpson_index( 66 | jnp.array(D), jnp.array(knn_idx), jnp.array(row_idx), jnp.array(labels), len(np.unique(labels)) 67 | ) 68 | 69 | 70 | @pytest.mark.parametrize("n_neighbors", [30, 60, 72]) 71 | def test_lisi_knn(n_neighbors): 72 | perplexity = n_neighbors // 3 73 | X, labels = dummy_x_labels() 74 | nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="kd_tree").fit(X) 75 | dists, inds = nbrs.kneighbors(X) 76 | neigh_results = NeighborsResults(indices=inds, distances=dists) 77 | lisi_res = scib_metrics.lisi_knn(neigh_results, labels, perplexity=perplexity) 78 | harmonypy_lisi_res = harmonypy_lisi( 79 | X, pd.DataFrame(labels, columns=["labels"]), label_colnames=["labels"], perplexity=perplexity 80 | )[:, 0] 81 | np.testing.assert_allclose(lisi_res, harmonypy_lisi_res, rtol=5e-5, atol=5e-5) 82 | 83 | 84 | def test_ilisi_clisi_knn(): 85 | X, labels, batches = dummy_x_labels_batch(x_is_neighbors_results=True) 86 | scib_metrics.ilisi_knn(X, batches, perplexity=10) 87 | scib_metrics.clisi_knn(X, labels, perplexity=10) 88 | 89 | 90 | def test_nmi_ari_cluster_labels_kmeans(): 91 | X, labels = dummy_x_labels() 92 | out = scib_metrics.nmi_ari_cluster_labels_kmeans(X, labels) 93 | nmi, ari = out["nmi"], out["ari"] 94 | assert isinstance(nmi, float) 95 | assert isinstance(ari, float) 96 | 97 | 98 | def test_nmi_ari_cluster_labels_leiden_parallel(): 99 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True) 100 | out = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=True, n_jobs=2) 101 | nmi, ari = out["nmi"], out["ari"] 102 | assert isinstance(nmi, float) 103 | assert isinstance(ari, float) 104 | 105 | 106 | def test_nmi_ari_cluster_labels_leiden_single_resolution(): 107 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True) 108 | out = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=0.1) 109 | nmi, ari = out["nmi"], out["ari"] 110 | assert isinstance(nmi, float) 111 | assert isinstance(ari, float) 112 | 113 | 114 | def test_nmi_ari_cluster_labels_leiden_reproducibility(): 115 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True) 116 | out1 = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=3.0) 117 | out2 = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=3.0) 118 | nmi1, ari1 = out1["nmi"], out1["ari"] 119 | nmi2, ari2 = out2["nmi"], out2["ari"] 120 | assert nmi1 == nmi2 121 | assert ari1 == ari2 122 | 123 | 124 | def test_leiden_graph_construction(): 125 | X, _ = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True) 126 | conn_graph = X.knn_graph_connectivities 127 | g = igraph.Graph.Weighted_Adjacency(conn_graph, mode="directed") 128 | g.to_undirected(mode="each") 129 | sc_g = sc._utils.get_igraph_from_adjacency(conn_graph, directed=False) 130 | assert g.isomorphic(sc_g) 131 | np.testing.assert_equal(g.es["weight"], sc_g.es["weight"]) 132 | 133 | 134 | def test_isolated_labels(): 135 | X, labels, batch = dummy_x_labels_batch() 136 | pred = scib_metrics.isolated_labels(X, labels, batch) 137 | adata = anndata.AnnData(X) 138 | adata.obsm["embed"] = X 139 | adata.obs["batch"] = batch 140 | adata.obs["labels"] = labels 141 | target = isolated_labels_asw(adata, "labels", "batch", "embed", iso_threshold=5) 142 | np.testing.assert_allclose(np.array(pred), np.array(target)) 143 | 144 | 145 | def test_kmeans(): 146 | centers = np.array([[1, 1], [-1, -1], [1, -1]]) * 2 147 | len(centers) 148 | X, labels_true = make_blobs(n_samples=3000, centers=centers, cluster_std=0.7, random_state=42) 149 | 150 | kmeans = scib_metrics.utils.KMeans(n_clusters=3) 151 | kmeans.fit(X) 152 | assert kmeans.labels_.shape == (X.shape[0],) 153 | 154 | skmeans = SKMeans(n_clusters=3) 155 | skmeans.fit(X) 156 | sk_inertia = np.array([skmeans.inertia_]) 157 | jax_inertia = np.array([kmeans.inertia_]) 158 | np.testing.assert_allclose(sk_inertia, jax_inertia, atol=4e-2) 159 | 160 | # Reorder cluster centroids between methods and measure accuracy 161 | k_means_cluster_centers = kmeans.cluster_centroids_ 162 | order = pairwise_distances_argmin(kmeans.cluster_centroids_, skmeans.cluster_centers_) 163 | sk_means_cluster_centers = skmeans.cluster_centers_[order] 164 | 165 | k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers) 166 | sk_means_labels = pairwise_distances_argmin(X, sk_means_cluster_centers) 167 | 168 | accuracy = (k_means_labels == sk_means_labels).sum() / len(k_means_labels) 169 | assert accuracy > 0.995 170 | 171 | 172 | def test_kbet(): 173 | X, _, batch = dummy_x_labels_batch(x_is_neighbors_results=True) 174 | acc_rate, stats, pvalues = scib_metrics.kbet(X, batch) 175 | assert isinstance(acc_rate, float) 176 | assert len(stats) == X.indices.shape[0] 177 | assert len(pvalues) == X.indices.shape[0] 178 | 179 | 180 | def test_kbet_per_label(): 181 | X, labels, batch = dummy_x_labels_batch(x_is_neighbors_results=True) 182 | score = scib_metrics.kbet_per_label(X, batch, labels) 183 | assert isinstance(score, float) 184 | 185 | 186 | def test_graph_connectivity(): 187 | X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True) 188 | metric = scib_metrics.graph_connectivity(X, labels) 189 | assert isinstance(metric, float) 190 | -------------------------------------------------------------------------------- /tests/test_neighbors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import scanpy as sc 4 | 5 | from scib_metrics.nearest_neighbors import jax_approx_min_k, pynndescent 6 | from tests.utils.data import dummy_benchmarker_adata 7 | 8 | 9 | def test_jax_neighbors(): 10 | ad, emb_keys, _, _ = dummy_benchmarker_adata() 11 | output = jax_approx_min_k(ad.obsm[emb_keys[0]], 10) 12 | assert output.distances.shape == (ad.n_obs, 10) 13 | 14 | 15 | @pytest.mark.parametrize("n", [5, 10, 20, 21]) 16 | def test_neighbors_results(n): 17 | adata, embedding_keys, *_ = dummy_benchmarker_adata() 18 | neigh_result = pynndescent(adata.obsm[embedding_keys[0]], n_neighbors=n) 19 | neigh_result = neigh_result.subset_neighbors(n=n) 20 | new_connect = neigh_result.knn_graph_connectivities 21 | 22 | sc_connect = sc.neighbors._connectivity.umap( 23 | neigh_result.indices[:, :n], neigh_result.distances[:, :n], n_obs=adata.n_obs, n_neighbors=n 24 | ) 25 | 26 | np.testing.assert_allclose(new_connect.toarray(), sc_connect.toarray()) 27 | -------------------------------------------------------------------------------- /tests/test_pcr_comparison.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | 5 | import scib_metrics 6 | from tests.utils.sampling import categorical_sample, normal_sample, poisson_sample 7 | 8 | PCR_COMPARISON_PARAMS = list(product([100], [100], [False, True])) 9 | 10 | 11 | @pytest.mark.parametrize("n_obs, n_vars, categorical", PCR_COMPARISON_PARAMS) 12 | def test_pcr_comparison(n_obs, n_vars, categorical): 13 | X_pre = poisson_sample(n_obs, n_vars, seed=0) 14 | X_post = poisson_sample(n_obs, n_vars, seed=1) 15 | covariate = categorical_sample(n_obs, int(n_obs / 5)) if categorical else normal_sample(n_obs, seed=0) 16 | 17 | score = scib_metrics.pcr_comparison(X_pre, X_post, covariate, scale=True) 18 | assert score >= 0 and score <= 1 19 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/scib-metrics/5d01bb46f4d2317b4b3379851c7202df5be36c68/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/data.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | from scipy.sparse import csr_matrix 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | import scib_metrics 7 | from scib_metrics.nearest_neighbors import NeighborsResults 8 | 9 | 10 | def dummy_x_labels(symmetric_positive=False, x_is_neighbors_results=False): 11 | rng = np.random.default_rng(seed=42) 12 | X = rng.normal(size=(100, 10)) 13 | labels = rng.integers(0, 4, size=(100,)) 14 | if symmetric_positive: 15 | X = np.abs(X @ X.T) 16 | if x_is_neighbors_results: 17 | dist_mat = csr_matrix(scib_metrics.utils.cdist(X, X)) 18 | nbrs = NearestNeighbors(n_neighbors=30, metric="precomputed").fit(dist_mat) 19 | dist, ind = nbrs.kneighbors(dist_mat) 20 | X = NeighborsResults(indices=ind, distances=dist) 21 | return X, labels 22 | 23 | 24 | def dummy_x_labels_batch(x_is_neighbors_results=False): 25 | rng = np.random.default_rng(seed=43) 26 | X, labels = dummy_x_labels(x_is_neighbors_results=x_is_neighbors_results) 27 | batch = rng.integers(0, 4, size=(100,)) 28 | return X, labels, batch 29 | 30 | 31 | def dummy_benchmarker_adata(): 32 | X, labels, batch = dummy_x_labels_batch(x_is_neighbors_results=False) 33 | adata = anndata.AnnData(X) 34 | labels_key = "labels" 35 | batch_key = "batch" 36 | adata.obs[labels_key] = labels 37 | adata.obs[batch_key] = batch 38 | embedding_keys = [] 39 | for i in range(5): 40 | key = f"X_emb_{i}" 41 | adata.obsm[key] = X 42 | embedding_keys.append(key) 43 | return adata, embedding_keys, labels_key, batch_key 44 | -------------------------------------------------------------------------------- /tests/utils/sampling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import Array 4 | 5 | IntOrKey = int | Array 6 | 7 | 8 | def _validate_seed(seed: IntOrKey) -> Array: 9 | return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed 10 | 11 | 12 | def categorical_sample( 13 | n_obs: int, 14 | n_cats: int, 15 | seed: IntOrKey = 0, 16 | ) -> jnp.ndarray: 17 | return jax.random.categorical(_validate_seed(seed), jnp.ones(n_cats), shape=(n_obs,)) 18 | 19 | 20 | def normal_sample( 21 | n_obs: int, 22 | mean: float = 0.0, 23 | var: float = 1.0, 24 | seed: IntOrKey = 0, 25 | ) -> jnp.ndarray: 26 | return jax.random.multivariate_normal(_validate_seed(seed), jnp.ones(n_obs) * mean, jnp.eye(n_obs) * var) 27 | 28 | 29 | def poisson_sample( 30 | n_obs: int, 31 | n_vars: int, 32 | rate: float = 1.0, 33 | seed: IntOrKey = 0, 34 | ) -> jnp.ndarray: 35 | return jax.random.poisson(_validate_seed(seed), rate, shape=(n_obs, n_vars)) 36 | -------------------------------------------------------------------------------- /tests/utils/test_pca.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import pytest 6 | from sklearn.decomposition import PCA 7 | 8 | import scib_metrics 9 | from tests.utils.sampling import poisson_sample 10 | 11 | PCA_PARAMS = list(product([10, 100, 1000], [10, 100, 1000])) 12 | 13 | 14 | @pytest.mark.parametrize("n_obs, n_vars", PCA_PARAMS) 15 | def test_pca(n_obs: int, n_vars: int): 16 | def _test_pca(n_obs: int, n_vars: int, n_components: int, eps: float = 1e-4): 17 | X = poisson_sample(n_obs, n_vars) 18 | max_components = min(X.shape) 19 | pca = scib_metrics.utils.pca(X, n_components=n_components, return_svd=True) 20 | 21 | # SANITY CHECKS 22 | assert pca.coordinates.shape == (X.shape[0], n_components) 23 | assert pca.components.shape == (n_components, X.shape[1]) 24 | assert pca.variance.shape == (n_components,) 25 | assert pca.variance_ratio.shape == (n_components,) 26 | # SVD should not be truncated to n_components 27 | assert pca.svd is not None 28 | assert pca.svd.u.shape == (X.shape[0], max_components) 29 | assert pca.svd.s.shape == (max_components,) 30 | assert pca.svd.v.shape == (max_components, X.shape[1]) 31 | 32 | # VALUE CHECKS 33 | # TODO(martinkim0): Currently not checking coordinates and components, 34 | # TODO(martinkim0): implementations differ very slightly and not sure why. 35 | pca_true = PCA(n_components=n_components, svd_solver="full").fit(X) 36 | # assert np.allclose(pca_true.transform(X), pca.coordinates, atol=eps) 37 | # assert np.allclose(pca_true.components_, pca.components, atol=eps) 38 | assert np.allclose(pca_true.singular_values_, pca.svd.s[:n_components], atol=eps) 39 | assert np.allclose(pca_true.explained_variance_, pca.variance, atol=eps) 40 | assert np.allclose(pca_true.explained_variance_ratio_, pca.variance_ratio, atol=eps) 41 | # Use arpack iff n_components < max_components 42 | if n_components < max_components: 43 | pca_true = PCA(n_components=n_components, svd_solver="arpack").fit(X) 44 | # assert jnp.allclose(pca_true.transform(X), pca.coordinates, atol=eps) 45 | # assert jnp.allclose(pca_true.components_, pca.components, atol=eps) 46 | assert jnp.allclose(pca_true.singular_values_, pca.svd.s[:n_components], atol=eps) 47 | assert jnp.allclose(pca_true.explained_variance_, pca.variance, atol=eps) 48 | assert jnp.allclose(pca_true.explained_variance_ratio_, pca.variance_ratio, atol=eps) 49 | 50 | max_components = min(n_obs, n_vars) 51 | _test_pca(n_obs, n_vars, n_components=max_components) 52 | _test_pca(n_obs, n_vars, n_components=max_components - 1) 53 | _test_pca(n_obs, n_vars, n_components=int(max_components / 2)) 54 | _test_pca(n_obs, n_vars, n_components=1) 55 | -------------------------------------------------------------------------------- /tests/utils/test_pcr.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from scib.metrics import pc_regression 7 | 8 | import scib_metrics 9 | from scib_metrics.utils import get_ndarray 10 | from tests.utils.sampling import categorical_sample, normal_sample, poisson_sample 11 | 12 | PCR_PARAMS = list(product([10, 100, 1000], [10, 100, 1000], [False])) 13 | # TODO(martinkim0): Currently not testing categorical covariates because of 14 | # TODO(martinkim0): reproducibility issues with original scib. See comment in PR #16. 15 | 16 | 17 | @pytest.mark.parametrize("n_obs, n_vars, categorical", PCR_PARAMS) 18 | def test_pcr(n_obs, n_vars, categorical): 19 | def _test_pcr(n_obs: int, n_vars: int, n_components: int, categorical: bool, eps=1e-3, seed=123): 20 | X = poisson_sample(n_obs, n_vars, seed=seed) 21 | covariate = categorical_sample(n_obs, int(n_obs / 5)) if categorical else normal_sample(n_obs, seed=seed) 22 | 23 | pcr_true = pc_regression( 24 | get_ndarray(X), 25 | pd.Categorical(get_ndarray(covariate)) if categorical else get_ndarray(covariate), 26 | n_comps=n_components, 27 | ) 28 | pcr = scib_metrics.utils.principal_component_regression( 29 | X, 30 | covariate, 31 | categorical=categorical, 32 | n_components=n_components, 33 | ) 34 | assert np.allclose(pcr_true, pcr, atol=eps) 35 | 36 | max_components = min(n_obs, n_vars) 37 | _test_pcr(n_obs, n_vars, n_components=max_components, categorical=categorical) 38 | _test_pcr(n_obs, n_vars, n_components=max_components - 1, categorical=categorical) 39 | _test_pcr(n_obs, n_vars, n_components=int(max_components / 2), categorical=categorical) 40 | _test_pcr(n_obs, n_vars, n_components=1, categorical=categorical) 41 | --------------------------------------------------------------------------------