├── .bumpversion.cfg
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── codecov.yml
├── release.yml
└── workflows
│ ├── build_image.yml
│ ├── release.yaml
│ └── test.yaml
├── .gitignore
├── .gitmodules
├── .mypy.ini
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── Dockerfile
├── LICENSE
├── README.md
├── asv.conf.json
├── benchmarks
├── README.md
├── __init__.py
├── spatialdata_benchmark.py
└── utils.py
├── docs
├── Makefile
├── _static
│ ├── .gitkeep
│ ├── css
│ │ └── custom.css
│ └── img
│ │ └── spatialdata_horizontal.png
├── _templates
│ ├── .gitkeep
│ └── autosummary
│ │ ├── base.rst
│ │ ├── class.rst
│ │ └── function.rst
├── api.md
├── api
│ ├── SpatialData.md
│ ├── data_formats.md
│ ├── dataloader.md
│ ├── datasets.md
│ ├── io.md
│ ├── models.md
│ ├── models_utils.md
│ ├── operations.md
│ ├── testing.md
│ ├── transformations.md
│ └── transformations_utils.md
├── changelog.md
├── conf.py
├── contributing.md
├── design_doc.md
├── extensions
│ └── typed_returns.py
├── glossary.md
├── index.md
├── installation.md
├── references.bib
└── references.md
├── pyproject.toml
├── src
└── spatialdata
│ ├── __init__.py
│ ├── __main__.py
│ ├── _bridges
│ └── __init__.py
│ ├── _core
│ ├── __init__.py
│ ├── _deepcopy.py
│ ├── _elements.py
│ ├── _utils.py
│ ├── centroids.py
│ ├── concatenate.py
│ ├── data_extent.py
│ ├── operations
│ │ ├── __init__.py
│ │ ├── _utils.py
│ │ ├── aggregate.py
│ │ ├── map.py
│ │ ├── rasterize.py
│ │ ├── rasterize_bins.py
│ │ ├── transform.py
│ │ └── vectorize.py
│ ├── query
│ │ ├── __init__.py
│ │ ├── _utils.py
│ │ ├── relational_query.py
│ │ └── spatial_query.py
│ ├── spatialdata.py
│ └── validation.py
│ ├── _docs.py
│ ├── _io
│ ├── __init__.py
│ ├── _utils.py
│ ├── format.py
│ ├── io_points.py
│ ├── io_raster.py
│ ├── io_shapes.py
│ ├── io_table.py
│ └── io_zarr.py
│ ├── _logging.py
│ ├── _types.py
│ ├── _utils.py
│ ├── config.py
│ ├── dataloader
│ ├── __init__.py
│ └── datasets.py
│ ├── datasets.py
│ ├── io
│ └── __init__.py
│ ├── models
│ ├── __init__.py
│ ├── _utils.py
│ └── models.py
│ ├── testing.py
│ └── transformations
│ ├── __init__.py
│ ├── _utils.py
│ ├── ngff
│ ├── __init__.py
│ ├── _utils.py
│ ├── ngff_coordinate_system.py
│ └── ngff_transformations.py
│ ├── operations.py
│ └── transformations.py
└── tests
├── __init__.py
├── conftest.py
├── core
├── __init__.py
├── operations
│ ├── __init__.py
│ ├── test_aggregations.py
│ ├── test_map.py
│ ├── test_rasterize.py
│ ├── test_rasterize_bins.py
│ ├── test_spatialdata_operations.py
│ ├── test_transform.py
│ └── test_vectorize.py
├── query
│ ├── __init__.py
│ ├── test_relational_query.py
│ ├── test_relational_query_match_sdata_to_table.py
│ └── test_spatial_query.py
├── test_centroids.py
├── test_data_extent.py
├── test_deepcopy.py
├── test_get_attrs.py
└── test_validation.py
├── data
├── multipolygon.json
├── points.json
└── polygon.json
├── dataloader
├── __init__.py
└── test_datasets.py
├── datasets
├── __init__.py
└── test_datasets.py
├── io
├── __init__.py
├── test_format.py
├── test_metadata.py
├── test_multi_table.py
├── test_partial_read.py
├── test_pyramids_performance.py
├── test_readwrite.py
├── test_utils.py
└── test_versions.py
├── models
├── __init__.py
└── test_models.py
├── transformations
├── __init__.py
├── ngff
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_ngff_coordinate_system.py
│ └── test_ngff_transformations.py
├── test_transformations.py
└── test_transformations_utils.py
└── utils
├── __init__.py
├── test_element_utils.py
├── test_sanitize.py
└── test_testing.py
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.0.1
3 | tag = True
4 | commit = False
5 |
6 | [bumpversion:file:./pyproject.toml]
7 | search = version = "{current_version}"
8 | replace = version = "{new_version}"
9 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | line_length = 120
11 |
12 | [Makefile]
13 | indent_style = tab
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ""
5 | labels: ""
6 | assignees: ""
7 | ---
8 |
9 | **Recommendation: attach a minimal working example**
10 | Generally, the easier it is for us to reproduce the issue, the faster we can work on it. It is not required, but if you can, please:
11 |
12 | 1. Reproduce using the [`blobs` dataset](https://spatialdata.scverse.org/en/stable/api/datasets.html#spatialdata.datasets.blobs)
13 |
14 | ```python
15 | from spatialdata.datasets import blobs
16 |
17 | sdata = blobs()
18 | ```
19 |
20 | You can also use [`blobs_annotating_element`](https://spatialdata.scverse.org/en/stable/api/datasets.html#spatialdata.datasets.blobs_annotating_element) for more
21 | control:
22 |
23 | ```
24 | from spatialdata.datasets import blobs_annotating_element
25 | sdata = blobs_annotating_element('blobs_labels')
26 | ```
27 |
28 | 2. If the above is not possible, reproduce using a public dataset and explain how we can download the data.
29 | 3. If the data is private, consider sharing an anonymized version/subset via a [Zulip private message](https://scverse.zulipchat.com/#user/480560), or provide screenshots/GIFs showing the behavior.
30 |
31 | **Describe the bug**
32 | A clear and concise description of what the bug is; please report only one bug per issue.
33 |
34 | **To Reproduce**
35 | Steps to reproduce the behavior:
36 |
37 | 1. Go to '...'
38 | 2. Click on '....'
39 | 3. Scroll down to '....'
40 | 4. See error
41 |
42 | **Expected behavior**
43 | A clear and concise description of what you expected to happen.
44 |
45 | **Screenshots**
46 | If applicable, add screenshots to help explain your problem.
47 |
48 | **Desktop (optional):**
49 |
50 | - OS: [e.g. macOS, Windows, Linux]
51 | - Version [e.g. 22]
52 |
53 | **Additional context**
54 | Add any other context about the problem here.
55 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ""
5 | labels: ""
6 | assignees: ""
7 | ---
8 |
9 | **Is your feature request related to a problem? Please describe.**
10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
11 |
12 | **Describe the solution you'd like**
13 | A clear and concise description of what you want to happen.
14 |
15 | **Describe alternatives you've considered**
16 | A clear and concise description of any alternative solutions or features you've considered.
17 |
18 | **Additional context**
19 | Add any other context or screenshots about the feature request here.
20 |
--------------------------------------------------------------------------------
/.github/codecov.yml:
--------------------------------------------------------------------------------
1 | # Based on pydata/xarray
2 | codecov:
3 | require_ci_to_pass: false
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | # Require 1% coverage, i.e., always succeed
10 | target: 1
11 | patch: false
12 | changes: false
13 |
14 | comment:
15 | layout: "diff, flags, files"
16 | behavior: once
17 | require_base: false
18 |
--------------------------------------------------------------------------------
/.github/release.yml:
--------------------------------------------------------------------------------
1 | changelog:
2 | exclude:
3 | labels:
4 | - release-ignore
5 | authors:
6 | - pre-commit-ci
7 | categories:
8 | - title: Added
9 | labels:
10 | - "release-added"
11 | - title: Changed
12 | labels:
13 | - "release-changed"
14 | - title: Deprecated
15 | labels:
16 | - "release-deprecated"
17 | - title: Removed
18 | labels:
19 | - "release-removed"
20 | - title: Fixed
21 | labels:
22 | - "release-fixed"
23 | - title: Security
24 | labels:
25 | - "release-security"
26 | - title: Other Changes
27 | labels:
28 | - "*"
29 |
--------------------------------------------------------------------------------
/.github/workflows/build_image.yml:
--------------------------------------------------------------------------------
1 | name: Build Docker image
2 |
3 | on:
4 | workflow_dispatch:
5 | # schedule:
6 | # - cron: '0 0 * * *' # run daily at midnight UTC
7 |
8 | env:
9 | REGISTRY: ghcr.io
10 | IMAGE_NAME: ${{ github.repository }}
11 |
12 | concurrency:
13 | group: ${{ github.workflow }}-${{ github.ref }}
14 | cancel-in-progress: true
15 |
16 | jobs:
17 | build:
18 | runs-on: ubuntu-latest
19 |
20 | defaults:
21 | run:
22 | shell: bash -e {0} # -e to fail on error
23 |
24 | permissions:
25 | contents: read
26 | packages: write
27 | attestations: write
28 | id-token: write
29 |
30 | steps:
31 | - name: Checkout code
32 | uses: actions/checkout@v4
33 |
34 | - name: Set up Python
35 | uses: actions/setup-python@v4
36 | with:
37 | python-version: "3.x"
38 |
39 | - name: Upgrade pip
40 | run: pip install pip
41 |
42 | - name: Get latest versions
43 | id: get_versions
44 | run: |
45 | SPATIALDATA_VERSION=$(pip index versions spatialdata | grep "Available versions" | sed 's/Available versions: //' | awk -F', ' '{print $1}')
46 | SPATIALDATA_IO_VERSION=$(pip index versions spatialdata-io | grep "Available versions" | sed 's/Available versions: //' | awk -F', ' '{print $1}')
47 | SPATIALDATA_PLOT_VERSION=$(pip index versions spatialdata-plot | grep "Available versions" | sed 's/Available versions: //' | awk -F', ' '{print $1}')
48 | echo "SPATIALDATA_VERSION=${SPATIALDATA_VERSION}" >> $GITHUB_ENV
49 | echo "SPATIALDATA_IO_VERSION=${SPATIALDATA_IO_VERSION}" >> $GITHUB_ENV
50 | echo "SPATIALDATA_PLOT_VERSION=${SPATIALDATA_PLOT_VERSION}" >> $GITHUB_ENV
51 |
52 | - name: Check if image tag exists
53 | id: check_tag
54 | env:
55 | IMAGE_TAG_SUFFIX: spatialdata${{ env.SPATIALDATA_VERSION }}_spatialdata-io${{ env.SPATIALDATA_IO_VERSION }}_spatialdata-plot${{ env.SPATIALDATA_PLOT_VERSION }}
56 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
57 | run: |
58 | # Define the API URL
59 | API_URL="https://api.github.com/orgs/scverse/packages/container/spatialdata/versions"
60 |
61 | # Fetch all existing versions
62 | existing_tags=$(curl -s -H "Authorization: token $GITHUB_TOKEN" $API_URL | jq -r '.[].metadata.container.tags[]')
63 |
64 | # Debug: Output all existing tags
65 | echo "Existing tags:"
66 | echo "$existing_tags"
67 |
68 | # Check if the constructed tag exists
69 | if echo "$existing_tags" | grep -q "$IMAGE_TAG_SUFFIX"; then
70 | echo "Image tag $IMAGE_TAG_SUFFIX already exists. Skipping build."
71 | echo "skip_build=true" >> $GITHUB_ENV
72 | else
73 | echo "Image tag $IMAGE_TAG_SUFFIX does not exist. Proceeding with build."
74 | echo "skip_build=false" >> $GITHUB_ENV
75 | echo "IMAGE_TAG_SUFFIX=${IMAGE_TAG_SUFFIX}" >> $GITHUB_ENV
76 | fi
77 |
78 | - name: Login to GitHub Container Registry
79 | if: ${{ env.skip_build == 'false' }}
80 | uses: docker/login-action@v3
81 | with:
82 | registry: ${{ env.REGISTRY }}
83 | username: ${{ github.actor }}
84 | password: ${{ secrets.GITHUB_TOKEN }}
85 |
86 | - uses: docker/build-push-action@v5
87 | if: ${{ env.skip_build == 'false' }}
88 | env:
89 | IMAGE_TAG: ${{ env.REGISTRY }}/scverse/spatialdata:${{ env.IMAGE_TAG_SUFFIX }}
90 | with:
91 | context: .
92 | file: ./Dockerfile
93 | push: true
94 | cache-from: type=registry,ref=${{ env.REGISTRY }}/scverse/spatialdata:buildcache
95 | cache-to: type=inline,ref=${{ env.REGISTRY }}/scverse/spatialdata:buildcache
96 | build-args: |
97 | SPATIALDATA_VERSION=${{ env.SPATIALDATA_VERSION }}
98 | SPATIALDATA_IO_VERSION=${{ env.SPATIALDATA_IO_VERSION }}
99 | SPATIALDATA_PLOT_VERSION=${{ env.SPATIALDATA_PLOT_VERSION }}
100 | tags: ${{ env.IMAGE_TAG }}
101 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | package_and_release:
9 | runs-on: ubuntu-latest
10 | if: startsWith(github.ref, 'refs/tags/v')
11 | steps:
12 | - uses: actions/checkout@v3
13 | - name: Set up Python 3.12
14 | uses: actions/setup-python@v4
15 | with:
16 | python-version: "3.12"
17 | cache: pip
18 | - name: Install build dependencies
19 | run: python -m pip install --upgrade pip wheel twine build
20 | - name: Build package
21 | run: python -m build
22 | - name: Check package
23 | run: twine check --strict dist/*.whl
24 | - name: Install hatch
25 | run: pip install hatch
26 | - name: Build project for distribution
27 | run: hatch build
28 | - name: Publish a Python distribution to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 | with:
31 | password: ${{ secrets.PYPI_API_TOKEN }}
32 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | tags:
7 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
8 | pull_request:
9 | branches: "*"
10 |
11 | jobs:
12 | test:
13 | runs-on: ${{ matrix.os }}
14 | defaults:
15 | run:
16 | shell: bash -e {0} # -e to fail on error
17 |
18 | strategy:
19 | fail-fast: false
20 | matrix:
21 | python: ["3.10", "3.12"]
22 | os: [ubuntu-latest]
23 | include:
24 | - os: macos-latest
25 | python: "3.10"
26 | - os: macos-latest
27 | python: "3.12"
28 | pip-flags: "--pre"
29 | name: "Python 3.12 (pre-release)"
30 |
31 | env:
32 | OS: ${{ matrix.os }}
33 | PYTHON: ${{ matrix.python }}
34 |
35 | steps:
36 | - uses: actions/checkout@v2
37 | - uses: astral-sh/setup-uv@v5
38 | id: setup-uv
39 | with:
40 | version: "latest"
41 | python-version: ${{ matrix.python }}
42 | - name: Install dependencies
43 | run: "uv sync --extra test"
44 | - name: Test
45 | env:
46 | MPLBACKEND: agg
47 | PLATFORM: ${{ matrix.os }}
48 | DISPLAY: :42
49 | run: |
50 | uv run pytest --cov --color=yes --cov-report=xml
51 | - name: Upload coverage to Codecov
52 | uses: codecov/codecov-action@v4
53 | with:
54 | name: coverage
55 | verbose: true
56 | env:
57 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Temp files
2 | .DS_Store
3 | *~
4 | .dmypy.json
5 |
6 | # Compiled files
7 | __pycache__/
8 |
9 | # Distribution / packaging
10 | /build/
11 | /dist/
12 | /*.egg-info/
13 |
14 | # Tests and coverage
15 | /.pytest_cache/
16 | /.cache/
17 | /data/
18 |
19 | # docs
20 | docs/_build
21 | !docs/api/.md
22 | docs/**/generated
23 |
24 | # IDEs
25 | /.idea/
26 | .vscode
27 |
28 | # data
29 | *.zarr/
30 |
31 | # temp files
32 | temp/
33 |
34 |
35 | # symlinks (luca) for extending the refactoring to satellite projects
36 | napari-spatialdata
37 | spatialdata-io
38 | spatialdata-notebooks
39 | spatialdata-plot
40 | spatialdata-sandbox
41 |
42 | # notebooks
43 | *.ipynb
44 |
45 | # version file
46 | _version.py
47 |
48 | # other
49 | node_modules/
50 |
51 | .asv/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "docs/tutorials/notebooks"]
2 | path = docs/tutorials/notebooks
3 | url = https://github.com/scverse/spatialdata-notebooks
4 |
--------------------------------------------------------------------------------
/.mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.10
3 | plugins = numpy.typing.mypy_plugin
4 |
5 | ignore_errors = False
6 | warn_redundant_casts = True
7 | warn_unused_configs = True
8 | warn_unused_ignores = False
9 |
10 | disallow_untyped_calls = False
11 | disallow_untyped_defs = True
12 | disallow_incomplete_defs = True
13 | disallow_any_generics = True
14 |
15 | strict_optional = True
16 | strict_equality = True
17 | warn_return_any = True
18 | warn_unreachable = False
19 | check_untyped_defs = True
20 | ; because of docrep
21 | allow_untyped_decorators = True
22 | no_implicit_optional = True
23 | no_implicit_reexport = True
24 | no_warn_no_return = True
25 |
26 | show_error_codes = True
27 | show_column_numbers = True
28 | error_summary = True
29 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - pre-commit
6 | - pre-push
7 | minimum_pre_commit_version: 2.16.0
8 | ci:
9 | skip: []
10 | repos:
11 | - repo: https://github.com/rbubley/mirrors-prettier
12 | rev: v3.5.3
13 | hooks:
14 | - id: prettier
15 | - repo: https://github.com/pre-commit/mirrors-mypy
16 | rev: v1.15.0
17 | hooks:
18 | - id: mypy
19 | additional_dependencies: [numpy, types-requests]
20 | exclude: tests/|docs/
21 | - repo: https://github.com/astral-sh/ruff-pre-commit
22 | rev: v0.11.11
23 | hooks:
24 | - id: ruff
25 | args: [--fix, --exit-non-zero-on-fix]
26 | - id: ruff-format
27 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html
2 | version: 2
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.10"
7 | sphinx:
8 | configuration: docs/conf.py
9 | fail_on_warning: true
10 | python:
11 | install:
12 | - method: pip
13 | path: .
14 | extra_requirements:
15 | - docs
16 | - torch
17 | submodules:
18 | include:
19 | - "docs/tutorials/notebooks"
20 | recursive: true
21 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/CHANGELOG.md
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG TARGETPLATFORM=linux/amd64
2 |
3 | # Use the specified platform to pull the correct base image.
4 | # Override TARGETPLATFORM during build for different architectures, such as linux/arm64 for Apple Silicon.
5 | # For example, to build for ARM64 architecture (e.g., Apple Silicon),
6 | # use the following command on the command line:
7 | #
8 | # docker build --build-arg TARGETPLATFORM=linux/arm64 -t my-arm-image .
9 | #
10 | # Similarly, to build for the default x86_64 architecture, you can use:
11 | #
12 | # docker build --build-arg TARGETPLATFORM=linux/amd64 -t my-amd64-image .
13 | #
14 | FROM --platform=$TARGETPLATFORM ubuntu:latest
15 | LABEL authors="Luca Marconato"
16 |
17 | ENV PYTHONUNBUFFERED=1
18 |
19 | ARG SPATIALDATA_VERSION
20 | ARG SPATIALDATA_IO_VERSION
21 | ARG SPATIALDATA_PLOT_VERSION
22 |
23 | # debugging
24 | RUN echo "Target Platform: ${TARGETPLATFORM}" && \
25 | echo "spatialdata version: ${SPATIALDATA_VERSION}" && \
26 | echo "spatialdata-io version: ${SPATIALDATA_IO_VERSION}" && \
27 | echo "spatialdata-plot version: ${SPATIALDATA_PLOT_VERSION}"
28 |
29 | # Update and install system dependencies.
30 | RUN apt-get update && \
31 | apt-get install -y --no-install-recommends \
32 | build-essential \
33 | python3-venv \
34 | python3-dev \
35 | git \
36 | && rm -rf /var/lib/apt/lists/*
37 |
38 | # setup python virtual environment
39 | RUN python3 -m venv /opt/venv
40 | ENV PATH="/opt/venv/bin:$PATH"
41 | RUN pip install --upgrade pip wheel
42 |
43 | # Install the libraries with specific versions
44 | RUN pip install --no-cache-dir \
45 | spatialdata[torch]==${SPATIALDATA_VERSION} \
46 | spatialdata-io==${SPATIALDATA_IO_VERSION} \
47 | spatialdata-plot==${SPATIALDATA_PLOT_VERSION}
48 |
49 | LABEL spatialdata_version="${SPATIALDATA_VERSION}" \
50 | spatialdata_io_version="${SPATIALDATA_IO_VERSION}" \
51 | spatialdata_plot_version="${SPATIALDATA_PLOT_VERSION}"
52 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, scverse®
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # SpatialData: an open and universal framework for processing spatial omics data.
4 |
5 | [![Tests][badge-tests]][link-tests]
6 | [](https://results.pre-commit.ci/latest/github/scverse/spatialdata/main)
7 | [](https://codecov.io/gh/scverse/spatialdata)
8 | [](https://spatialdata.scverse.org/en/latest/)
9 | [](https://zenodo.org/badge/latestdoi/487366481)
10 | [](https://pepy.tech/project/spatialdata)
11 | [](https://github.com/scverse/spatialdata/actions/workflows/release.yaml)
12 | [![Documentation][badge-pypi]][link-pypi]
13 | [](https://anaconda.org/conda-forge/spatialdata)
14 |
15 | [badge-pypi]: https://badge.fury.io/py/spatialdata.svg
16 | [link-pypi]: https://pypi.org/project/spatialdata/
17 |
18 | SpatialData is a data framework that comprises a FAIR storage format and a collection of python libraries for performant access, alignment, and processing of uni- and multi-modal spatial omics datasets. This repository contains the core spatialdata library. See the links below to learn more about other packages in the SpatialData ecosystem.
19 |
20 | - [spatialdata-io](https://github.com/scverse/spatialdata-io): load data from common spatial omics technologies into spatialdata.
21 | - [spatialdata-plot](https://github.com/scverse/spatialdata-plot): Static plotting library for spatialdata.
22 | - [napari-spatialdata](https://github.com/scverse/napari-spatialdata): napari plugin for interactive exploration and annotation of spatial data.
23 |
24 | [//]: # "numfocus-fiscal-sponsor-attribution"
25 |
26 | spatialdata is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
27 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs.
28 |
29 | The spatialdata project also received support by the Chan Zuckerberg Initiative.
30 |
31 |
38 |
39 |
40 | 
41 |
42 | - **The library is currently under review.** We expect there to be changes as the community provides feedback. We have an announcement channel for communicating these changes, please see the contact section below.
43 | - The SpatialData storage format is built on top of the [OME-NGFF](https://ngff.openmicroscopy.org/latest/) specification.
44 |
45 | ## Getting started
46 |
47 | Please refer to the [documentation][link-docs]. In particular:
48 |
49 | - [API documentation][link-api].
50 | - [Design doc][link-design-doc] (includes the roadmap).
51 | - [Example notebooks][link-notebooks].
52 |
53 | Another useful resource to get started is the source code of the [`spatialdata-io`](https://github.com/scverse/spatialdata-io) package, which shows example of how to read data from common technologies.
54 |
55 | ## Installation
56 |
57 | Check out the docs for more complete [installation instructions](https://spatialdata.scverse.org/en/stable/installation.html). To get started with the "batteries included" installation, you can install via pip:
58 |
59 | ```bash
60 | pip install "spatialdata[extra]"
61 | ```
62 |
63 | ~~or via conda:~~
64 | Update Feb 2025: `spatialdata` cannot be currently be installed via `conda` because some dependencies of our dependencies are not updated in `conda-forge` and we are still waiting for an update. Please install from `pip`; the latest versions of the `spatialdata` libraries are always available via `PyPI`.
65 |
66 | ```bash
67 | mamba install -c conda-forge spatialdata napari-spatialdata spatialdata-io spatialdata-plot
68 | ```
69 |
70 | ## Limitations
71 |
72 | - Code only manually tested for Windows machines. Currently the framework is being developed using Linux, macOS and Windows machines, but it is automatically tested only for Linux and macOS machines.
73 |
74 | ## Contact
75 |
76 | To get involved in the discussion, or if you need help to get started, you are welcome to use the following options.
77 |
78 | - Chat via [`scverse` Zulip](https://scverse.zulipchat.com/#narrow/stream/315824-spatial) (public or 1 to 1).
79 | - Forum post in the [scverse discourse forum](https://discourse.scverse.org/).
80 | - Bug report/feature request via the [GitHub issue tracker][issue-tracker].
81 | - Zoom call as part of the SpatialData Community Meetings, held every 2 weeks on Thursday, [schedule here](https://hackmd.io/enWU826vRai-JYaL7TZaSw).
82 |
83 | Finally, especially relevant for for developers that are building a library upon `spatialdata`, please follow this channel for:
84 |
85 | - Announcements on new features and important changes [Zulip](https://imagesc.zulipchat.com/#narrow/stream/329057-scverse/topic/spatialdata.20announcements).
86 |
87 | ## Citation
88 |
89 | Marconato, L., Palla, G., Yamauchi, K.A. et al. SpatialData: an open and universal data framework for spatial omics. Nat Methods (2024). https://doi.org/10.1038/s41592-024-02212-x
90 |
91 |
92 |
93 | [scverse-discourse]: https://discourse.scverse.org/
94 | [issue-tracker]: https://github.com/scverse/spatialdata/issues
95 | [design doc]: https://scverse-spatialdata.readthedocs.io/en/stable/design_doc.html
96 | [link-docs]: https://spatialdata.scverse.org/en/stable/
97 | [link-api]: https://spatialdata.scverse.org/en/stable/api.html
98 | [link-design-doc]: https://spatialdata.scverse.org/en/stable/design_doc.html
99 | [link-notebooks]: https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks.html
100 | [badge-tests]: https://github.com/scverse/spatialdata/actions/workflows/test.yaml/badge.svg
101 | [link-tests]: https://github.com/scverse/spatialdata/actions/workflows/test.yaml
102 |
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarking SpatialData code
2 |
3 | This `benchmarks/` folder contains code to benchmark the performance of the SpatialData code. You can use it to see how code behaves for different options or data sizes. For more information, check the [SpatialData Contributing Guide](https://spatialdata.scverse.org/en/stable/contributing.html).
4 |
5 | Note that to run code, your current working directory should be the SpatialData repo, not this `benchmarks/` folder.
6 |
7 | ## Installation
8 |
9 | The benchmarks use the [airspeed velocity](https://asv.readthedocs.io/en/stable/) (asv) framework. Install it with the `benchmark` option:
10 |
11 | ```
12 | pip install -e '.[docs,test,benchmark]'
13 | ```
14 |
15 | ## Usage
16 |
17 | Running all the benchmarks is usually not needed. You run the benchmark using `asv run`. See the [asv documentation](https://asv.readthedocs.io/en/stable/commands.html#asv-run) for interesting arguments, like selecting the benchmarks you're interested in by providing a regex pattern `-b` or `--bench` that links to a function or class method e.g. the option `-b timeraw_import_inspect` selects the function `timeraw_import_inspect` in `benchmarks/spatialdata_benchmark.py`. You can run the benchmark in your current environment with `--python=same`. Some example benchmarks:
18 |
19 | Importing the SpatialData library can take around 4 seconds:
20 |
21 | ```
22 | PYTHONWARNINGS="ignore" asv run --python=same --show-stderr -b timeraw_import_inspect
23 | Couldn't load asv.plugins._mamba_helpers because
24 | No module named 'conda'
25 | · Discovering benchmarks
26 | · Running 1 total benchmarks (1 commits * 1 environments * 1 benchmarks)
27 | [ 0.00%] ·· Benchmarking existing-py_opt_homebrew_Caskroom_mambaforge_base_envs_spatialdata2_bin_python3.12
28 | [50.00%] ··· Running (spatialdata_benchmark.timeraw_import_inspect--).
29 | [100.00%] ··· spatialdata_benchmark.timeraw_import_inspect 3.65±0.2s
30 | ```
31 |
32 | Querying using a bounding box without a spatial index is highly impacted by large amounts of points (transcripts), more than table rows (cells).
33 |
34 | ```
35 | $ PYTHONWARNINGS="ignore" asv run --python=same --show-stderr -b time_query_bounding_box
36 |
37 | [100.00%] ··· ======== ============ ============= ============= ==============
38 | -- filter_table / n_transcripts_per_cell
39 | -------- -------------------------------------------------------
40 | length True / 100 True / 1000 False / 100 False / 1000
41 | ======== ============ ============= ============= ==============
42 | 100 177±5ms 195±4ms 168±0.5ms 186±2ms
43 | 1000 195±3ms 402±2ms 187±3ms 374±4ms
44 | 10000 722±3ms 2.65±0.01s 389±3ms 2.22±0.02s
45 | ======== ============ ============= ============= ==============
46 | ```
47 |
48 | You can use `asv` to run all the benchmarks in their own environment. This can take a long time, so it is not recommended for regular use:
49 |
50 | ```
51 | $ asv run
52 | Couldn't load asv.plugins._mamba_helpers because
53 | No module named 'conda'
54 | · Creating environments....
55 | · Discovering benchmarks..
56 | ·· Uninstalling from virtualenv-py3.12
57 | ·· Building a89d16d8 for virtualenv-py3.12
58 | ·· Installing a89d16d8 into virtualenv-py3.12.............
59 | · Running 6 total benchmarks (1 commits * 1 environments * 6 benchmarks)
60 | [ 0.00%] · For spatialdata commit a89d16d8 :
61 | [ 0.00%] ·· Benchmarking virtualenv-py3.12
62 | [25.00%] ··· Running (spatialdata_benchmark.TimeMapRaster.time_map_blocks--)...
63 | ...
64 | [100.00%] ··· spatialdata_benchmark.timeraw_import_inspect 3.33±0.06s
65 | ```
66 |
67 | ## Notes
68 |
69 | When using PyCharm, remember to set [Configuration](https://www.jetbrains.com/help/pycharm/run-debug-configuration.html) to include the benchmark module, as this is separate from the main code module.
70 |
71 | In Python, you can run a module using the following command:
72 |
73 | ```
74 | python -m benchmarks.spatialdata_benchmark
75 | ```
76 |
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/benchmarks/__init__.py
--------------------------------------------------------------------------------
/benchmarks/spatialdata_benchmark.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | # Write the benchmarking functions here.
4 | # See "Writing benchmarks" in the asv docs for more information.
5 | import spatialdata as sd
6 |
7 | from .utils import cluster_blobs
8 |
9 |
10 | class MemorySpatialData:
11 | # TODO: see what the memory overhead is e.g. Python interpreter...
12 | """Calculate the peak memory usage is for artificial datasets with increasing channels."""
13 |
14 | def peakmem_list(self):
15 | sdata: sd.SpatialData = sd.datasets.blobs(n_channels=1)
16 | return sdata
17 |
18 | def peakmem_list2(self):
19 | sdata: sd.SpatialData = sd.datasets.blobs(n_channels=2)
20 | return sdata
21 |
22 |
23 | def timeraw_import_inspect():
24 | """Time the import of the spatialdata module."""
25 | return """
26 | import spatialdata
27 | """
28 |
29 |
30 | class TimeMapRaster:
31 | """Time the."""
32 |
33 | params = [100, 1000, 10_000]
34 | param_names = ["length"]
35 |
36 | def setup(self, length):
37 | self.sdata = cluster_blobs(length=length)
38 |
39 | def teardown(self, _):
40 | del self.sdata
41 |
42 | def time_map_blocks(self, _):
43 | sd.map_raster(self.sdata["blobs_image"], lambda x: x + 1)
44 |
45 |
46 | class TimeQueries:
47 | params = ([100, 1_000, 10_000], [True, False], [100, 1_000])
48 | param_names = ["length", "filter_table", "n_transcripts_per_cell"]
49 |
50 | def setup(self, length, filter_table, n_transcripts_per_cell):
51 | import shapely
52 |
53 | self.sdata = cluster_blobs(length=length, n_transcripts_per_cell=n_transcripts_per_cell)
54 | self.polygon = shapely.box(0, 0, length // 2, length // 2)
55 |
56 | def teardown(self, length, filter_table, n_transcripts_per_cell):
57 | del self.sdata
58 |
59 | def time_query_bounding_box(self, length, filter_table, n_transcripts_per_cell):
60 | self.sdata.query.bounding_box(
61 | axes=["x", "y"],
62 | min_coordinate=[0, 0],
63 | max_coordinate=[length // 2, length // 2],
64 | target_coordinate_system="global",
65 | filter_table=filter_table,
66 | )
67 |
68 | def time_query_polygon_box(self, length, filter_table, n_transcripts_per_cell):
69 | sd.polygon_query(
70 | self.sdata,
71 | self.polygon,
72 | target_coordinate_system="global",
73 | filter_table=filter_table,
74 | )
75 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= python3 -msphinx
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
22 | clean:
23 | rm -r "$(BUILDDIR)"
24 | rm -r "generated"
25 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/docs/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* TODO: this is class a dupliace of the code from tutorials/notebooks/_static/css/custom.css; ideally we remove the following code and we import from such path */
2 | .custom-card {
3 | border: 1px solid #dfe2e5;
4 | border-radius: 6px;
5 | padding-top: 1rem;
6 | padding-bottom: 1rem;
7 | padding-left: 1rem; /* Add internal padding on the left side */
8 | padding-right: 1rem; /* Add internal padding on the right side */
9 | box-shadow: 0 1px 2px rgba(0, 0, 0, 0.075);
10 | margin-bottom: 1.5rem;
11 | height: 280px;
12 | }
13 |
--------------------------------------------------------------------------------
/docs/_static/img/spatialdata_horizontal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/docs/_static/img/spatialdata_horizontal.png
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/docs/_templates/.gitkeep
--------------------------------------------------------------------------------
/docs/_templates/autosummary/base.rst:
--------------------------------------------------------------------------------
1 | :github_url: {{ fullname }}
2 |
3 | {% extends "!autosummary/base.rst" %}
4 |
5 | .. http://www.sphinx-doc.org/en/stable/ext/autosummary.html#customizing-templates
6 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/function.rst:
--------------------------------------------------------------------------------
1 | :github_url: {{ fullname }}
2 |
3 | {{ fullname | escape | underline}}
4 |
5 | .. autofunction:: {{ fullname }}
6 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ```{toctree}
4 | :maxdepth: 1
5 |
6 | api/SpatialData.md
7 | api/io.md
8 | api/operations.md
9 | api/transformations.md
10 | api/transformations_utils.md
11 | api/datasets.md
12 | api/dataloader.md
13 | api/models.md
14 | api/models_utils.md
15 | api/testing.md
16 | api/data_formats.md
17 | ```
18 |
--------------------------------------------------------------------------------
/docs/api/SpatialData.md:
--------------------------------------------------------------------------------
1 | # SpatialData object
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: spatialdata
5 |
6 | .. autoclass:: SpatialData
7 | ```
8 |
--------------------------------------------------------------------------------
/docs/api/data_formats.md:
--------------------------------------------------------------------------------
1 | # Data formats (advanced)
2 |
3 | The SpatialData format is defined as a set of versioned subclasses of `spatialdata._io.format.SpatialDataFormat`, one per type of element.
4 | These classes are useful to ensure backward compatibility whenever a major version change is introduced. We also provide pointers to the latest format.
5 |
6 | ```{eval-rst}
7 | .. currentmodule:: spatialdata._io.format
8 |
9 | .. autoclass:: CurrentRasterFormat
10 | .. autoclass:: RasterFormatV01
11 | .. autoclass:: CurrentShapesFormat
12 | .. autoclass:: ShapesFormatV01
13 | .. autoclass:: ShapesFormatV02
14 | .. autoclass:: CurrentPointsFormat
15 | .. autoclass:: PointsFormatV01
16 | .. autoclass:: CurrentTablesFormat
17 | .. autoclass:: TablesFormatV01
18 | ```
19 |
--------------------------------------------------------------------------------
/docs/api/dataloader.md:
--------------------------------------------------------------------------------
1 | # Data Loaders
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: spatialdata.dataloader
5 |
6 | .. autoclass:: ImageTilesDataset
7 | ```
8 |
--------------------------------------------------------------------------------
/docs/api/datasets.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | Convenience small datasets
4 |
5 | ```{eval-rst}
6 | .. currentmodule:: spatialdata.datasets
7 |
8 | .. autofunction:: blobs
9 | .. autofunction:: blobs_annotating_element
10 | .. autofunction:: raccoon
11 | ```
12 |
--------------------------------------------------------------------------------
/docs/api/io.md:
--------------------------------------------------------------------------------
1 | # Input/Output
2 |
3 | To read the data from a specific technology (e.g., Xenium, MERSCOPE, ...), you can
4 | use any of the [spatialdata-io readers](https://spatialdata.scverse.org/projects/io/en/stable/api.html).
5 |
6 | ```{eval-rst}
7 | .. currentmodule:: spatialdata
8 |
9 | .. autofunction:: read_zarr
10 | .. autofunction:: save_transformations
11 | .. autofunction:: get_dask_backing_files
12 | ```
13 |
--------------------------------------------------------------------------------
/docs/api/models.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | The elements (building-blocks) that constitute `SpatialData`.
4 |
5 | ```{eval-rst}
6 | .. currentmodule:: spatialdata.models
7 |
8 | .. autoclass:: Image2DModel
9 | .. autoclass:: Image3DModel
10 | .. autoclass:: Labels2DModel
11 | .. autoclass:: Labels3DModel
12 | .. autoclass:: ShapesModel
13 | .. autoclass:: PointsModel
14 | .. autoclass:: TableModel
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/api/models_utils.md:
--------------------------------------------------------------------------------
1 | # Models utils
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: spatialdata.models
5 |
6 | .. autofunction:: get_model
7 | .. autodata:: SpatialElement
8 | .. autofunction:: get_axes_names
9 | .. autofunction:: get_spatial_axes
10 | .. autofunction:: points_geopandas_to_dask_dataframe
11 | .. autofunction:: points_dask_dataframe_to_geopandas
12 | .. autofunction:: get_channel_names
13 | .. autofunction:: set_channel_names
14 | .. autofunction:: force_2d
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/api/operations.md:
--------------------------------------------------------------------------------
1 | # Operations
2 |
3 | Operations on `SpatialData` objects.
4 |
5 | ```{eval-rst}
6 | .. module:: spatialdata
7 |
8 | .. autofunction:: bounding_box_query
9 | .. autofunction:: polygon_query
10 | .. autofunction:: get_values
11 | .. autofunction:: get_element_instances
12 | .. autofunction:: get_extent
13 | .. autofunction:: get_centroids
14 | .. autofunction:: join_spatialelement_table
15 | .. autofunction:: match_element_to_table
16 | .. autofunction:: match_table_to_element
17 | .. autofunction:: match_sdata_to_table
18 | .. autofunction:: concatenate
19 | .. autofunction:: transform
20 | .. autofunction:: rasterize
21 | .. autofunction:: rasterize_bins
22 | .. autofunction:: rasterize_bins_link_table_to_labels
23 | .. autofunction:: to_circles
24 | .. autofunction:: to_polygons
25 | .. autofunction:: aggregate
26 | .. autofunction:: map_raster
27 | .. autofunction:: unpad_raster
28 | .. autofunction:: relabel_sequential
29 | .. autofunction:: are_extents_equal
30 | .. autofunction:: deepcopy
31 | .. autofunction:: get_pyramid_levels
32 | .. autofunction:: sanitize_name
33 | .. autofunction:: sanitize_table
34 | ```
35 |
--------------------------------------------------------------------------------
/docs/api/testing.md:
--------------------------------------------------------------------------------
1 | # Testing utilities
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: spatialdata.testing
5 |
6 | .. autofunction:: assert_spatial_data_objects_are_identical
7 | .. autofunction:: assert_elements_are_identical
8 | .. autofunction:: assert_elements_dict_are_identical
9 | ```
10 |
--------------------------------------------------------------------------------
/docs/api/transformations.md:
--------------------------------------------------------------------------------
1 | # Transformations
2 |
3 | The transformations that can be defined between elements and coordinate systems in `SpatialData`.
4 |
5 | ```{eval-rst}
6 | .. currentmodule:: spatialdata.transformations
7 |
8 | .. autoclass:: BaseTransformation
9 | .. autoclass:: Identity
10 | .. autoclass:: MapAxis
11 | .. autoclass:: Translation
12 | .. autoclass:: Scale
13 | .. autoclass:: Affine
14 | .. autoclass:: Sequence
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/api/transformations_utils.md:
--------------------------------------------------------------------------------
1 | # Transformations utils
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: spatialdata.transformations
5 |
6 | .. autofunction:: get_transformation
7 | .. autofunction:: set_transformation
8 | .. autofunction:: remove_transformation
9 | .. autofunction:: get_transformation_between_coordinate_systems
10 | .. autofunction:: get_transformation_between_landmarks
11 | .. autofunction:: align_elements_using_landmarks
12 | .. autofunction:: remove_transformations_to_coordinate_system
13 | ```
14 |
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | Please refer directly to the [Releases](https://github.com/scverse/spatialdata/releases) section on GitHub, where you can find curated release notes for each release.
4 | For developers, please consult the [contributing guide](https://github.com/scverse/spatialdata/blob/main/docs/contributing.md), which explains how to keep release notes are up-to-date at each release.
5 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 | import sys
9 | from datetime import datetime
10 | from importlib.metadata import metadata
11 | from pathlib import Path
12 |
13 | HERE = Path(__file__).parent
14 | sys.path.insert(0, str(HERE / "extensions"))
15 |
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | info = metadata("spatialdata")
20 | project_name = info["Name"]
21 | author = info["Author"]
22 | copyright = f"{datetime.now():%Y}, {author}."
23 | version = info["Version"]
24 | # repository_url = f"https://github.com/scverse/{project_name}"
25 |
26 | # The full version, including alpha/beta/rc tags
27 | release = info["Version"]
28 |
29 | bibtex_bibfiles = ["references.bib"]
30 | bibtex_reference_style = "author_year"
31 | templates_path = ["_templates"]
32 | nitpicky = True # Warn about broken links
33 | needs_sphinx = "4.0"
34 |
35 | html_context = {
36 | "display_github": True, # Integrate GitHub
37 | "github_user": "scverse", # Username
38 | "github_repo": project_name, # Repo name
39 | "github_version": "main", # Version
40 | "conf_py_path": "/docs/", # Path in the checkout to the docs root
41 | }
42 |
43 | # -- General configuration ---------------------------------------------------
44 |
45 | # Add any Sphinx extension module names here, as strings.
46 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
47 | extensions = [
48 | "myst_nb",
49 | "sphinx_copybutton",
50 | "sphinx.ext.autodoc",
51 | "sphinx.ext.intersphinx",
52 | "sphinx.ext.autosummary",
53 | "sphinx.ext.napoleon",
54 | "sphinxcontrib.bibtex",
55 | "sphinx_autodoc_typehints",
56 | "sphinx.ext.mathjax",
57 | "IPython.sphinxext.ipython_console_highlighting",
58 | "sphinx_design",
59 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
60 | ]
61 |
62 | autodoc_default_options = {
63 | "members": True,
64 | "inherited-members": True,
65 | "show-inheritance": True,
66 | }
67 |
68 | autosummary_generate = True
69 | autodoc_process_signature = True
70 | autodoc_member_order = "groupwise"
71 | default_role = "literal"
72 | napoleon_google_docstring = False
73 | napoleon_numpy_docstring = True
74 | napoleon_include_init_with_doc = False
75 | napoleon_use_rtype = True # having a separate entry generally helps readability
76 | napoleon_use_param = True
77 | myst_heading_anchors = 3 # create anchors for h1-h3
78 | myst_enable_extensions = [
79 | "amsmath",
80 | "colon_fence",
81 | "deflist",
82 | "dollarmath",
83 | "html_image",
84 | "html_admonition",
85 | ]
86 | myst_url_schemes = ("http", "https", "mailto")
87 | nb_output_stderr = "remove"
88 | nb_execution_mode = "off"
89 | nb_merge_streams = True
90 | typehints_defaults = "braces"
91 |
92 | source_suffix = {
93 | ".rst": "restructuredtext",
94 | ".ipynb": "myst-nb",
95 | ".myst": "myst-nb",
96 | }
97 |
98 | intersphinx_mapping = {
99 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
100 | "numpy": ("https://numpy.org/doc/stable/", None),
101 | "geopandas": ("https://geopandas.org/en/stable/", None),
102 | "xarray": ("https://docs.xarray.dev/en/stable/", None),
103 | "datatree": ("https://datatree.readthedocs.io/en/latest/", None),
104 | "dask": ("https://docs.dask.org/en/latest/", None),
105 | "shapely": ("https://shapely.readthedocs.io/en/stable", None),
106 | }
107 |
108 |
109 | # List of patterns, relative to source directory, that match files and
110 | # directories to ignore when looking for source files.
111 | # This pattern also affects html_static_path and html_extra_path.
112 | exclude_patterns = [
113 | "_build",
114 | "Thumbs.db",
115 | "**.ipynb_checkpoints",
116 | "tutorials/notebooks/index.md",
117 | "tutorials/notebooks/README.md",
118 | "tutorials/notebooks/references.md",
119 | "tutorials/notebooks/notebooks/paper_reproducibility/*",
120 | "tutorials/notebooks/notebooks/paper_reproducibility/*",
121 | "tutorials/notebooks/notebooks/developers_resources/storage_format/*.ipynb",
122 | "tutorials/notebooks/notebooks/developers_resources/storage_format/Readme.md",
123 | "tutorials/notebooks/notebooks/examples/technology_stereoseq.ipynb",
124 | "tutorials/notebooks/notebooks/examples/technology_curio.ipynb",
125 | "tutorials/notebooks/notebooks/examples/stereoseq_data/*",
126 | ]
127 | # Ignore warnings.
128 | nitpicky = False # TODO: solve upstream.
129 | # nitpick_ignore = [
130 | # ("py:class", "spatial_image.SpatialImage"),
131 | # ("py:class", "multiscale_spatial_image.multiscale_spatial_image.MultiscaleSpatialImage"),
132 | # ]
133 | # no solution yet (7.4.7); using the workaround shown here: https://github.com/sphinx-doc/sphinx/issues/12589
134 | suppress_warnings = [
135 | "autosummary.import_cycle",
136 | ]
137 |
138 |
139 | # -- Options for HTML output -------------------------------------------------
140 |
141 | # The theme to use for HTML and HTML Help pages. See the documentation for
142 | # a list of builtin themes.
143 | #
144 | html_theme = "sphinx_book_theme"
145 | # html_theme = "sphinx_rtd_theme"
146 | html_static_path = ["_static"]
147 | html_title = project_name
148 | html_logo = "_static/img/spatialdata_horizontal.png"
149 |
150 | html_theme_options = {
151 | "navigation_with_keys": True,
152 | "show_toc_level": 4,
153 | # "repository_url": repository_url,
154 | # "use_repository_button": True,
155 | }
156 |
157 | pygments_style = "default"
158 |
159 | nitpick_ignore = [
160 | # If building the documentation fails because of a missing link that is outside your control,
161 | # you can add an exception to this list.
162 | ("py:class", "igraph.Graph"),
163 | ]
164 |
165 |
166 | def setup(app):
167 | """App setup hook."""
168 | app.add_config_value(
169 | "recommonmark_config",
170 | {
171 | "auto_toc_tree_section": "Contents",
172 | "enable_auto_toc_tree": True,
173 | "enable_math": True,
174 | "enable_inline_math": False,
175 | "enable_eval_rst": True,
176 | },
177 | True,
178 | )
179 | app.add_css_file("css/custom.css")
180 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | import re
4 |
5 | from sphinx.application import Sphinx
6 | from sphinx.ext.napoleon import NumpyDocstring
7 |
8 |
9 | def _process_return(lines):
10 | for line in lines:
11 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line)
12 | if m:
13 | # Once this is in scanpydoc, we can use the fancy hover stuff
14 | yield f"-{m['param']} (:class:`~{m['type']}`)"
15 | else:
16 | yield line
17 |
18 |
19 | def _parse_returns_section(self, section):
20 | lines_raw = list(_process_return(self._dedent(self._consume_to_next_section())))
21 | lines = self._format_block(":returns: ", lines_raw)
22 | if lines and lines[-1]:
23 | lines.append("")
24 | return lines
25 |
26 |
27 | def setup(app: Sphinx):
28 | """Set app."""
29 | NumpyDocstring._parse_returns_section = _parse_returns_section
30 |
--------------------------------------------------------------------------------
/docs/glossary.md:
--------------------------------------------------------------------------------
1 | # Glossary
2 |
3 | ## IO
4 |
5 | IO means input/output. For instance, `spatialdata-io` is about reading and/or writing data.
6 |
7 | ## NGFF
8 |
9 | [NGFF (Next-Generation File Format)](https://ngff.openmicroscopy.org/latest/) is a specification for storing multi-dimensional, large-scale imaging and spatial data efficiently. Developed by the OME (Open Microscopy Environment), it supports formats like multi-resolution images, enabling flexible data access, high performance, and compatibility with cloud-based storage. NGFF is designed to handle modern microscopy, biomedical imaging, and other high-dimensional datasets in a scalable, standardized way, often using formats like Zarr for storage.
10 |
11 | ## OME
12 |
13 | OME stands for Open Microscopy Environment and is a consortium of universities, research labs, industry and developers producing open-source software and format standards for microscopy data. It developed, among others, the OME-NGFF specification.
14 |
15 | ## OME-NGFF
16 |
17 | See NGFF.
18 |
19 | ## OME-Zarr
20 |
21 | An implementation of the OME-NGFF specification using the Zarr format. The SpatialData Zarr format (see below) is an extnesion of the OME-Zarr format.
22 |
23 | ## Raster / Rasterization
24 |
25 | Raster data represents images using a grid of pixels, where each pixel contains a specific value representing information such as color, intensity, or another attribute.
26 |
27 | Rasterization is the process of converting data from a vector format (see definition below) into a raster format of pixels (i.e., an image).
28 |
29 | ## ROI
30 |
31 | An ROI (_Region of Interest_) is a specific subset of a dataset that highlights an area of particular relevance, such as a niche, cell, or tissue location. For example, an ROI may be defined as a polygon outlining a lymphoid structure.
32 |
33 | ## Spatial query
34 |
35 | A spatial query subsets the data based on the location of the spatial elements. For example, subsetting data with a bounding box query selects elements, and the corresponding tabular annotations, within a defined rectangular region, while a polygon query selects data within a specified shape.
36 |
37 | ## SpatialData Zarr format
38 |
39 | An extension of the OME-Zarr format used to represent SpatialData objects on disk. Our aim is to converge the SpatialData Zarr format with the OME-Zarr format, by adapting to future versions of the NGFF specification and by contributing to the development of new features in NGFF. The OME-Zarr format was developed for bioimaging data and therefore and extension of the OME-Zarr format is necessary to accommodate spatial omics data.
40 |
41 | ## SpatialElements
42 |
43 | _SpatialElements_ are the building blocks of _SpatialData_ datasets, and are split into multiple categories: `images`, `labels`, `shapes` and `points`. SpatialData SpatialElements are not special classes, but are instead standard scientific Python classes (e.g., `xarray.DataArray`, `geopandas.GeoDataFrame`) with specified metadata. The metadata provides the necessary information for displaying and integrating the different elements (e.g., coordinate systems and coordinate transformations). SpatialElements can be annotated by `tables`, which are represented as `anndata.AnnData` objects. Tables are a building block of SpatialData, but are not considered SpatialElements since they do not contain spatial information.
44 |
45 | ## Tables
46 |
47 | A building block of SpatialData, represented as `anndata.AnnData` objeets. Tables are used to store tabular data, such as gene expression values, cell metadata, or other annotations. Tables can be associated with SpatialElements to provide additional information about the spatial data, such as cell type annotations or gene expression levels.
48 |
49 | ## Vector / Vectorization
50 |
51 | Vector data represents spatial information using geometric shapes such as points, circles, or polygons, each defined by mathematical coordinates.
52 |
53 | Vectorization is the process of converting raster data (i.e. pixel-based images) into vector format. For instance, a "raster" cell boundary can be "vectorized" into a polygon represented by the coordinates of the vertices of the polygon.
54 |
55 | ## Zarr storage
56 |
57 | Zarr is a format for storing multi-dimensional arrays, designed for efficient, scalable access to large datasets. It supports chunking (splitting data into smaller, manageable pieces) and compression, enabling fast reading and writing of data, even for very large arrays. Zarr is particularly useful for distributed and parallel computing, allowing access to subsets of data without loading the entire dataset into memory. A `SpatialData` object can be stored as a Zarr store (`.zarr` directory).
58 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{eval-rst}
2 | .. image:: _static/img/spatialdata_horizontal.png
3 | :class: dark-light p-2
4 | :alt: SpatialData banner
5 | ```
6 |
7 | # An open and universal framework for processing spatial omics data.
8 |
9 | SpatialData is a data framework that comprises a FAIR storage format and a collection of python libraries for performant access, alignment, and processing of uni- and multi-modal spatial omics datasets. This page provides documentation on how to install, use, and extend the core `spatialdata` library. See the links below to learn more about other packages in the SpatialData ecosystem.
10 |
11 | - `spatialdata-io`: load data from common spatial omics technologies into `spatialdata` ([repository][spatialdata-io-repo], [documentation][spatialdata-io-docs]).
12 | - `spatialdata-plot`: Static plotting library for `spatialdata` ([repository][spatialdata-plot-repo], [documentation][spatialdata-plot-docs]).
13 | - `napari-spatialdata-repo`: napari plugin for interactive exploration and annotation of `spatialdata` ([repository][napari-spatialdata-repo], [documentation][napari-spatialdata-docs]).
14 |
15 | Please see our publication {cite}`marconatoSpatialDataOpenUniversal2024` for citation and to learn more.
16 |
17 | [//]: # "numfocus-fiscal-sponsor-attribution"
18 |
19 | spatialdata is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
20 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs.
21 |
22 |
30 |
31 | ```{eval-rst}
32 | .. note::
33 | This library is currently under active development. We may make changes to the API between versions as the community provides feedback. To ensure reproducibility, please make note of the version you are developing against.
34 | ```
35 |
36 | ```{eval-rst}
37 | .. card:: Installation
38 | :link: installation
39 | :link-type: doc
40 |
41 | Learn how to install ``spatialdata``.
42 |
43 | .. card:: Tutorials
44 | :link: tutorials/notebooks/notebooks
45 | :link-type: doc
46 |
47 | Learn how to use ``spatialdata`` with hands-on examples.
48 |
49 | .. card:: API
50 | :link: api
51 | :link-type: doc
52 |
53 | Find a detailed documentation of ``spatialdata``.
54 |
55 | .. card:: Datasets
56 | :link: tutorials/notebooks/datasets/README
57 | :link-type: doc
58 |
59 | Example datasets from 8 different technologies.
60 |
61 | .. card:: Design document
62 | :link: design_doc
63 | :link-type: doc
64 |
65 | Learn about the design approach behind ``spatialdata``.
66 |
67 | .. card:: Contributing
68 | :link: contributing
69 | :link-type: doc
70 |
71 | Learn how to contribute to ``spatialdata``.
72 |
73 | ```
74 |
75 | ```{toctree}
76 | :hidden: true
77 | :maxdepth: 1
78 |
79 | installation.md
80 | api.md
81 | tutorials/notebooks/notebooks.md
82 | tutorials/notebooks/datasets/README.md
83 | glossary.md
84 | design_doc.md
85 | contributing.md
86 | changelog.md
87 | references.md
88 | ```
89 |
90 |
91 |
92 | [napari-spatialdata-repo]: https://github.com/scverse/napari-spatialdata
93 | [spatialdata-io-repo]: https://github.com/scverse/spatialdata-io
94 | [spatialdata-plot-repo]: https://github.com/scverse/spatialdata-plot
95 | [napari-spatialdata-docs]: https://spatialdata.scverse.org/projects/napari/en/stable/notebooks/spatialdata.html
96 | [spatialdata-io-docs]: https://spatialdata.scverse.org/projects/io/en/stable/
97 | [spatialdata-plot-docs]: https://spatialdata.scverse.org/projects/plot/en/stable/api.html
98 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | `spatialdata` requires Python version >= 3.9 to run and the installation time requires a few minutes on a standard desktop computer.
4 |
5 | ## PyPI
6 |
7 | Install `spatialdata` by running:
8 |
9 | ```bash
10 | pip install spatialdata
11 | ```
12 |
13 | ## Visualization and readers
14 |
15 | The SpatialData ecosystem is designed to work with the following packages:
16 |
17 | - [spatialdata-io][]: `spatialdata` readers and converters for common spatial omics technologies.
18 | - [spatialdata-plot][]: Static plotting library for `spatialdata`.
19 | - [napari-spatialdata][]: napari plugin for `spatialdata`.
20 |
21 | They can be installed with:
22 |
23 | ```bash
24 | pip install "spatialdata[extra]"
25 | ```
26 |
27 | ## Additional dependencies
28 |
29 | To use the `PyTorch` dataloader in `spatialdata`, `torch` needs to be installed. This can be done with:
30 |
31 | ```bash
32 | pip install "spatialdata[torch]"
33 | ```
34 |
35 | ## Development version
36 |
37 | To install `spatialdata` from GitHub, run:
38 |
39 | ```bash
40 | pip install git+https://github.com/scverse/spatialdata
41 | ```
42 |
43 | Alternative you can clone the repository (or a fork of it if you are contributing) and do an editable install with:
44 |
45 | ```bash
46 | pip install -e .
47 | ```
48 |
49 | This is the reccommended way to install the package in case in which you want to contribute to the code. In this case, to subsequently update the package you can use `git pull`.
50 |
51 | ### A note on editable install
52 |
53 | If you perform an editable install of `spatialdata` and then install `spatialdata-plot`, `spatialdata-io` or `napari-spatialdata`, they may automatically override the installation of `spatialdata` with the version from PyPI.
54 |
55 | To check if this happened you can run
56 |
57 | ```
58 | python -c "import spatialdata; print(spatialdata.__path__)"
59 | ```
60 |
61 | if you get a path that contains `site-packages`, then your editable installation has been overridden and you need to reinstall the package by rerunning `pip install -e .` in the cloned `spatialdata` repo.
62 |
63 | ## Conda
64 |
65 | You can install the `spatialdata`, `spatialdata-io`, `spatialdata-plot` and `napari-spatialdata` packages from the `conda-forge` channel using
66 |
67 | ```bash
68 | mamba install -c conda-forge spatialdata spatialdata-io spatialdata-plot napari-spatialdata
69 | ```
70 |
71 | Update: currently (Feb 2025), due to particular versions being unavailable on `conda-forge` for some (dependencies of our) dependencies, the latest versions of the packages of the `SpatialData` ecosystem are not available on `conda-forge`. We are waiting for the availability to be unlocked. The latest versions are always available via PyPI.
72 |
73 | ## Docker
74 |
75 | ## Docker
76 |
77 | A `Dockerfile` is available in the repository; the image that can be built from it contains `spatialdata` (with `torch`), `spatialdata-io` and `spatialdata-plot` (not `napari-spatialdata`)'; the libaries are installed from PyPI.
78 |
79 | To build the image, run:
80 |
81 | ```bash
82 | # this is for Apple Silicon machines, if you are not using such machine you can omit the --build-arg
83 | docker build --build-arg TARGETPLATFORM=linux/arm64 --tag spatialdata .
84 | docker run -it spatialdata
85 | ```
86 |
87 | We also publish images automatically via GitHub Actions; you can see the [list of available images here](https://github.com/scverse/spatialdata/pkgs/container/spatialdata/versions).
88 |
89 | Once you have the image name, you can pull and run it with:
90 |
91 | ```bash
92 | docker pull ghcr.io/scverse/spatialdata:spatialdata0.3.0_spatialdata-io0.1.7_spatialdata-plot0.2.9
93 | docker run -it ghcr.io/scverse/spatialdata:spatialdata0.3.0_spatialdata-io0.1.7_spatialdata-plot0.2.9
94 | ```
95 |
96 |
97 |
98 | [napari-spatialdata]: https://github.com/scverse/napari-spatialdata
99 | [spatialdata-io]: https://github.com/scverse/spatialdata-io
100 | [spatialdata-plot]: https://github.com/scverse/spatialdata-plot
101 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{Wolf2018,
2 | author = {Wolf, F. Alexander
3 | and Angerer, Philipp
4 | and Theis, Fabian J.},
5 | title = {SCANPY: large-scale single-cell gene expression data analysis},
6 | journal = {Genome Biology},
7 | year = {2018},
8 | month = {Feb},
9 | day = {06},
10 | volume = {19},
11 | number = {1},
12 | pages = {15},
13 | abstract = {Scanpy is a scalable toolkit for analyzing single-cell gene expression data. It includes methods for preprocessing, visualization, clustering, pseudotime and trajectory inference, differential expression testing, and simulation of gene regulatory networks. Its Python-based implementation efficiently deals with data sets of more than one million cells (https://github.com/theislab/Scanpy). Along with Scanpy, we present AnnData, a generic class for handling annotated data matrices (https://github.com/theislab/anndata).},
14 | issn = {1474-760X},
15 | doi = {10.1186/s13059-017-1382-0},
16 | url = {https://doi.org/10.1186/s13059-017-1382-0}
17 | }
18 | @article {Marconato2023.05.05.539647,
19 | author = {Luca Marconato and Giovanni Palla and Kevin A. Yamauchi and Isaac Virshup and Elyas Heidari and Tim Treis and Marcella Toth and Rahul B. Shrestha and Harald V{\"o}hringer and Wolfgang Huber and Moritz Gerstung and Josh Moore and Fabian J. Theis and Oliver Stegle},
20 | title = {SpatialData: an open and universal data framework for spatial omics},
21 | elocation-id = {2023.05.05.539647},
22 | year = {2023},
23 | doi = {10.1101/2023.05.05.539647},
24 | publisher = {Cold Spring Harbor Laboratory},
25 | abstract = {Spatially resolved omics technologies are transforming our understanding of biological tissues. However, handling uni- and multi-modal spatial omics datasets remains a challenge owing to large volumes of data, heterogeneous data types and the lack of unified spatially-aware data structures. Here, we introduce SpatialData, a framework that establishes a unified and extensible multi-platform file-format, lazy representation of larger-than-memory data, transformations, and alignment to common coordinate systems. SpatialData facilitates spatial annotations and cross-modal aggregation and analysis, the utility of which is illustrated via multiple vignettes, including integrative analysis on a multi-modal Xenium and Visium breast cancer study.Competing Interest StatementJ.M. holds equity in Glencoe Software which builds products based on OME-NGFF. F.J.T. consults for Immunai Inc., Singularity Bio B.V., CytoReason Ltd, Cellarity, and Omniscope Ltd, and has ownership interest in Dermagnostix GmbH and Cellarity.},
26 | URL = {https://www.biorxiv.org/content/early/2023/05/08/2023.05.05.539647},
27 | eprint = {https://www.biorxiv.org/content/early/2023/05/08/2023.05.05.539647.full.pdf},
28 | journal = {bioRxiv}
29 | }
30 |
31 | @article{marconatoSpatialDataOpenUniversal2024,
32 | title = {{{SpatialData}}: An Open and Universal Data Framework for Spatial Omics},
33 | author = {Marconato, Luca and Palla, Giovanni and Yamauchi, Kevin A. and Virshup, Isaac and Heidari, Elyas and Treis, Tim and Vierdag, Wouter-Michiel and Toth, Marcella and Stockhaus, Sonja and Shrestha, Rahul B. and Rombaut, Benjamin and Pollaris, Lotte and Lehner, Laurens and V{\"o}hringer, Harald and Kats, Ilia and Saeys, Yvan and Saka, Sinem K. and Huber, Wolfgang and Gerstung, Moritz and Moore, Josh and Theis, Fabian J. and Stegle, Oliver},
34 | year = {2024},
35 | month = mar,
36 | journal = {Nature Methods},
37 | issn = {1548-7105},
38 | doi = {10.1038/s41592-024-02212-x},
39 | abstract = {Spatially resolved omics technologies are transforming our understanding of biological tissues. However, the handling of uni- and multimodal spatial omics datasets remains a challenge owing to large data volumes, heterogeneity of data types and the lack of flexible, spatially aware data structures. Here we introduce SpatialData, a framework that establishes a unified and extensible multiplatform file-format, lazy representation of larger-than-memory data, transformations and alignment to common coordinate systems. SpatialData facilitates spatial annotations and cross-modal aggregation and analysis, the utility of which is illustrated in the context of multiple vignettes, including integrative analysis on a multimodal Xenium and Visium breast cancer study.}
40 | }
41 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "hatchling.build"
3 | requires = ["hatchling", "hatch-vcs"]
4 |
5 |
6 | [project]
7 | name = "spatialdata"
8 | description = "Spatial data format."
9 | authors = [
10 | {name = "scverse"},
11 | ]
12 | maintainers = [
13 | {name = "scverse", email = "giov.pll@gmail.com"},
14 | ]
15 | urls.Documentation = "https://spatialdata.scverse.org/en/latest"
16 | urls.Source = "https://github.com/scverse/spatialdata.git"
17 | urls.Home-page = "https://github.com/scverse/spatialdata.git"
18 | requires-python = ">=3.10, <3.13" # include 3.13 once multiscale-spatial-image conflicts are resolved
19 | dynamic= [
20 | "version" # allow version to be set by git tags
21 | ]
22 | license = {file = "LICENSE"}
23 | readme = "README.md"
24 | dependencies = [
25 | "anndata>=0.9.1",
26 | "click",
27 | "dask-image",
28 | "dask>=2024.4.1,<=2024.11.2",
29 | "datashader",
30 | "fsspec",
31 | "geopandas>=0.14",
32 | "multiscale_spatial_image>=2.0.2",
33 | "networkx",
34 | "numba>=0.55.0",
35 | "numpy",
36 | "ome_zarr>=0.8.4",
37 | "pandas",
38 | "pooch",
39 | "pyarrow",
40 | "rich",
41 | "setuptools",
42 | "shapely>=2.0.1",
43 | "spatial_image>=1.1.0",
44 | "scikit-image",
45 | "scipy",
46 | "typing_extensions>=4.8.0",
47 | "xarray>=2024.10.0",
48 | "xarray-schema",
49 | "xarray-spatial>=0.3.5",
50 | "xarray-dataclasses>=1.8.0",
51 | "zarr<3",
52 | ]
53 |
54 | [project.optional-dependencies]
55 | dev = [
56 | "bump2version",
57 | ]
58 | test = [
59 | "pytest",
60 | "pytest-cov",
61 | "pytest-mock",
62 | "torch",
63 | ]
64 | docs = [
65 | "sphinx>=4.5",
66 | "sphinx-autobuild",
67 | "sphinx-book-theme>=1.0.0",
68 | "myst-nb",
69 | "sphinxcontrib-bibtex>=1.0.0",
70 | "sphinx-autodoc-typehints",
71 | "sphinx-design",
72 | # For notebooks
73 | "ipython>=8.6.0",
74 | "sphinx-copybutton",
75 | "sphinx-pytest",
76 | ]
77 | benchmark = [
78 | "asv",
79 | ]
80 | torch = [
81 | "torch"
82 | ]
83 | extra = [
84 | "napari-spatialdata[all]",
85 | "spatialdata-plot",
86 | "spatialdata-io",
87 | ]
88 |
89 | [tool.coverage.run]
90 | source = ["spatialdata"]
91 | omit = [
92 | "**/test_*.py",
93 | ]
94 |
95 | [tool.pytest.ini_options]
96 | testpaths = ["tests"]
97 | xfail_strict = true
98 | addopts = [
99 | # "-Werror", # if 3rd party libs raise DeprecationWarnings, just use filterwarnings below
100 | "--import-mode=importlib", # allow using test files with same name
101 | "-s" # print output from tests
102 | ]
103 | # info on how to use this https://stackoverflow.com/questions/57925071/how-do-i-avoid-getting-deprecationwarning-from-inside-dependencies-with-pytest
104 | filterwarnings = [
105 | # "ignore:.*U.*mode is deprecated:DeprecationWarning",
106 | ]
107 |
108 | [tool.jupytext]
109 | formats = "ipynb,md"
110 |
111 | [tool.hatch.build.targets.wheel]
112 | packages = ['src/spatialdata']
113 |
114 | [tool.hatch.version]
115 | source = "vcs"
116 |
117 | [tool.hatch.build.hooks.vcs]
118 | version-file = "_version.py"
119 |
120 | [tool.hatch.metadata]
121 | allow-direct-references = true
122 |
123 | [tool.ruff]
124 | exclude = [
125 | ".git",
126 | ".tox",
127 | "__pycache__",
128 | "build",
129 | "docs/_build",
130 | "dist",
131 | "setup.py",
132 |
133 | ]
134 | line-length = 120
135 | target-version = "py310"
136 |
137 | [tool.ruff.lint]
138 | ignore = [
139 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
140 | "E731",
141 | # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
142 | "E741",
143 | # Missing docstring in public package
144 | "D104",
145 | # Missing docstring in public module
146 | "D100",
147 | # Missing docstring in __init__
148 | "D107",
149 | # Missing docstring in magic method
150 | "D105",
151 | # Do not perform function calls in argument defaults.
152 | "B008",
153 | # Missing docstring in magic method
154 | "D105",
155 | ]
156 | select = [
157 | "D", # flake8-docstrings
158 | "I", # isort
159 | "E", # pycodestyle
160 | "F", # pyflakes
161 | "W", # pycodestyle
162 | "Q", # flake8-quotes
163 | "SIM", # flake8-simplify
164 | "TID", # flake-8-tidy-imports
165 | "NPY", # NumPy-specific rules
166 | "PT", # flake8-pytest-style
167 | "B", # flake8-bugbear
168 | "UP", # pyupgrade
169 | "C4", # flake8-comprehensions
170 | "BLE", # flake8-blind-except
171 | "T20", # flake8-print
172 | "RET", # flake8-raise
173 | "PGH", # pygrep-hooks
174 | ]
175 | unfixable = ["B", "C4", "UP", "BLE", "T20", "RET"]
176 |
177 | [tool.ruff.lint.pydocstyle]
178 | convention = "numpy"
179 |
180 | [tool.ruff.lint.per-file-ignores]
181 | "tests/*" = ["D", "PT", "B024"]
182 | "*/__init__.py" = ["F401", "D104", "D107", "E402"]
183 | "docs/*" = ["D","B","E","A"]
184 | "src/spatialdata/transformations/transformations.py" = ["D101","D102", "D106", "B024", "T201", "RET504", "UP006", "UP007"]
185 | "src/spatialdata/transformations/operations.py" = ["D101","D102", "D106", "B024","D401", "T201", "RET504", "RET506", "RET505", "RET504", "UP006", "UP007"]
186 | "src/spatialdata/transformations/ngff/*.py" = ["D101","D102", "D106", "D401", "E501","RET506", "RET505", "RET504", "UP006", "UP007"]
187 | "src/spatialdata/transformations/*" = ["RET", "D", "UP006", "UP007"]
188 | "src/spatialdata/models/models.py" = ["D101", "B026"]
189 | "src/spatialdata/dataloader/datasets.py" = ["D101"]
190 | "tests/test_models/test_models.py" = ["NPY002"]
191 | "tests/conftest.py"= ["E402"]
192 | "benchmarks/*" = ["ALL"]
193 |
194 |
195 | # pyupgrade typing rewrite TODO: remove at some point from per-file ignore
196 | # "UP006", "UP007"
197 |
--------------------------------------------------------------------------------
/src/spatialdata/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dask
4 |
5 | dask.config.set({"dataframe.query-planning": False})
6 | import dask.dataframe as dd
7 |
8 | # Setting `dataframe.query-planning` to False is effective only if run before `dask.dataframe` is initialized. In
9 | # the case in which the user had initilized `dask.dataframe` before, we would have DASK_EXPER_ENABLED set to `True`.
10 | # Here we check that this does not happen.
11 | if hasattr(dd, "DASK_EXPR_ENABLED") and dd.DASK_EXPR_ENABLED:
12 | raise RuntimeError(
13 | "Unsupported backend: dask-expr has been detected as the backend of dask.dataframe. Please "
14 | "use:\nimport dask\ndask.config.set({'dataframe.query-planning': False})\nbefore importing "
15 | "dask.dataframe to disable dask-expr. The support is being worked on, for more information please see"
16 | "https://github.com/scverse/spatialdata/pull/570"
17 | )
18 | from importlib.metadata import version
19 |
20 | __version__ = version("spatialdata")
21 |
22 | __all__ = [
23 | "models",
24 | "transformations",
25 | "datasets",
26 | "dataloader",
27 | "concatenate",
28 | "rasterize",
29 | "rasterize_bins",
30 | "rasterize_bins_link_table_to_labels",
31 | "to_circles",
32 | "to_polygons",
33 | "transform",
34 | "aggregate",
35 | "bounding_box_query",
36 | "polygon_query",
37 | "get_element_annotators",
38 | "get_element_instances",
39 | "get_values",
40 | "join_spatialelement_table",
41 | "match_element_to_table",
42 | "match_table_to_element",
43 | "match_sdata_to_table",
44 | "SpatialData",
45 | "get_extent",
46 | "get_centroids",
47 | "read_zarr",
48 | "unpad_raster",
49 | "get_pyramid_levels",
50 | "save_transformations",
51 | "get_dask_backing_files",
52 | "are_extents_equal",
53 | "relabel_sequential",
54 | "map_raster",
55 | "deepcopy",
56 | "sanitize_table",
57 | "sanitize_name",
58 | ]
59 |
60 | from spatialdata import dataloader, datasets, models, transformations
61 | from spatialdata._core._deepcopy import deepcopy
62 | from spatialdata._core._utils import sanitize_name, sanitize_table
63 | from spatialdata._core.centroids import get_centroids
64 | from spatialdata._core.concatenate import concatenate
65 | from spatialdata._core.data_extent import are_extents_equal, get_extent
66 | from spatialdata._core.operations.aggregate import aggregate
67 | from spatialdata._core.operations.map import map_raster, relabel_sequential
68 | from spatialdata._core.operations.rasterize import rasterize
69 | from spatialdata._core.operations.rasterize_bins import rasterize_bins, rasterize_bins_link_table_to_labels
70 | from spatialdata._core.operations.transform import transform
71 | from spatialdata._core.operations.vectorize import to_circles, to_polygons
72 | from spatialdata._core.query._utils import get_bounding_box_corners
73 | from spatialdata._core.query.relational_query import (
74 | get_element_annotators,
75 | get_element_instances,
76 | get_values,
77 | join_spatialelement_table,
78 | match_element_to_table,
79 | match_sdata_to_table,
80 | match_table_to_element,
81 | )
82 | from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query
83 | from spatialdata._core.spatialdata import SpatialData
84 | from spatialdata._io._utils import get_dask_backing_files, save_transformations
85 | from spatialdata._io.format import SpatialDataFormat
86 | from spatialdata._io.io_zarr import read_zarr
87 | from spatialdata._utils import get_pyramid_levels, unpad_raster
88 |
--------------------------------------------------------------------------------
/src/spatialdata/__main__.py:
--------------------------------------------------------------------------------
1 | """
2 | The CLI Interaction module.
3 |
4 | This module provides command line interface (CLI) interactions for the SpatialData library, allowing users to perform
5 | various operations through a terminal. Currently, it implements the "peek" function, which allows users to inspect
6 | the contents of a SpatialData .zarr dataset. Additional CLI functionalities will be implemented in the future.
7 | """
8 |
9 | from typing import Literal
10 |
11 | import click
12 |
13 |
14 | @click.command(help="Peek inside the SpatialData .zarr dataset")
15 | @click.argument("path", default=False, type=str)
16 | @click.argument("selection", type=click.Choice(["images", "labels", "points", "shapes", "table"]), nargs=-1)
17 | def peek(path: str, selection: tuple[Literal["images", "labels", "points", "shapes", "table"]]) -> None:
18 | """
19 | Peek inside the SpatialData .zarr dataset.
20 |
21 | This function takes a path to a local or remote .zarr dataset, reads and prints
22 | its contents using the SpatialData library. If any ValueError is raised, it is caught and printed to the
23 | terminal along with a help message.
24 |
25 | Parameters
26 | ----------
27 | path
28 | The path to the .zarr dataset to be inspected.
29 | selection
30 | Optional, a list of keys (among images, labels, points, shapes, table) to load only a subset of the dataset.
31 | Example: `python -m spatialdata peek data.zarr images labels`
32 | """
33 | import spatialdata as sd
34 |
35 | try:
36 | sdata = sd.SpatialData.read(path, selection=selection)
37 | print(sdata) # noqa: T201
38 | except ValueError as e:
39 | # checking if a valid path was provided is difficult given the various ways in which
40 | # a possibly remote path and storage access options can be specified
41 | # so we just catch the ValueError and print a help message
42 | print(e) # noqa: T201
43 | print( # noqa: T201
44 | f"Error: .zarr storage not found at {path}. Please specify a valid OME-NGFF spatial data (.zarr) file. "
45 | "Examples "
46 | '"python -m spatialdata peek data.zarr"'
47 | '"python -m spatialdata peek https://remote/.../data.zarr labels table"'
48 | )
49 |
50 |
51 | @click.group()
52 | def cli() -> None:
53 | """Provide the main Click command group.
54 |
55 | This function serves as the main entry point for the command-line interface. It creates a Click command
56 | group and adds the various cli commands to it.
57 | """
58 |
59 |
60 | cli.add_command(peek)
61 |
62 |
63 | def main() -> None:
64 | """Initialize and run the command-line interface.
65 |
66 | This function initializes the Click command group and runs the command-line interface, processing user
67 | input and executing the appropriate commands.
68 | """
69 | cli()
70 |
71 |
72 | if __name__ == "__main__":
73 | main()
74 |
--------------------------------------------------------------------------------
/src/spatialdata/_bridges/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/src/spatialdata/_bridges/__init__.py
--------------------------------------------------------------------------------
/src/spatialdata/_core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/src/spatialdata/_core/__init__.py
--------------------------------------------------------------------------------
/src/spatialdata/_core/_deepcopy.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from copy import deepcopy as _deepcopy
4 | from functools import singledispatch
5 |
6 | from anndata import AnnData
7 | from dask.array.core import Array as DaskArray
8 | from dask.array.core import from_array
9 | from dask.dataframe import DataFrame as DaskDataFrame
10 | from geopandas import GeoDataFrame
11 | from xarray import DataArray, DataTree
12 |
13 | from spatialdata._core.spatialdata import SpatialData
14 | from spatialdata.models._utils import SpatialElement
15 | from spatialdata.models.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, get_model
16 |
17 |
18 | @singledispatch
19 | def deepcopy(element: SpatialData | SpatialElement | AnnData) -> SpatialData | SpatialElement | AnnData:
20 | """
21 | Deepcopy a SpatialData or SpatialElement object.
22 |
23 | Deepcopy will load the data in memory. Using this function for large Dask-backed objects is discouraged. In that
24 | case, please save the SpatialData object to a different disk location and read it back again.
25 |
26 | Parameters
27 | ----------
28 | element
29 | The SpatialData or SpatialElement object to deepcopy
30 |
31 | Returns
32 | -------
33 | A deepcopy of the SpatialData or SpatialElement object
34 |
35 | Notes
36 | -----
37 | The order of the columns for a deepcopied points element may be differ from the original one, please see more here:
38 | https://github.com/scverse/spatialdata/issues/486
39 | """
40 | raise RuntimeError(f"Wrong type for deepcopy: {type(element)}")
41 |
42 |
43 | # In the implementations below, when the data is loaded from Dask, we first use compute() and then we deepcopy the data.
44 | # This leads to double copying the data, but since we expect the data to be small, this is acceptable.
45 | @deepcopy.register(SpatialData)
46 | def _(sdata: SpatialData) -> SpatialData:
47 | elements_dict = {}
48 | for _, element_name, element in sdata.gen_elements():
49 | elements_dict[element_name] = deepcopy(element)
50 | deepcopied_attrs = _deepcopy(sdata.attrs)
51 | return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs)
52 |
53 |
54 | @deepcopy.register(DataArray)
55 | def _(element: DataArray) -> DataArray:
56 | model = get_model(element)
57 | if isinstance(element.data, DaskArray):
58 | element = element.compute()
59 | if model in [Image2DModel, Image3DModel]:
60 | return model.parse(element.copy(deep=True), c_coords=element["c"]) # type: ignore[call-arg]
61 | assert model in [Labels2DModel, Labels3DModel]
62 | return model.parse(element.copy(deep=True))
63 |
64 |
65 | @deepcopy.register(DataTree)
66 | def _(element: DataTree) -> DataTree:
67 | # TODO: now that multiscale_spatial_image 1.0.0 is supported, this code can probably be simplified. Check
68 | # https://github.com/scverse/spatialdata/pull/587/files#diff-c74ebf49cb8cbddcfaec213defae041010f2043cfddbded24175025b6764ef79
69 | # to understand the original motivation.
70 | model = get_model(element)
71 | for key in element:
72 | ds = element[key].ds
73 | assert len(ds) == 1
74 | variable = ds.__iter__().__next__()
75 | if isinstance(element[key][variable].data, DaskArray):
76 | element[key][variable] = element[key][variable].compute()
77 | msi = element.copy(deep=True)
78 | for key in msi:
79 | ds = msi[key].ds
80 | variable = ds.__iter__().__next__()
81 | msi[key][variable].data = from_array(msi[key][variable].data)
82 | element[key][variable].data = from_array(element[key][variable].data)
83 | assert model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel]
84 | model().validate(msi)
85 | return msi
86 |
87 |
88 | @deepcopy.register(GeoDataFrame)
89 | def _(gdf: GeoDataFrame) -> GeoDataFrame:
90 | new_gdf = _deepcopy(gdf)
91 | # temporary fix for https://github.com/scverse/spatialdata/issues/286.
92 | new_gdf.attrs = _deepcopy(gdf.attrs)
93 | return new_gdf
94 |
95 |
96 | @deepcopy.register(DaskDataFrame)
97 | def _(df: DaskDataFrame) -> DaskDataFrame:
98 | # bug: the parser may change the order of the columns
99 | new_ddf = PointsModel.parse(df.compute().copy(deep=True))
100 | # the problem is not .copy(deep=True), but the parser, which discards some metadata https://github.com/scverse/spatialdata/issues/503#issuecomment-2015275322
101 | new_ddf.attrs = _deepcopy(df.attrs)
102 | return new_ddf
103 |
104 |
105 | @deepcopy.register(AnnData)
106 | def _(adata: AnnData) -> AnnData:
107 | return adata.copy()
108 |
--------------------------------------------------------------------------------
/src/spatialdata/_core/_elements.py:
--------------------------------------------------------------------------------
1 | """SpatialData elements."""
2 |
3 | from __future__ import annotations
4 |
5 | from collections import UserDict
6 | from collections.abc import Iterable, KeysView, ValuesView
7 | from typing import TypeVar
8 | from warnings import warn
9 |
10 | from anndata import AnnData
11 | from dask.dataframe import DataFrame as DaskDataFrame
12 | from geopandas import GeoDataFrame
13 | from xarray import DataArray, DataTree
14 |
15 | from spatialdata._core.validation import check_key_is_case_insensitively_unique, check_valid_name
16 | from spatialdata._types import Raster_T
17 | from spatialdata.models import (
18 | Image2DModel,
19 | Image3DModel,
20 | Labels2DModel,
21 | Labels3DModel,
22 | PointsModel,
23 | ShapesModel,
24 | TableModel,
25 | get_axes_names,
26 | get_model,
27 | )
28 |
29 | T = TypeVar("T")
30 |
31 |
32 | class Elements(UserDict[str, T]):
33 | def __init__(self, shared_keys: set[str | None]) -> None:
34 | self._shared_keys = shared_keys
35 | super().__init__()
36 |
37 | def _add_shared_key(self, key: str) -> None:
38 | self._shared_keys.add(key)
39 |
40 | def _remove_shared_key(self, key: str) -> None:
41 | self._shared_keys.remove(key)
42 |
43 | @staticmethod
44 | def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | None]) -> None:
45 | check_valid_name(key)
46 | if key in element_keys:
47 | warn(f"Key `{key}` already exists. Overwriting it in-memory.", UserWarning, stacklevel=2)
48 | else:
49 | try:
50 | check_key_is_case_insensitively_unique(key, shared_keys)
51 | except ValueError as e:
52 | # Validation raises ValueError, but inappropriate mapping key must raise KeyError.
53 | raise KeyError(*e.args) from e
54 |
55 | def __setitem__(self, key: str, value: T) -> None:
56 | self._add_shared_key(key)
57 | super().__setitem__(key, value)
58 |
59 | def __delitem__(self, key: str) -> None:
60 | self._remove_shared_key(key)
61 | super().__delitem__(key)
62 |
63 | def keys(self) -> KeysView[str]:
64 | """Return the keys of the Elements."""
65 | return self.data.keys()
66 |
67 | def values(self) -> ValuesView[T]:
68 | """Return the values of the Elements."""
69 | return self.data.values()
70 |
71 |
72 | class Images(Elements[DataArray | DataTree]):
73 | def __setitem__(self, key: str, value: Raster_T) -> None:
74 | self._check_key(key, self.keys(), self._shared_keys)
75 | schema = get_model(value)
76 | if schema not in (Image2DModel, Image3DModel):
77 | raise TypeError(f"Unknown element type with schema: {schema!r}.")
78 | ndim = len(get_axes_names(value))
79 | if ndim == 3:
80 | Image2DModel().validate(value)
81 | super().__setitem__(key, value)
82 | elif ndim == 4:
83 | Image3DModel().validate(value)
84 | super().__setitem__(key, value)
85 | else:
86 | NotImplementedError("TODO: implement for ndim > 4.")
87 |
88 |
89 | class Labels(Elements[DataArray | DataTree]):
90 | def __setitem__(self, key: str, value: Raster_T) -> None:
91 | self._check_key(key, self.keys(), self._shared_keys)
92 | schema = get_model(value)
93 | if schema not in (Labels2DModel, Labels3DModel):
94 | raise TypeError(f"Unknown element type with schema: {schema!r}.")
95 | ndim = len(get_axes_names(value))
96 | if ndim == 2:
97 | Labels2DModel().validate(value)
98 | super().__setitem__(key, value)
99 | elif ndim == 3:
100 | Labels3DModel().validate(value)
101 | super().__setitem__(key, value)
102 | else:
103 | NotImplementedError("TODO: implement for ndim > 3.")
104 |
105 |
106 | class Shapes(Elements[GeoDataFrame]):
107 | def __setitem__(self, key: str, value: GeoDataFrame) -> None:
108 | self._check_key(key, self.keys(), self._shared_keys)
109 | schema = get_model(value)
110 | if schema != ShapesModel:
111 | raise TypeError(f"Unknown element type with schema: {schema!r}.")
112 | ShapesModel().validate(value)
113 | super().__setitem__(key, value)
114 |
115 |
116 | class Points(Elements[DaskDataFrame]):
117 | def __setitem__(self, key: str, value: DaskDataFrame) -> None:
118 | self._check_key(key, self.keys(), self._shared_keys)
119 | schema = get_model(value)
120 | if schema != PointsModel:
121 | raise TypeError(f"Unknown element type with schema: {schema!r}.")
122 | PointsModel().validate(value)
123 | super().__setitem__(key, value)
124 |
125 |
126 | class Tables(Elements[AnnData]):
127 | def __setitem__(self, key: str, value: AnnData) -> None:
128 | self._check_key(key, self.keys(), self._shared_keys)
129 | schema = get_model(value)
130 | if schema != TableModel:
131 | raise TypeError(f"Unknown element type with schema: {schema!r}.")
132 | TableModel().validate(value)
133 | super().__setitem__(key, value)
134 |
--------------------------------------------------------------------------------
/src/spatialdata/_core/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Iterable
4 |
5 | from anndata import AnnData
6 |
7 | from spatialdata._core.spatialdata import SpatialData
8 |
9 |
10 | def _find_common_table_keys(sdatas: Iterable[SpatialData]) -> set[str]:
11 | """
12 | Find table keys present in more than one SpatialData object.
13 |
14 | Parameters
15 | ----------
16 | sdatas
17 | An `Iterable` of SpatialData objects.
18 |
19 | Returns
20 | -------
21 | A set of common keys that are present in the tables of more than one SpatialData object.
22 | """
23 | common_keys: set[str] = set()
24 |
25 | for sdata in sdatas:
26 | if len(common_keys) == 0:
27 | common_keys = set(sdata.tables.keys())
28 | else:
29 | common_keys.intersection_update(sdata.tables.keys())
30 |
31 | return common_keys
32 |
33 |
34 | def sanitize_name(name: str, is_dataframe_column: bool = False) -> str:
35 | """
36 | Sanitize a name to comply with SpatialData naming rules.
37 |
38 | This function converts invalid names into valid ones by:
39 | 1. Converting to string if not already
40 | 2. Removing invalid characters
41 | 3. Handling special cases like "__" prefix
42 | 4. Ensuring the name is not empty
43 | 5. Handling special cases for dataframe columns
44 |
45 | See a discussion on the naming rules, and how to avoid naming collisions, here:
46 | https://github.com/scverse/spatialdata/discussions/707
47 |
48 | Parameters
49 | ----------
50 | name
51 | The name to sanitize
52 | is_dataframe_column
53 | Whether this name is for a dataframe column (additional restrictions apply)
54 |
55 | Returns
56 | -------
57 | A sanitized version of the name that complies with SpatialData naming rules. If a
58 | santized name cannoted be generated, it returns "unnamed".
59 |
60 | Examples
61 | --------
62 | >>> sanitize_name("my@invalid#name")
63 | 'my_invalid_name'
64 | >>> sanitize_name("__private")
65 | 'private'
66 | >>> sanitize_name("_index", is_dataframe_column=True)
67 | 'index'
68 | """
69 | # Convert to string if not already
70 | name = str(name)
71 |
72 | # Handle empty string case
73 | if not name:
74 | return "unnamed"
75 |
76 | # Handle special cases
77 | if name in {".", ".."}:
78 | return "unnamed"
79 |
80 | sanitized = "".join(char if char.isalnum() or char in "_-." else "_" for char in name)
81 |
82 | # remove double underscores if found as a prefix
83 | while sanitized.startswith("__"):
84 | sanitized = sanitized[1:]
85 |
86 | if is_dataframe_column and sanitized == "_index":
87 | return "index"
88 |
89 | # Ensure we don't end up with an empty string after sanitization
90 | return sanitized or "unnamed"
91 |
92 |
93 | def sanitize_table(data: AnnData, inplace: bool = True) -> AnnData | None:
94 | """
95 | Sanitize all keys in an AnnData table to comply with SpatialData naming rules.
96 |
97 | This function sanitizes all keys in obs, var, obsm, obsp, varm, varp, uns, and layers
98 | while maintaining case-insensitive uniqueness. It can either modify the table in-place
99 | or return a new sanitized copy.
100 |
101 | See a discussion on the naming rules here:
102 | https://github.com/scverse/spatialdata/discussions/707
103 |
104 | Parameters
105 | ----------
106 | data
107 | The AnnData table to sanitize
108 | inplace
109 | Whether to modify the table in-place or return a new copy
110 |
111 | Returns
112 | -------
113 | If inplace is False, returns a new AnnData object with sanitized keys.
114 | If inplace is True, returns None as the original object is modified.
115 |
116 | Examples
117 | --------
118 | >>> import anndata as ad
119 | >>> adata = ad.AnnData(obs=pd.DataFrame({"@invalid#": [1, 2]}))
120 | >>> # Create a new sanitized copy
121 | >>> sanitized = sanitize_table(adata)
122 | >>> print(sanitized.obs.columns)
123 | Index(['invalid_'], dtype='object')
124 | >>> # Or modify in-place
125 | >>> sanitize_table(adata, inplace=True)
126 | >>> print(adata.obs.columns)
127 | Index(['invalid_'], dtype='object')
128 | """
129 | import copy
130 | from collections import defaultdict
131 |
132 | # Create a deep copy if not modifying in-place
133 | sanitized = data if inplace else copy.deepcopy(data)
134 |
135 | # Track used names to maintain case-insensitive uniqueness
136 | used_names_lower: dict[str, set[str]] = defaultdict(set)
137 |
138 | def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> str:
139 | base_name = sanitize_name(name, is_dataframe_column)
140 | normalized_base = base_name.lower()
141 |
142 | # If this exact name is already used, add a number
143 | if normalized_base in used_names_lower[attr]:
144 | counter = 1
145 | while f"{base_name}_{counter}".lower() in used_names_lower[attr]:
146 | counter += 1
147 | base_name = f"{base_name}_{counter}"
148 |
149 | used_names_lower[attr].add(base_name.lower())
150 | return base_name
151 |
152 | # Handle obs and var (dataframe columns)
153 | for attr in ("obs", "var"):
154 | df = getattr(sanitized, attr)
155 | new_columns = {old: get_unique_name(old, attr, is_dataframe_column=True) for old in df.columns}
156 | df.rename(columns=new_columns, inplace=True)
157 |
158 | # Handle other attributes
159 | for attr in ("obsm", "obsp", "varm", "varp", "uns", "layers"):
160 | d = getattr(sanitized, attr)
161 | new_keys = {old: get_unique_name(old, attr) for old in d}
162 | # Create new dictionary with sanitized keys
163 | new_dict = {new_keys[old]: value for old, value in d.items()}
164 | setattr(sanitized, attr, new_dict)
165 |
166 | return None if inplace else sanitized
167 |
--------------------------------------------------------------------------------
/src/spatialdata/_core/centroids.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import defaultdict
4 | from functools import singledispatch
5 |
6 | import dask.array as da
7 | import pandas as pd
8 | import xarray as xr
9 | from dask.dataframe import DataFrame as DaskDataFrame
10 | from geopandas import GeoDataFrame
11 | from shapely import MultiPolygon, Point, Polygon
12 | from xarray import DataArray, DataTree
13 |
14 | from spatialdata._core.operations.transform import transform
15 | from spatialdata.models import get_axes_names
16 | from spatialdata.models._utils import SpatialElement
17 | from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, get_model
18 | from spatialdata.transformations.operations import get_transformation
19 | from spatialdata.transformations.transformations import BaseTransformation
20 |
21 | BoundingBoxDescription = dict[str, tuple[float, float]]
22 |
23 |
24 | def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None:
25 | d = get_transformation(e, get_all=True)
26 | assert isinstance(d, dict)
27 | assert coordinate_system in d, (
28 | f"No transformation to coordinate system {coordinate_system} is available for the given element.\n"
29 | f"Available coordinate systems: {list(d.keys())}"
30 | )
31 |
32 |
33 | @singledispatch
34 | def get_centroids(
35 | e: SpatialElement,
36 | coordinate_system: str = "global",
37 | return_background: bool = False,
38 | ) -> DaskDataFrame:
39 | """
40 | Get the centroids of the geometries contained in a SpatialElement, as a new Points element.
41 |
42 | Parameters
43 | ----------
44 | e
45 | The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported.
46 | coordinate_system
47 | The coordinate system in which the centroids are computed.
48 | return_background
49 | If True, the centroid of the background label (0) is included in the output.
50 |
51 | Notes
52 | -----
53 | For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute
54 | each :class:`~shapely.Multipolygon`.
55 | """
56 | raise ValueError(f"The object type {type(e)} is not supported.")
57 |
58 |
59 | def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame:
60 | """
61 | Compute the component "axis" of the centroid of each label as a weighted average of the xarray coordinates.
62 |
63 | Parameters
64 | ----------
65 | xdata
66 | The xarray DataArray containing the labels.
67 | axis
68 | The axis for which the centroids are computed.
69 |
70 | Returns
71 | -------
72 | pd.DataFrame
73 | A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis.
74 | The index of the DataFrame is the collection of label values, sorted in ascending order.
75 | """
76 | centroids: dict[int, float] = defaultdict(float)
77 | for i in xdata[axis]:
78 | portion = xdata.sel(**{axis: i}).data
79 | u = da.unique(portion, return_counts=True)
80 | labels_values = u[0].compute()
81 | counts = u[1].compute()
82 | for j in range(len(labels_values)):
83 | label_value = labels_values[j]
84 | count = counts[j]
85 | centroids[label_value] += count * i.values.item()
86 |
87 | all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True)
88 | all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute(), strict=True))
89 | for label_value in centroids:
90 | centroids[label_value] /= all_labels[label_value]
91 | centroids = dict(sorted(centroids.items(), key=lambda x: x[0]))
92 | return pd.DataFrame({axis: centroids.values()}, index=list(centroids.keys()))
93 |
94 |
95 | @get_centroids.register(DataArray)
96 | @get_centroids.register(DataTree)
97 | def _(
98 | e: DataArray | DataTree,
99 | coordinate_system: str = "global",
100 | return_background: bool = False,
101 | ) -> DaskDataFrame:
102 | """Get the centroids of a Labels element (2D or 3D)."""
103 | model = get_model(e)
104 | if model not in [Labels2DModel, Labels3DModel]:
105 | raise ValueError("Expected a `Labels` element. Found an `Image` instead.")
106 | _validate_coordinate_system(e, coordinate_system)
107 |
108 | if isinstance(e, DataTree):
109 | assert len(e["scale0"]) == 1
110 | e = next(iter(e["scale0"].values()))
111 |
112 | dfs = []
113 | for axis in get_axes_names(e):
114 | dfs.append(_get_centroids_for_axis(e, axis))
115 | df = pd.concat(dfs, axis=1)
116 | if not return_background and 0 in df.index:
117 | df = df.drop(index=0) # drop the background label
118 | t = get_transformation(e, coordinate_system)
119 | centroids = PointsModel.parse(df, transformations={coordinate_system: t})
120 | return transform(centroids, to_coordinate_system=coordinate_system)
121 |
122 |
123 | @get_centroids.register(GeoDataFrame)
124 | def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
125 | """Get the centroids of a Shapes element (circles or polygons/multipolygons)."""
126 | _validate_coordinate_system(e, coordinate_system)
127 | t = get_transformation(e, coordinate_system)
128 | assert isinstance(t, BaseTransformation)
129 | # separate points from (multi-)polygons
130 | first_geometry = e["geometry"].iloc[0]
131 | if isinstance(first_geometry, Point):
132 | xy = e.geometry.get_coordinates().values
133 | else:
134 | assert isinstance(first_geometry, Polygon | MultiPolygon), (
135 | f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or"
136 | f" Polygons/MultiPolygons. Found {type(first_geometry)} instead."
137 | )
138 | xy = e.centroid.get_coordinates().values
139 | xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy())
140 | points = PointsModel.parse(xy_df, transformations={coordinate_system: t})
141 | return transform(points, to_coordinate_system=coordinate_system)
142 |
143 |
144 | @get_centroids.register(DaskDataFrame)
145 | def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
146 | """Get the centroids of a Points element."""
147 | _validate_coordinate_system(e, coordinate_system)
148 | axes = get_axes_names(e)
149 | assert axes in [("x", "y"), ("x", "y", "z")]
150 | coords = e[list(axes)].compute()
151 | t = get_transformation(e, coordinate_system)
152 | assert isinstance(t, BaseTransformation)
153 | centroids = PointsModel.parse(coords, transformations={coordinate_system: t})
154 | return transform(centroids, to_coordinate_system=coordinate_system)
155 |
156 |
157 | ##
158 |
--------------------------------------------------------------------------------
/src/spatialdata/_core/operations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/src/spatialdata/_core/operations/__init__.py
--------------------------------------------------------------------------------
/src/spatialdata/_core/operations/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from xarray import DataArray, DataTree
6 |
7 | from spatialdata.models import SpatialElement, get_axes_names, get_spatial_axes
8 |
9 | if TYPE_CHECKING:
10 | from spatialdata._core.spatialdata import SpatialData
11 |
12 |
13 | def transform_to_data_extent(
14 | sdata: SpatialData,
15 | coordinate_system: str,
16 | maintain_positioning: bool = True,
17 | target_unit_to_pixels: float | None = None,
18 | target_width: float | None = None,
19 | target_height: float | None = None,
20 | target_depth: float | None = None,
21 | ) -> SpatialData:
22 | """
23 | Transform the spatial data to match the data extent, so that pixels and vector coordinates correspond.
24 |
25 | Given a selected coordinate system, this function will transform the spatial data in that coordinate system, and
26 | will resample images, so that the pixels and vector coordinates correspond.
27 | In other words, the vector coordinate (x, y) (or (x, y, z)) will correspond to the pixel (y, x) (or (z, y, x)).
28 |
29 | When `maintain_positioning` is `False`, each transformation will be set to Identity. When `maintain_positioning` is
30 | `True` (default value), each element of the data will also have a transformation that will maintain the positioning
31 | of the element, as it was before calling this function.
32 | Note that in this case the correspondence between pixels and vector coordinates is true in the intrinsic coordinate
33 | system, not in the target coordinate system.
34 |
35 | Parameters
36 | ----------
37 | sdata
38 | The spatial data to transform.
39 | coordinate_system
40 | The coordinate system to use to compute the extent and to transform the data to.
41 | maintain_positioning
42 | If `True`, the transformation will maintain the positioning of the elements, as it was before calling this
43 | function. If `False`, each transformation will be set to Identity.
44 | target_unit_to_pixels
45 | The required number of pixels per unit (units in the target coordinate system) of the data that will be
46 | produced.
47 | target_width
48 | The width of the data extent, in pixels, for the data that will be produced.
49 | target_height
50 | The height of the data extent, in pixels, for the data that will be produced.
51 | target_depth
52 | The depth of the data extent, in pixels, for the data that will be produced.
53 |
54 | Returns
55 | -------
56 | SpatialData
57 | The transformed spatial data with downscaled and padded images and adjusted vector coordinates; all the
58 | transformations will set to Identity and the coordinates of the vector data will be aligned to the pixel
59 | coordinates.
60 |
61 | Notes
62 | -----
63 | - The data extent is the smallest rectangle that contains all the images and geometries.
64 | - DataTree objects (multiscale images) will be converted to DataArray (single-scale images) objects.
65 | - This helper function will be deprecated when https://github.com/scverse/spatialdata/issues/308 is closed,
66 | as this function will be easily recovered by `transform_to_coordinate_system()`
67 | """
68 | from spatialdata._core.data_extent import get_extent
69 | from spatialdata._core.operations.rasterize import _compute_target_dimensions, rasterize
70 | from spatialdata._core.spatialdata import SpatialData
71 | from spatialdata.transformations.operations import get_transformation, set_transformation
72 | from spatialdata.transformations.transformations import BaseTransformation, Identity, Scale, Sequence, Translation
73 |
74 | sdata = sdata.filter_by_coordinate_system(coordinate_system=coordinate_system)
75 | # calling transform_to_coordinate_system will likely decrease the resolution, let's use rasterize() instead
76 | sdata_vector = SpatialData(shapes=dict(sdata.shapes), points=dict(sdata.points))
77 | sdata_raster = SpatialData(images=dict(sdata.images), labels=dict(sdata.labels))
78 | sdata_vector_transformed = sdata_vector.transform_to_coordinate_system(coordinate_system)
79 |
80 | data_extent = get_extent(sdata, coordinate_system=coordinate_system)
81 | data_extent_axes = tuple(data_extent.keys())
82 | translation_to_origin = Translation([-data_extent[ax][0] for ax in data_extent_axes], axes=data_extent_axes)
83 |
84 | sizes = [data_extent[ax][1] - data_extent[ax][0] for ax in data_extent_axes]
85 | target_width, target_height, target_depth = _compute_target_dimensions(
86 | spatial_axes=data_extent_axes,
87 | min_coordinate=[0 for _ in data_extent_axes],
88 | max_coordinate=sizes,
89 | target_unit_to_pixels=target_unit_to_pixels,
90 | target_width=target_width,
91 | target_height=target_height,
92 | target_depth=target_depth,
93 | )
94 | scale_to_target_d = {
95 | "x": target_width / sizes[data_extent_axes.index("x")],
96 | "y": target_height / sizes[data_extent_axes.index("y")],
97 | }
98 | if target_depth is not None:
99 | scale_to_target_d["z"] = target_depth / sizes[data_extent_axes.index("z")]
100 | scale_to_target = Scale([scale_to_target_d[ax] for ax in data_extent_axes], axes=data_extent_axes)
101 |
102 | for el in sdata_vector_transformed._gen_spatial_element_values():
103 | t = get_transformation(el, to_coordinate_system=coordinate_system)
104 | assert isinstance(t, BaseTransformation)
105 | sequence = Sequence([t, translation_to_origin, scale_to_target])
106 | set_transformation(el, transformation=sequence, to_coordinate_system=coordinate_system)
107 | sdata_vector_transformed_inplace = sdata_vector_transformed.transform_to_coordinate_system(
108 | coordinate_system, maintain_positioning=True
109 | )
110 |
111 | sdata_to_return_elements = {
112 | **sdata_vector_transformed_inplace.shapes,
113 | **sdata_vector_transformed_inplace.points,
114 | }
115 |
116 | for _, element_name, element in sdata_raster.gen_spatial_elements():
117 | element_axes = get_spatial_axes(get_axes_names(element))
118 | if isinstance(element, DataArray | DataTree):
119 | rasterized = rasterize(
120 | element,
121 | axes=element_axes,
122 | min_coordinate=[data_extent[ax][0] for ax in element_axes],
123 | max_coordinate=[data_extent[ax][1] for ax in element_axes],
124 | target_coordinate_system=coordinate_system,
125 | target_unit_to_pixels=None,
126 | target_width=target_width,
127 | target_height=None,
128 | target_depth=None,
129 | return_regions_as_labels=True,
130 | )
131 | sdata_to_return_elements[element_name] = rasterized
132 | else:
133 | sdata_to_return_elements[element_name] = element
134 | if not maintain_positioning:
135 | for el in sdata_to_return_elements.values():
136 | set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True)
137 | for k, v in sdata.tables.items():
138 | sdata_to_return_elements[k] = v.copy()
139 | return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs)
140 |
141 |
142 | def _parse_element(
143 | element: str | SpatialElement, sdata: SpatialData | None, element_var_name: str, sdata_var_name: str
144 | ) -> SpatialElement:
145 | if not ((sdata is not None and isinstance(element, str)) ^ (not isinstance(element, str))):
146 | raise ValueError(
147 | f"To specify the {element_var_name!r} SpatialElement, please do one of the following: "
148 | f"- either pass a SpatialElement to the {element_var_name!r} parameter (and keep "
149 | f"`{sdata_var_name}` = None);"
150 | f"- either `{sdata_var_name}` needs to be a SpatialData object, and {element_var_name!r} needs "
151 | f"to be the string name of the element."
152 | )
153 | if sdata is not None:
154 | assert isinstance(element, str)
155 | return sdata[element]
156 | assert element is not None
157 | return element
158 |
--------------------------------------------------------------------------------
/src/spatialdata/_core/query/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/src/spatialdata/_core/query/__init__.py
--------------------------------------------------------------------------------
/src/spatialdata/_docs.py:
--------------------------------------------------------------------------------
1 | # from https://stackoverflow.com/questions/10307696/how-to-put-a-variable-into-python-docstring
2 | from collections.abc import Callable
3 | from typing import Any, TypeVar
4 |
5 | T = TypeVar("T")
6 |
7 |
8 | def docstring_parameter(*args: Any, **kwargs: Any) -> Callable[[T], T]:
9 | def dec(obj: T) -> T:
10 | if obj.__doc__:
11 | obj.__doc__ = obj.__doc__.format(*args, **kwargs)
12 | return obj
13 |
14 | return dec
15 |
--------------------------------------------------------------------------------
/src/spatialdata/_io/__init__.py:
--------------------------------------------------------------------------------
1 | from spatialdata._io._utils import get_dask_backing_files
2 | from spatialdata._io.format import SpatialDataFormat
3 | from spatialdata._io.io_points import write_points
4 | from spatialdata._io.io_raster import write_image, write_labels
5 | from spatialdata._io.io_shapes import write_shapes
6 | from spatialdata._io.io_table import write_table
7 |
8 | __all__ = [
9 | "write_image",
10 | "write_labels",
11 | "write_points",
12 | "write_shapes",
13 | "write_table",
14 | "SpatialDataFormat",
15 | "get_dask_backing_files",
16 | ]
17 |
--------------------------------------------------------------------------------
/src/spatialdata/_io/io_points.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections.abc import MutableMapping
3 | from pathlib import Path
4 |
5 | import zarr
6 | from dask.dataframe import DataFrame as DaskDataFrame
7 | from dask.dataframe import read_parquet
8 | from ome_zarr.format import Format
9 |
10 | from spatialdata._io._utils import (
11 | _get_transformations_from_ngff_dict,
12 | _write_metadata,
13 | overwrite_coordinate_transformations_non_raster,
14 | )
15 | from spatialdata._io.format import CurrentPointsFormat, PointsFormats, _parse_version
16 | from spatialdata.models import get_axes_names
17 | from spatialdata.transformations._utils import (
18 | _get_transformations,
19 | _set_transformations,
20 | )
21 |
22 |
23 | def _read_points(
24 | store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg]
25 | ) -> DaskDataFrame:
26 | """Read points from a zarr store."""
27 | assert isinstance(store, str | Path)
28 | f = zarr.open(store, mode="r")
29 |
30 | version = _parse_version(f, expect_attrs_key=True)
31 | assert version is not None
32 | format = PointsFormats[version]
33 |
34 | path = os.path.join(f._store.path, f.path, "points.parquet")
35 | # cache on remote file needed for parquet reader to work
36 | # TODO: allow reading in the metadata without caching all the data
37 | points = read_parquet("simplecache::" + path if path.startswith("http") else path)
38 | assert isinstance(points, DaskDataFrame)
39 |
40 | transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"])
41 | _set_transformations(points, transformations)
42 |
43 | attrs = format.attrs_from_dict(f.attrs.asdict())
44 | if len(attrs):
45 | points.attrs["spatialdata_attrs"] = attrs
46 | return points
47 |
48 |
49 | def write_points(
50 | points: DaskDataFrame,
51 | group: zarr.Group,
52 | name: str,
53 | group_type: str = "ngff:points",
54 | format: Format = CurrentPointsFormat(),
55 | ) -> None:
56 | axes = get_axes_names(points)
57 | t = _get_transformations(points)
58 |
59 | points_groups = group.require_group(name)
60 | path = Path(points_groups._store.path) / points_groups.path / "points.parquet"
61 |
62 | # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is
63 | # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the
64 | # categories to known categories using 'c.cat.as_known()' and assigns the transformed Series back to the original
65 | # DataFrame. This step is crucial when the number of categories exceeds 127, as pyarrow defaults to int8 for
66 | # unknown categories which can only hold values from -128 to 127.
67 | for column_name in points.columns:
68 | c = points[column_name]
69 | if c.dtype == "category" and not c.cat.known:
70 | c = c.cat.as_known()
71 | points[column_name] = c
72 |
73 | points.to_parquet(path)
74 |
75 | attrs = format.attrs_to_dict(points.attrs)
76 | attrs["version"] = format.spatialdata_format_version
77 |
78 | _write_metadata(
79 | points_groups,
80 | group_type=group_type,
81 | axes=list(axes),
82 | attrs=attrs,
83 | )
84 | assert t is not None
85 | overwrite_coordinate_transformations_non_raster(group=points_groups, axes=axes, transformations=t)
86 |
--------------------------------------------------------------------------------
/src/spatialdata/_io/io_shapes.py:
--------------------------------------------------------------------------------
1 | from collections.abc import MutableMapping
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import zarr
6 | from geopandas import GeoDataFrame, read_parquet
7 | from ome_zarr.format import Format
8 | from shapely import from_ragged_array, to_ragged_array
9 |
10 | from spatialdata._io._utils import (
11 | _get_transformations_from_ngff_dict,
12 | _write_metadata,
13 | overwrite_coordinate_transformations_non_raster,
14 | )
15 | from spatialdata._io.format import (
16 | CurrentShapesFormat,
17 | ShapesFormats,
18 | ShapesFormatV01,
19 | ShapesFormatV02,
20 | _parse_version,
21 | )
22 | from spatialdata.models import ShapesModel, get_axes_names
23 | from spatialdata.transformations._utils import (
24 | _get_transformations,
25 | _set_transformations,
26 | )
27 |
28 |
29 | def _read_shapes(
30 | store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg]
31 | ) -> GeoDataFrame:
32 | """Read shapes from a zarr store."""
33 | assert isinstance(store, str | Path)
34 | f = zarr.open(store, mode="r")
35 | version = _parse_version(f, expect_attrs_key=True)
36 | assert version is not None
37 | format = ShapesFormats[version]
38 |
39 | if isinstance(format, ShapesFormatV01):
40 | coords = np.array(f["coords"])
41 | index = np.array(f["Index"])
42 | typ = format.attrs_from_dict(f.attrs.asdict())
43 | if typ.name == "POINT":
44 | radius = np.array(f["radius"])
45 | geometry = from_ragged_array(typ, coords)
46 | geo_df = GeoDataFrame({"geometry": geometry, "radius": radius}, index=index)
47 | else:
48 | offsets_keys = [k for k in f if k.startswith("offset")]
49 | offsets = tuple(np.array(f[k]).flatten() for k in offsets_keys)
50 | geometry = from_ragged_array(typ, coords, offsets)
51 | geo_df = GeoDataFrame({"geometry": geometry}, index=index)
52 | elif isinstance(format, ShapesFormatV02):
53 | path = Path(f._store.path) / f.path / "shapes.parquet"
54 | geo_df = read_parquet(path)
55 | else:
56 | raise ValueError(
57 | f"Unsupported shapes format {format} from version {version}. Please update the spatialdata library."
58 | )
59 |
60 | transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"])
61 | _set_transformations(geo_df, transformations)
62 | return geo_df
63 |
64 |
65 | def write_shapes(
66 | shapes: GeoDataFrame,
67 | group: zarr.Group,
68 | name: str,
69 | group_type: str = "ngff:shapes",
70 | format: Format = CurrentShapesFormat(),
71 | ) -> None:
72 | import numcodecs
73 |
74 | axes = get_axes_names(shapes)
75 | t = _get_transformations(shapes)
76 |
77 | shapes_group = group.require_group(name)
78 |
79 | if isinstance(format, ShapesFormatV01):
80 | geometry, coords, offsets = to_ragged_array(shapes.geometry)
81 | shapes_group.create_dataset(name="coords", data=coords)
82 | for i, o in enumerate(offsets):
83 | shapes_group.create_dataset(name=f"offset{i}", data=o)
84 | if shapes.index.dtype.kind == "U" or shapes.index.dtype.kind == "O":
85 | shapes_group.create_dataset(
86 | name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8()
87 | )
88 | else:
89 | shapes_group.create_dataset(name="Index", data=shapes.index.values)
90 | if geometry.name == "POINT":
91 | shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values)
92 |
93 | attrs = format.attrs_to_dict(geometry)
94 | attrs["version"] = format.spatialdata_format_version
95 | elif isinstance(format, ShapesFormatV02):
96 | path = Path(shapes_group._store.path) / shapes_group.path / "shapes.parquet"
97 | shapes.to_parquet(path)
98 |
99 | attrs = format.attrs_to_dict(shapes.attrs)
100 | attrs["version"] = format.spatialdata_format_version
101 | else:
102 | raise ValueError(f"Unsupported format version {format.version}. Please update the spatialdata library.")
103 |
104 | _write_metadata(
105 | shapes_group,
106 | group_type=group_type,
107 | axes=list(axes),
108 | attrs=attrs,
109 | )
110 | assert t is not None
111 | overwrite_coordinate_transformations_non_raster(group=shapes_group, axes=axes, transformations=t)
112 |
--------------------------------------------------------------------------------
/src/spatialdata/_io/io_table.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from json import JSONDecodeError
5 | from typing import Literal
6 |
7 | import numpy as np
8 | import zarr
9 | from anndata import AnnData
10 | from anndata import read_zarr as read_anndata_zarr
11 | from anndata._io.specs import write_elem as write_adata
12 | from ome_zarr.format import Format
13 | from zarr.errors import ArrayNotFoundError
14 |
15 | from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors
16 | from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version
17 | from spatialdata._logging import logger
18 | from spatialdata.models import TableModel
19 |
20 |
21 | def _read_table(
22 | zarr_store_path: str,
23 | group: zarr.Group,
24 | subgroup: zarr.Group,
25 | tables: dict[str, AnnData],
26 | on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
27 | ) -> dict[str, AnnData]:
28 | """
29 | Read in tables in the tables Zarr.group of a SpatialData Zarr store.
30 |
31 | Parameters
32 | ----------
33 | zarr_store_path
34 | The path to the Zarr store.
35 | group
36 | The parent group containing the subgroup.
37 | subgroup
38 | The subgroup containing the tables.
39 | tables
40 | A dictionary of tables.
41 | on_bad_files
42 | Specifies what to do upon encountering a bad file, e.g. corrupted, invalid or missing files.
43 |
44 | Returns
45 | -------
46 | The modified dictionary with the tables.
47 | """
48 | count = 0
49 | for table_name in subgroup:
50 | f_elem = subgroup[table_name]
51 | f_elem_store = os.path.join(zarr_store_path, f_elem.path)
52 |
53 | with handle_read_errors(
54 | on_bad_files=on_bad_files,
55 | location=f"{subgroup.path}/{table_name}",
56 | exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError),
57 | ):
58 | tables[table_name] = read_anndata_zarr(f_elem_store)
59 |
60 | f = zarr.open(f_elem_store, mode="r")
61 | version = _parse_version(f, expect_attrs_key=False)
62 | assert version is not None
63 | # since have just one table format, we currently read it but do not use it; if we ever change the format
64 | # we can rename the two _ to format and implement the per-format read logic (as we do for shapes)
65 | _ = TablesFormats[version]
66 | f.store.close()
67 |
68 | # # replace with format from above
69 | # version = "0.1"
70 | # format = TablesFormats[version]
71 | if TableModel.ATTRS_KEY in tables[table_name].uns:
72 | # fill out eventual missing attributes that has been omitted because their value was None
73 | attrs = tables[table_name].uns[TableModel.ATTRS_KEY]
74 | if "region" not in attrs:
75 | attrs["region"] = None
76 | if "region_key" not in attrs:
77 | attrs["region_key"] = None
78 | if "instance_key" not in attrs:
79 | attrs["instance_key"] = None
80 | # fix type for region
81 | if "region" in attrs and isinstance(attrs["region"], np.ndarray):
82 | attrs["region"] = attrs["region"].tolist()
83 |
84 | count += 1
85 |
86 | logger.debug(f"Found {count} elements in {subgroup}")
87 | return tables
88 |
89 |
90 | def write_table(
91 | table: AnnData,
92 | group: zarr.Group,
93 | name: str,
94 | group_type: str = "ngff:regions_table",
95 | format: Format = CurrentTablesFormat(),
96 | ) -> None:
97 | if TableModel.ATTRS_KEY in table.uns:
98 | region = table.uns["spatialdata_attrs"]["region"]
99 | region_key = table.uns["spatialdata_attrs"].get("region_key", None)
100 | instance_key = table.uns["spatialdata_attrs"].get("instance_key", None)
101 | format.validate_table(table, region_key, instance_key)
102 | else:
103 | region, region_key, instance_key = (None, None, None)
104 | write_adata(group, name, table) # creates group[name]
105 | tables_group = group[name]
106 | tables_group.attrs["spatialdata-encoding-type"] = group_type
107 | tables_group.attrs["region"] = region
108 | tables_group.attrs["region_key"] = region_key
109 | tables_group.attrs["instance_key"] = instance_key
110 | tables_group.attrs["version"] = format.spatialdata_format_version
111 |
--------------------------------------------------------------------------------
/src/spatialdata/_logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def _setup_logger() -> "logging.Logger":
6 | from rich.console import Console
7 | from rich.logging import RichHandler
8 |
9 | logger = logging.getLogger(__name__)
10 | level = os.environ.get("LOGLEVEL", logging.INFO)
11 | logger.setLevel(level=level)
12 | console = Console(force_terminal=True)
13 | if console.is_jupyter is True:
14 | console.is_jupyter = False
15 | ch = RichHandler(show_path=False, console=console, show_time=logger.level == logging.DEBUG)
16 | logger.addHandler(ch)
17 |
18 | # this prevents double outputs
19 | logger.propagate = False
20 | return logger
21 |
22 |
23 | logger = _setup_logger()
24 |
--------------------------------------------------------------------------------
/src/spatialdata/_types.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 | from xarray import DataArray, DataTree
5 |
6 | __all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"]
7 |
8 | from numpy.typing import DTypeLike, NDArray
9 |
10 | ArrayLike = NDArray[np.floating[Any]]
11 | IntArrayLike = NDArray[np.integer[Any]]
12 |
13 | Raster_T = DataArray | DataTree
14 | ColorLike = tuple[float, ...] | str
15 |
--------------------------------------------------------------------------------
/src/spatialdata/config.py:
--------------------------------------------------------------------------------
1 | # chunk sizes bigger than this value (bytes) can trigger a compression error
2 | # https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276
3 | # so if we detect this during parsing/validation we raise a warning
4 | LARGE_CHUNK_THRESHOLD_BYTES = 2147483647
5 |
--------------------------------------------------------------------------------
/src/spatialdata/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from spatialdata.dataloader.datasets import ImageTilesDataset
3 | except ImportError:
4 | ImageTilesDataset = None # type: ignore[assignment, misc]
5 |
--------------------------------------------------------------------------------
/src/spatialdata/io/__init__.py:
--------------------------------------------------------------------------------
1 | """Experimental bridge to the spatialdata_io package."""
2 |
3 | try:
4 | from spatialdata_io import * # noqa: F403
5 | except ImportError as e:
6 | raise ImportError(
7 | "To access spatialdata.io, `spatialdata_io` must be installed, e.g. via `pip install spatialdata-io`. "
8 | f"Original exception: {e}"
9 | ) from e
10 |
--------------------------------------------------------------------------------
/src/spatialdata/models/__init__.py:
--------------------------------------------------------------------------------
1 | from spatialdata._core.validation import check_target_region_column_symmetry
2 | from spatialdata.models._utils import (
3 | C,
4 | SpatialElement,
5 | X,
6 | Y,
7 | Z,
8 | force_2d,
9 | get_axes_names,
10 | get_channel_names,
11 | get_channels,
12 | get_spatial_axes,
13 | points_dask_dataframe_to_geopandas,
14 | points_geopandas_to_dask_dataframe,
15 | set_channel_names,
16 | validate_axes,
17 | validate_axis_name,
18 | )
19 | from spatialdata.models.models import (
20 | Image2DModel,
21 | Image3DModel,
22 | Labels2DModel,
23 | Labels3DModel,
24 | PointsModel,
25 | ShapesModel,
26 | TableModel,
27 | get_model,
28 | get_table_keys,
29 | )
30 |
31 | __all__ = [
32 | "Labels2DModel",
33 | "Labels3DModel",
34 | "Image2DModel",
35 | "Image3DModel",
36 | "ShapesModel",
37 | "PointsModel",
38 | "TableModel",
39 | "get_model",
40 | "SpatialElement",
41 | "get_spatial_axes",
42 | "validate_axes",
43 | "validate_axis_name",
44 | "X",
45 | "Y",
46 | "Z",
47 | "C",
48 | "get_axes_names",
49 | "points_geopandas_to_dask_dataframe",
50 | "points_dask_dataframe_to_geopandas",
51 | "check_target_region_column_symmetry",
52 | "get_table_keys",
53 | "get_channels",
54 | "get_channel_names",
55 | "set_channel_names",
56 | "force_2d",
57 | "RasterSchema",
58 | ]
59 |
--------------------------------------------------------------------------------
/src/spatialdata/testing.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from anndata import AnnData
4 | from anndata.tests.helpers import assert_equal as assert_anndata_equal
5 | from dask.dataframe import DataFrame as DaskDataFrame
6 | from dask.dataframe.tests.test_dataframe import assert_eq as assert_dask_dataframe_equal
7 | from geopandas import GeoDataFrame
8 | from geopandas.testing import assert_geodataframe_equal
9 | from xarray import DataArray, DataTree
10 | from xarray.testing import assert_equal
11 |
12 | from spatialdata import SpatialData
13 | from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
14 | from spatialdata.models import PointsModel
15 | from spatialdata.models._utils import SpatialElement
16 | from spatialdata.transformations.operations import get_transformation
17 |
18 | _Elements = Images | Labels | Shapes | Points | Tables
19 |
20 |
21 | def assert_elements_dict_are_identical(
22 | elements0: _Elements,
23 | elements1: _Elements,
24 | check_transformations: bool = True,
25 | check_metadata: bool = True,
26 | ) -> None:
27 | """
28 | Compare two dictionaries of elements and assert that they are identical (except for the order of the keys).
29 |
30 | The dictionaries of elements can be obtained from a SpatialData object using the `.shapes`, `.labels`, `.points`,
31 | `.images` and `.tables` properties.
32 |
33 | Parameters
34 | ----------
35 | elements0
36 | The first dictionary of elements.
37 | elements1
38 | The second dictionary of elements.
39 | check_transformations
40 | Whether to check if the transformations are identical, for each element.
41 | check_metadata
42 | Whether to check if the metadata is identical, for each element.
43 |
44 | Returns
45 | -------
46 | None
47 |
48 | Raises
49 | ------
50 | AssertionError
51 | If the two dictionaries of elements are not identical.
52 |
53 | Notes
54 | -----
55 | Please see
56 | :func:`spatialdata.testing.assert_spatial_data_objects_are_identical` for additional comments.
57 | """
58 | assert set(elements0.keys()) == set(elements1.keys())
59 | for k in elements0:
60 | element0 = elements0[k]
61 | element1 = elements1[k]
62 | assert_elements_are_identical(
63 | element0,
64 | element1,
65 | check_transformations=check_transformations,
66 | check_metadata=check_metadata,
67 | )
68 |
69 |
70 | def assert_elements_are_identical(
71 | element0: SpatialElement | AnnData,
72 | element1: SpatialElement | AnnData,
73 | check_transformations: bool = True,
74 | check_metadata: bool = True,
75 | ) -> None:
76 | """
77 | Compare two elements (two SpatialElements or two tables) and assert that they are identical.
78 |
79 | Parameters
80 | ----------
81 | element0
82 | The first element.
83 | element1
84 | The second element.
85 | check_transformations
86 | Whether to check if the transformations of the elements are identical.
87 | check_metadata
88 | Whether to check if the metadata of the elements is identical.
89 |
90 | Returns
91 | -------
92 | None
93 |
94 | Raises
95 | ------
96 | AssertionError
97 | If the two elements are not identical.
98 |
99 | Notes
100 | -----
101 | Please see
102 | :func:`spatialdata.testing.assert_spatial_data_objects_are_identical` for additional comments.
103 | """
104 | assert type(element0) is type(element1)
105 | if check_transformations and not check_metadata:
106 | raise ValueError("check_transformations cannot be True if check_metadata is False")
107 |
108 | # compare transformations (only for SpatialElements)
109 | if not isinstance(element0, AnnData):
110 | transformations0 = get_transformation(element0, get_all=True)
111 | transformations1 = get_transformation(element1, get_all=True)
112 | assert isinstance(transformations0, dict)
113 | assert isinstance(transformations1, dict)
114 | if check_transformations:
115 | assert transformations0.keys() == transformations1.keys()
116 | for key in transformations0:
117 | assert transformations0[key] == transformations1[key], (
118 | f"transformations0[{key}] != transformations1[{key}]"
119 | )
120 |
121 | # compare the elements
122 | if isinstance(element0, AnnData):
123 | assert_anndata_equal(element0, element1)
124 | elif isinstance(element0, DataArray | DataTree):
125 | assert_equal(element0, element1)
126 | elif isinstance(element0, GeoDataFrame):
127 | assert_geodataframe_equal(element0, element1, check_less_precise=True)
128 | else:
129 | assert isinstance(element0, DaskDataFrame)
130 | assert_dask_dataframe_equal(element0, element1, check_divisions=False)
131 | if PointsModel.ATTRS_KEY in element0.attrs or PointsModel.ATTRS_KEY in element1.attrs:
132 | assert element0.attrs[PointsModel.ATTRS_KEY] == element1.attrs[PointsModel.ATTRS_KEY]
133 |
134 |
135 | def assert_spatial_data_objects_are_identical(
136 | sdata0: SpatialData,
137 | sdata1: SpatialData,
138 | check_transformations: bool = True,
139 | check_metadata: bool = True,
140 | ) -> None:
141 | """
142 | Compare two SpatialData objects and assert that they are identical.
143 |
144 | Parameters
145 | ----------
146 | sdata0
147 | The first SpatialData object.
148 | sdata1
149 | The second SpatialData object.
150 | check_transformations
151 | Whether to check if the transformations are identical, for each element.
152 | check_metadata
153 | Whether to check if the metadata is identical, for each element.
154 |
155 | Returns
156 | -------
157 | None
158 |
159 | Raises
160 | ------
161 | AssertionError
162 | If the two SpatialData objects are not identical.
163 |
164 | Notes
165 | -----
166 | If `check_metadata` is `True` but `check_transformations` is `False`, the metadata will be compared with
167 | the exclusion of the transformations.
168 |
169 | With the current implementation, the transformations Translate([1.0, 2.0],
170 | axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different.
171 | A quick way to avoid an error in this case is to use the check_transformations=False parameter.
172 | """
173 | # this is not a full comparison, but it's fine anyway
174 | element_names0 = [element_name for _, element_name, _ in sdata0.gen_elements()]
175 | element_names1 = [element_name for _, element_name, _ in sdata1.gen_elements()]
176 | assert len(set(element_names0)) == len(element_names0)
177 | assert len(set(element_names1)) == len(element_names1)
178 | assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems)
179 | for element_name in element_names0:
180 | element0 = sdata0[element_name]
181 | element1 = sdata1[element_name]
182 | assert_elements_are_identical(
183 | element0,
184 | element1,
185 | check_transformations=check_transformations,
186 | check_metadata=check_metadata,
187 | )
188 |
--------------------------------------------------------------------------------
/src/spatialdata/transformations/__init__.py:
--------------------------------------------------------------------------------
1 | from spatialdata.transformations.operations import (
2 | align_elements_using_landmarks,
3 | get_transformation,
4 | get_transformation_between_coordinate_systems,
5 | get_transformation_between_landmarks,
6 | remove_transformation,
7 | remove_transformations_to_coordinate_system,
8 | set_transformation,
9 | )
10 | from spatialdata.transformations.transformations import (
11 | Affine,
12 | BaseTransformation,
13 | Identity,
14 | MapAxis,
15 | Scale,
16 | Sequence,
17 | Translation,
18 | )
19 |
20 | __all__ = [
21 | "BaseTransformation",
22 | "Identity",
23 | "MapAxis",
24 | "Translation",
25 | "Scale",
26 | "Affine",
27 | "Sequence",
28 | "get_transformation",
29 | "set_transformation",
30 | "remove_transformation",
31 | "get_transformation_between_coordinate_systems",
32 | "get_transformation_between_landmarks",
33 | "align_elements_using_landmarks",
34 | "remove_transformations_to_coordinate_system",
35 | ]
36 |
--------------------------------------------------------------------------------
/src/spatialdata/transformations/ngff/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/src/spatialdata/transformations/ngff/__init__.py
--------------------------------------------------------------------------------
/src/spatialdata/transformations/ngff/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import copy
4 |
5 | from spatialdata.models import C, X, Y, Z
6 | from spatialdata.transformations.ngff.ngff_coordinate_system import (
7 | NgffAxis,
8 | NgffCoordinateSystem,
9 | )
10 |
11 | __all__ = "get_default_coordinate_system"
12 |
13 | # "unit" is a default placeholder value. This is not suported by NGFF so the user should replace it before saving
14 | x_axis = NgffAxis(name=X, type="space", unit="unit")
15 | y_axis = NgffAxis(name=Y, type="space", unit="unit")
16 | z_axis = NgffAxis(name=Z, type="space", unit="unit")
17 | c_axis = NgffAxis(name=C, type="channel")
18 | x_cs = NgffCoordinateSystem(name="x", axes=[x_axis])
19 | y_cs = NgffCoordinateSystem(name="y", axes=[y_axis])
20 | z_cs = NgffCoordinateSystem(name="z", axes=[z_axis])
21 | c_cs = NgffCoordinateSystem(name="c", axes=[c_axis])
22 | xy_cs = NgffCoordinateSystem(name="xy", axes=[x_axis, y_axis])
23 | xyz_cs = NgffCoordinateSystem(name="xyz", axes=[x_axis, y_axis, z_axis])
24 | yx_cs = NgffCoordinateSystem(name="yx", axes=[y_axis, x_axis])
25 | zyx_cs = NgffCoordinateSystem(name="zyx", axes=[z_axis, y_axis, x_axis])
26 | cyx_cs = NgffCoordinateSystem(name="cyx", axes=[c_axis, y_axis, x_axis])
27 | czyx_cs = NgffCoordinateSystem(name="czyx", axes=[c_axis, z_axis, y_axis, x_axis])
28 | _DEFAULT_COORDINATE_SYSTEM = {
29 | (X,): x_cs,
30 | (Y,): y_cs,
31 | (Z,): z_cs,
32 | (C,): c_cs,
33 | (X, Y): xy_cs,
34 | (X, Y, Z): xyz_cs,
35 | (Y, X): yx_cs,
36 | (Z, Y, X): zyx_cs,
37 | (C, Y, X): cyx_cs,
38 | (C, Z, Y, X): czyx_cs,
39 | }
40 |
41 |
42 | def get_default_coordinate_system(dims: tuple[str, ...]) -> NgffCoordinateSystem:
43 | """
44 | Get the default coordinate system
45 |
46 | Parameters
47 | ----------
48 | dims
49 | The dimension names to get the corresponding axes of the defeault coordinate system for.
50 | Names should be in ['x', 'y', 'z', 'c'].
51 |
52 | """
53 | axes = []
54 | for c in dims:
55 | if c == X:
56 | axes.append(copy.deepcopy(x_axis))
57 | elif c == Y:
58 | axes.append(copy.deepcopy(y_axis))
59 | elif c == Z:
60 | axes.append(copy.deepcopy(z_axis))
61 | elif c == C:
62 | axes.append(copy.deepcopy(c_axis))
63 | else:
64 | raise ValueError(f"Invalid dimension: {c}")
65 | return NgffCoordinateSystem(name="".join(dims), axes=axes)
66 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/__init__.py
--------------------------------------------------------------------------------
/tests/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/core/__init__.py
--------------------------------------------------------------------------------
/tests/core/operations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/core/operations/__init__.py
--------------------------------------------------------------------------------
/tests/core/operations/test_vectorize.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import pytest
5 | from geopandas import GeoDataFrame
6 | from shapely import MultiPoint, Point, Polygon
7 | from shapely.ops import unary_union
8 | from skimage.draw import polygon
9 |
10 | from spatialdata._core.operations.vectorize import (
11 | _vectorize_mask,
12 | to_circles,
13 | to_polygons,
14 | )
15 | from spatialdata.datasets import blobs
16 | from spatialdata.models.models import ShapesModel
17 | from spatialdata.testing import assert_elements_are_identical
18 |
19 | # each of the tests operates on different elements, hence we can initialize the data once without conflicts
20 | sdata = blobs()
21 |
22 |
23 | # conversion from labels
24 | @pytest.mark.parametrize("is_multiscale", [False, True])
25 | def test_labels_2d_to_circles(is_multiscale: bool) -> None:
26 | key = "blobs" + ("_multiscale" if is_multiscale else "") + "_labels"
27 | element = sdata[key]
28 | new_circles = to_circles(element)
29 |
30 | assert np.isclose(new_circles.loc[1].geometry.x, 330.59258152354386)
31 | assert np.isclose(new_circles.loc[1].geometry.y, 78.85026897788404)
32 | assert np.isclose(new_circles.loc[1].radius, 69.229993)
33 | assert 7 not in new_circles.index
34 |
35 |
36 | @pytest.mark.parametrize("is_multiscale", [False, True])
37 | def test_labels_2d_to_polygons(is_multiscale: bool) -> None:
38 | key = "blobs" + ("_multiscale" if is_multiscale else "") + "_labels"
39 | element = sdata[key]
40 | new_polygons = to_polygons(element)
41 |
42 | assert 7 not in new_polygons.index
43 |
44 | unique, counts = np.unique(sdata["blobs_labels"].compute().data, return_counts=True)
45 | new_polygons.loc[unique[1:], "pixel_count"] = counts[1:]
46 |
47 | assert ((new_polygons.area - new_polygons.pixel_count) / new_polygons.pixel_count < 0.01).all()
48 |
49 |
50 | def test_chunked_labels_2d_to_polygons() -> None:
51 | no_chunks_polygons = to_polygons(sdata["blobs_labels"])
52 |
53 | sdata["blobs_labels_chunked"] = sdata["blobs_labels"].copy()
54 | sdata["blobs_labels_chunked"].data = sdata["blobs_labels_chunked"].data.rechunk((200, 200))
55 |
56 | chunks_polygons = to_polygons(sdata["blobs_labels_chunked"])
57 |
58 | union = chunks_polygons.union(no_chunks_polygons)
59 |
60 | (no_chunks_polygons.area == union.area).all()
61 |
62 |
63 | # conversion from circles
64 | def test_circles_to_circles() -> None:
65 | element = sdata["blobs_circles"]
66 | new_circles = to_circles(element)
67 | assert_elements_are_identical(element, new_circles)
68 |
69 |
70 | def test_circles_to_polygons() -> None:
71 | element = sdata["blobs_circles"]
72 | polygons = to_polygons(element, buffer_resolution=1000)
73 | areas = element.radius**2 * math.pi
74 | assert np.allclose(polygons.area, areas)
75 |
76 |
77 | # conversion from polygons/multipolygons
78 | def test_polygons_to_circles() -> None:
79 | element = sdata["blobs_polygons"].iloc[:2]
80 | new_circles = to_circles(element)
81 |
82 | data = {
83 | "geometry": [
84 | Point(315.8120722406787, 220.18894606643332),
85 | Point(270.1386975678398, 417.8747936281634),
86 | ],
87 | "radius": [16.608781, 17.541365],
88 | }
89 | expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry"))
90 |
91 | assert_elements_are_identical(new_circles, expected)
92 |
93 |
94 | def test_multipolygons_to_circles() -> None:
95 | element = sdata["blobs_multipolygons"]
96 | new_circles = to_circles(element)
97 |
98 | data = {
99 | "geometry": [
100 | Point(340.37951022629096, 250.76310705786318),
101 | Point(337.1680699150594, 316.39984581697314),
102 | ],
103 | "radius": [23.488363, 19.059285],
104 | }
105 | expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry"))
106 | assert_elements_are_identical(new_circles, expected)
107 |
108 |
109 | def test_polygons_multipolygons_to_polygons() -> None:
110 | polygons = sdata["blobs_multipolygons"]
111 | assert polygons is to_polygons(polygons)
112 |
113 |
114 | # conversion from points
115 | def test_points_to_circles() -> None:
116 | element = sdata["blobs_points"]
117 | with pytest.raises(RuntimeError, match="`radius` must either be provided, either be a column"):
118 | to_circles(element)
119 | circles = to_circles(element, radius=1)
120 | x = circles.geometry.x
121 | y = circles.geometry.y
122 | assert np.array_equal(element["x"], x)
123 | assert np.array_equal(element["y"], y)
124 | assert np.array_equal(np.ones_like(x), circles["radius"])
125 |
126 |
127 | def test_points_to_polygons() -> None:
128 | with pytest.raises(RuntimeError, match="Cannot convert points to polygons"):
129 | to_polygons(sdata["blobs_points"])
130 |
131 |
132 | # conversion from images (invalid)
133 | def test_images_to_circles() -> None:
134 | with pytest.raises(RuntimeError, match=r"Cannot apply to_circles\(\) to images"):
135 | to_circles(sdata["blobs_image"])
136 |
137 |
138 | def test_images_to_polygons() -> None:
139 | with pytest.raises(RuntimeError, match=r"Cannot apply to_polygons\(\) to images"):
140 | to_polygons(sdata["blobs_image"])
141 |
142 |
143 | # conversion from other types (invalid)
144 | def test_invalid_geodataframe_to_circles() -> None:
145 | gdf = GeoDataFrame(geometry=[MultiPoint([[0, 0], [1, 1]])])
146 | with pytest.raises(RuntimeError, match="Unsupported geometry type"):
147 | to_circles(gdf)
148 |
149 |
150 | def test_invalid_geodataframe_to_polygons() -> None:
151 | gdf = GeoDataFrame(geometry=[MultiPoint([[0, 0], [1, 1]])])
152 | with pytest.raises(RuntimeError, match="Unsupported geometry type"):
153 | to_polygons(gdf)
154 |
155 |
156 | def test_vectorize_mask_almost_invertible() -> None:
157 | cell = Polygon([[10, 10], [30, 40], [90, 50], [100, 20]])
158 | image_shape = (70, 120)
159 |
160 | rasterized_image = np.zeros(image_shape, dtype=np.int8)
161 | x, y = cell.exterior.coords.xy
162 | rr, cc = polygon(y, x, image_shape)
163 | rasterized_image[rr, cc] = 1
164 |
165 | new_cell = _vectorize_mask(rasterized_image)
166 | new_cell = unary_union(new_cell.geometry)
167 |
168 | assert new_cell.intersection(cell).area / new_cell.union(cell).area > 0.97
169 |
170 |
171 | def test_label_column_vectorize_mask() -> None:
172 | assert "label" in _vectorize_mask(np.array([0]))
173 | assert "label" in _vectorize_mask(np.array([[0, 1], [1, 1]]))
174 |
--------------------------------------------------------------------------------
/tests/core/query/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/core/query/__init__.py
--------------------------------------------------------------------------------
/tests/core/query/test_relational_query_match_sdata_to_table.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from spatialdata import SpatialData, concatenate, match_sdata_to_table
4 | from spatialdata.datasets import blobs_annotating_element
5 |
6 |
7 | def _make_test_data() -> SpatialData:
8 | sdata1 = blobs_annotating_element("blobs_polygons")
9 | sdata2 = blobs_annotating_element("blobs_polygons")
10 | sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True)
11 | sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0]))
12 | return sdata
13 |
14 |
15 | # constructing the example data; let's use a global variable as we can reuse the same object on most tests
16 | # without having to recreate it
17 | sdata = _make_test_data()
18 |
19 |
20 | def test_match_sdata_to_table_filter_specific_instances():
21 | """
22 | Filter to keep only specific instances. Note that it works even when the table annotates multiple elements.
23 | """
24 | matched = match_sdata_to_table(
25 | sdata,
26 | table=sdata["table"][sdata["table"].obs.instance_id.isin([1, 2])],
27 | table_name="table",
28 | )
29 | assert len(matched["table"]) == 4
30 | assert "blobs_polygons-sdata1" in matched
31 | assert "blobs_polygons-sdata2" in matched
32 |
33 |
34 | def test_match_sdata_to_table_filter_specific_instances_element():
35 | """
36 | Filter to keep only specific instances, in a specific element.
37 | """
38 | matched = match_sdata_to_table(
39 | sdata,
40 | table=sdata["table"][
41 | sdata["table"].obs.instance_id.isin([1, 2]) & (sdata["table"].obs.region == "blobs_polygons-sdata1")
42 | ],
43 | table_name="table",
44 | )
45 | assert len(matched["table"]) == 2
46 | assert "blobs_polygons-sdata1" in matched
47 | assert "blobs_polygons-sdata2" not in matched
48 |
49 |
50 | def test_match_sdata_to_table_filter_by_threshold():
51 | """
52 | Filter by a threshold on a value column, in a specific element.
53 | """
54 | matched = match_sdata_to_table(
55 | sdata,
56 | table=sdata["table"][sdata["table"].obs.query('value < 5 and region == "blobs_polygons-sdata1"').index],
57 | table_name="table",
58 | )
59 | assert len(matched["table"]) == 5
60 | assert "blobs_polygons-sdata1" in matched
61 | assert "blobs_polygons-sdata2" not in matched
62 |
63 |
64 | def test_match_sdata_to_table_subset_certain_obs():
65 | """
66 | Subset to certain obs (we could also subset to certain var or layer).
67 | """
68 | matched = match_sdata_to_table(
69 | sdata,
70 | table=sdata["table"][[0, 1, 2, 3]],
71 | table_name="table",
72 | )
73 | assert len(matched["table"]) == 4
74 | assert "blobs_polygons-sdata1" in matched
75 | assert "blobs_polygons-sdata2" not in matched
76 |
77 |
78 | def test_match_sdata_to_table_shapes_and_points():
79 | """
80 | The function works both for shapes (examples above) and points.
81 | Changes the target of the table to labels.
82 | """
83 | sdata = _make_test_data()
84 | sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points"))
85 | sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
86 | sdata.set_table_annotates_spatialelement(
87 | table_name="table",
88 | region=["blobs_points-sdata1", "blobs_points-sdata2"],
89 | region_key="region",
90 | instance_key="instance_id",
91 | )
92 |
93 | matched = match_sdata_to_table(
94 | sdata,
95 | table=sdata["table"],
96 | table_name="table",
97 | )
98 |
99 | assert len(matched["table"]) == 10
100 | assert "blobs_points-sdata1" in matched
101 | assert "blobs_points-sdata2" in matched
102 | assert "blobs_polygons-sdata1" not in matched
103 |
104 |
105 | def test_match_sdata_to_table_match_labels_error():
106 | """
107 | match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the
108 | join.
109 | """
110 | sdata = _make_test_data()
111 | sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels"))
112 | sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
113 | sdata.set_table_annotates_spatialelement(
114 | table_name="table",
115 | region=["blobs_labels-sdata1", "blobs_labels-sdata2"],
116 | region_key="region",
117 | instance_key="instance_id",
118 | )
119 |
120 | with pytest.warns(
121 | UserWarning,
122 | match="Element type `labels` not supported for 'right' join. Skipping ",
123 | ):
124 | matched = match_sdata_to_table(
125 | sdata,
126 | table=sdata["table"],
127 | table_name="table",
128 | )
129 |
130 | assert len(matched["table"]) == 10
131 | assert "blobs_labels-sdata1" in matched
132 | assert "blobs_labels-sdata2" in matched
133 | assert "blobs_points-sdata1" not in matched
134 |
135 |
136 | def test_match_sdata_to_table_no_table_argument():
137 | """
138 | If no table argument is passed, the table_name argument will be used to match the table.
139 | """
140 | matched = match_sdata_to_table(sdata=sdata, table_name="table")
141 |
142 | assert len(matched["table"]) == 10
143 | assert "blobs_polygons-sdata1" in matched
144 | assert "blobs_polygons-sdata2" in matched
145 |
--------------------------------------------------------------------------------
/tests/core/test_centroids.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from anndata import AnnData
7 | from numpy.random import default_rng
8 |
9 | from spatialdata._core.centroids import get_centroids
10 | from spatialdata._core.query.relational_query import get_element_instances
11 | from spatialdata.models import Labels2DModel, Labels3DModel, PointsModel, TableModel, get_axes_names
12 | from spatialdata.transformations import Affine, Identity, get_transformation, set_transformation
13 |
14 | RNG = default_rng(42)
15 |
16 |
17 | def _get_affine() -> Affine:
18 | theta: float = math.pi / 18
19 | k = 10.0
20 | return Affine(
21 | [
22 | [2 * math.cos(theta), 2 * math.sin(-theta), -1000 / k],
23 | [2 * math.sin(theta), 2 * math.cos(theta), 300 / k],
24 | [0, 0, 1],
25 | ],
26 | input_axes=("x", "y"),
27 | output_axes=("x", "y"),
28 | )
29 |
30 |
31 | affine = _get_affine()
32 |
33 |
34 | @pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
35 | @pytest.mark.parametrize("is_3d", [False, True])
36 | def test_get_centroids_points(points, coordinate_system: str, is_3d: bool):
37 | element = points["points_0"].compute()
38 | element.index = np.arange(len(element)) + 10
39 | element = PointsModel.parse(element)
40 |
41 | # by default, the coordinate system is global and the points are 2D; let's modify the points as instructed by the
42 | # test arguments
43 | if coordinate_system == "aligned":
44 | set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
45 | if is_3d:
46 | element["z"] = element["x"]
47 |
48 | axes = get_axes_names(element)
49 | centroids = get_centroids(element, coordinate_system=coordinate_system)
50 |
51 | # the axes of the centroids should be the same as the axes of the element
52 | assert centroids.columns.tolist() == list(axes)
53 |
54 | # check the index is preserved
55 | assert np.array_equal(centroids.index.values, element.index.values)
56 |
57 | # the centroids should not contain extra columns
58 | assert "genes" in element.columns and "genes" not in centroids.columns
59 |
60 | # the centroids transformation to the target coordinate system should be an Identity because the transformation has
61 | # already been applied
62 | assert get_transformation(centroids, to_coordinate_system=coordinate_system) == Identity()
63 |
64 | # let's check the values
65 | if coordinate_system == "global":
66 | assert np.array_equal(centroids.compute().values, element[list(axes)].compute().values)
67 | else:
68 | matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes)
69 | centroids_untransformed = element[list(axes)].compute().values
70 | n = len(axes)
71 | centroids_transformed = np.dot(centroids_untransformed, matrix[:n, :n].T) + matrix[:n, n]
72 | assert np.allclose(centroids.compute().values, centroids_transformed)
73 |
74 |
75 | @pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
76 | @pytest.mark.parametrize("shapes_name", ["circles", "poly", "multipoly"])
77 | def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str):
78 | element = shapes[shapes_name]
79 | element.index = np.arange(len(element)) + 10
80 |
81 | if coordinate_system == "aligned":
82 | set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
83 | centroids = get_centroids(element, coordinate_system=coordinate_system)
84 |
85 | assert np.array_equal(centroids.index.values, element.index.values)
86 |
87 | if shapes_name == "circles":
88 | xy = element.geometry.get_coordinates().values
89 | else:
90 | assert shapes_name in ["poly", "multipoly"]
91 | xy = element.geometry.centroid.get_coordinates().values
92 |
93 | if coordinate_system == "global":
94 | assert np.array_equal(centroids.compute().values, xy)
95 | else:
96 | matrix = affine.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
97 | centroids_transformed = np.dot(xy, matrix[:2, :2].T) + matrix[:2, 2]
98 | assert np.allclose(centroids.compute().values, centroids_transformed)
99 |
100 |
101 | @pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
102 | @pytest.mark.parametrize("is_multiscale", [False, True])
103 | @pytest.mark.parametrize("is_3d", [False, True])
104 | @pytest.mark.parametrize("return_background", [False, True])
105 | def test_get_centroids_labels(
106 | labels, coordinate_system: str, is_multiscale: bool, is_3d: bool, return_background: bool
107 | ):
108 | scale_factors = [2] if is_multiscale else None
109 | if is_3d:
110 | model = Labels3DModel
111 | array = np.array(
112 | [
113 | [
114 | [0, 0, 10, 10],
115 | [0, 0, 10, 10],
116 | ],
117 | [
118 | [20, 20, 10, 10],
119 | [20, 20, 10, 10],
120 | ],
121 | ]
122 | )
123 | expected_centroids = pd.DataFrame(
124 | {
125 | "x": [1, 3, 1],
126 | "y": [1, 1.0, 1],
127 | "z": [0.5, 1, 1.5],
128 | },
129 | index=[0, 1, 2],
130 | )
131 | if not return_background:
132 | expected_centroids = expected_centroids.drop(index=0)
133 | else:
134 | array = np.array(
135 | [
136 | [10, 10, 10, 10],
137 | [20, 20, 20, 20],
138 | [20, 20, 20, 20],
139 | [20, 20, 20, 20],
140 | ]
141 | )
142 | model = Labels2DModel
143 | expected_centroids = pd.DataFrame(
144 | {
145 | "x": [2, 2],
146 | "y": [0.5, 2.5],
147 | },
148 | index=[1, 2],
149 | )
150 | element = model.parse(array, scale_factors=scale_factors)
151 |
152 | if coordinate_system == "aligned":
153 | set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
154 | centroids = get_centroids(element, coordinate_system=coordinate_system, return_background=return_background)
155 |
156 | labels_indices = get_element_instances(element, return_background=return_background)
157 | assert np.array_equal(centroids.index.values, labels_indices)
158 |
159 | if not return_background:
160 | assert 0 not in centroids.index
161 |
162 | if coordinate_system == "global":
163 | assert np.array_equal(centroids.compute().values, expected_centroids.values)
164 | else:
165 | axes = get_axes_names(element)
166 | n = len(axes)
167 | # the axes from the labels have 'x' last, but we want it first to manually transform the points, so we sort
168 | matrix = affine.to_affine_matrix(input_axes=sorted(axes), output_axes=sorted(axes))
169 | centroids_transformed = np.dot(expected_centroids.values, matrix[:n, :n].T) + matrix[:n, n]
170 | assert np.allclose(centroids.compute().values, centroids_transformed)
171 |
172 |
173 | def test_get_centroids_invalid_element(images):
174 | # cannot compute centroids for images
175 | with pytest.raises(ValueError, match="Expected a `Labels` element. Found an `Image` instead."):
176 | get_centroids(images["image2d"])
177 |
178 | # cannot compute centroids for tables
179 | N = 10
180 | adata = TableModel.parse(
181 | AnnData(X=RNG.random((N, N)), obs={"region": ["dummy" for _ in range(N)], "instance_id": np.arange(N)}),
182 | region="dummy",
183 | region_key="region",
184 | instance_key="instance_id",
185 | )
186 | with pytest.raises(ValueError, match="The object type is not supported."):
187 | get_centroids(adata)
188 |
189 |
190 | def test_get_centroids_invalid_coordinate_system(points):
191 | with pytest.raises(AssertionError, match="No transformation to coordinate system"):
192 | get_centroids(points["points_0"], coordinate_system="invalid")
193 |
--------------------------------------------------------------------------------
/tests/core/test_deepcopy.py:
--------------------------------------------------------------------------------
1 | from pandas.testing import assert_frame_equal
2 |
3 | from spatialdata import SpatialData
4 | from spatialdata._core._deepcopy import deepcopy as _deepcopy
5 | from spatialdata.testing import assert_spatial_data_objects_are_identical
6 |
7 |
8 | def test_deepcopy(full_sdata):
9 | to_delete = []
10 | for element_type, element_name in to_delete:
11 | del getattr(full_sdata, element_type)[element_name]
12 |
13 | copied = _deepcopy(full_sdata)
14 | # we first compute() the data in-place, then deepcopy and then we make the data lazy again; if the last step is
15 | # missing, calling _deepcopy() again on the original data would fail. Here we check for that.
16 | copied_again = _deepcopy(full_sdata)
17 |
18 | # workaround for https://github.com/scverse/spatialdata/issues/486
19 | for _, element_name, _ in full_sdata.gen_elements():
20 | assert full_sdata[element_name] is not copied[element_name]
21 | assert full_sdata[element_name] is not copied_again[element_name]
22 | assert copied[element_name] is not copied_again[element_name]
23 |
24 | p0_0 = full_sdata["points_0"].compute()
25 | columns = list(p0_0.columns)
26 | p0_1 = full_sdata["points_0_1"].compute()[columns]
27 |
28 | p1_0 = copied["points_0"].compute()[columns]
29 | p1_1 = copied["points_0_1"].compute()[columns]
30 |
31 | p2_0 = copied_again["points_0"].compute()[columns]
32 | p2_1 = copied_again["points_0_1"].compute()[columns]
33 |
34 | assert_frame_equal(p0_0, p1_0)
35 | assert_frame_equal(p0_1, p1_1)
36 | assert_frame_equal(p0_0, p2_0)
37 | assert_frame_equal(p0_1, p2_1)
38 |
39 | del full_sdata.points["points_0"]
40 | del full_sdata.points["points_0_1"]
41 | del copied.points["points_0"]
42 | del copied.points["points_0_1"]
43 | del copied_again.points["points_0"]
44 | del copied_again.points["points_0_1"]
45 | # end workaround
46 |
47 | assert_spatial_data_objects_are_identical(full_sdata, copied)
48 | assert_spatial_data_objects_are_identical(full_sdata, copied_again)
49 |
50 |
51 | def test_deepcopy_attrs(points: SpatialData) -> None:
52 | some_attrs = {"a": {"b": 0}}
53 | points.attrs = some_attrs
54 |
55 | # before deepcopy
56 | sub_points = points.subset(["points_0"])
57 | assert sub_points.attrs is some_attrs
58 | assert sub_points.attrs["a"] is some_attrs["a"]
59 |
60 | # after deepcopy
61 | sub_points_deepcopy = _deepcopy(sub_points)
62 | assert sub_points_deepcopy.attrs is not some_attrs
63 | assert sub_points_deepcopy.attrs["a"] is not some_attrs["a"]
64 |
--------------------------------------------------------------------------------
/tests/core/test_get_attrs.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from spatialdata.datasets import blobs
5 |
6 |
7 | @pytest.fixture
8 | def sdata_attrs():
9 | sdata = blobs()
10 | sdata.attrs["test"] = {"a": {"b": 12}, "c": 8}
11 | return sdata
12 |
13 |
14 | def test_get_attrs_as_is(sdata_attrs):
15 | result = sdata_attrs.get_attrs(key="test", return_as=None, flatten=False)
16 | expected = {"a": {"b": 12}, "c": 8}
17 | assert result == expected
18 |
19 |
20 | def test_get_attrs_as_dict_flatten(sdata_attrs):
21 | result = sdata_attrs.get_attrs(key="test", return_as="dict", flatten=True)
22 | expected = {"a_b": 12, "c": 8}
23 | assert result == expected
24 |
25 |
26 | def test_get_attrs_as_json_flatten_false(sdata_attrs):
27 | result = sdata_attrs.get_attrs(key="test", return_as="json", flatten=False)
28 | expected = '{"a": {"b": 12}, "c": 8}'
29 | assert result == expected
30 |
31 |
32 | def test_get_attrs_as_json_flatten_true(sdata_attrs):
33 | result = sdata_attrs.get_attrs(key="test", return_as="json", flatten=True)
34 | expected = '{"a_b": 12, "c": 8}'
35 | assert result == expected
36 |
37 |
38 | def test_get_attrs_as_dataframe_flatten_false(sdata_attrs):
39 | result = sdata_attrs.get_attrs(key="test", return_as="df", flatten=False)
40 | expected = pd.DataFrame([{"a": {"b": 12}, "c": 8}])
41 | pd.testing.assert_frame_equal(result, expected)
42 |
43 |
44 | def test_get_attrs_as_dataframe_flatten_true(sdata_attrs):
45 | result = sdata_attrs.get_attrs(key="test", return_as="df", flatten=True)
46 | expected = pd.DataFrame([{"a_b": 12, "c": 8}])
47 | pd.testing.assert_frame_equal(result, expected)
48 |
49 |
50 | # test invalid cases
51 | def test_invalid_key(sdata_attrs):
52 | with pytest.raises(KeyError, match="was not found in sdata.attrs"):
53 | sdata_attrs.get_attrs(key="non_existent_key")
54 |
55 |
56 | def test_invalid_return_as_value(sdata_attrs):
57 | with pytest.raises(ValueError, match="Invalid 'return_as' value"):
58 | sdata_attrs.get_attrs(key="test", return_as="invalid_option")
59 |
60 |
61 | def test_non_string_key(sdata_attrs):
62 | with pytest.raises(TypeError, match="The key must be a string."):
63 | sdata_attrs.get_attrs(key=123)
64 |
65 |
66 | def test_non_string_sep(sdata_attrs):
67 | with pytest.raises(TypeError, match="Parameter 'sep_for_nested_keys' must be a string."):
68 | sdata_attrs.get_attrs(key="test", sep=123)
69 |
70 |
71 | def test_empty_attrs():
72 | sdata = blobs()
73 | with pytest.raises(KeyError, match="was not found in sdata.attrs."):
74 | sdata.get_attrs(key="test")
75 |
--------------------------------------------------------------------------------
/tests/core/test_validation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from spatialdata._core.validation import ValidationError, raise_validation_errors
4 |
5 |
6 | def test_raise_validation_errors():
7 | with pytest.raises(expected_exception=ValidationError, match="Some errors happened") as actual_exc_info:
8 | ...
9 | with raise_validation_errors("Some errors happened", exc_type=ValueError) as collect_error:
10 | with collect_error(expected_exception=TypeError):
11 | raise TypeError("Another error type")
12 | for key, value in {"first": 1, "second": 2, "third": 3}.items():
13 | with collect_error(location=key):
14 | if value % 2 != 0:
15 | raise ValueError("Odd value encountered")
16 | actual_message = str(actual_exc_info.value)
17 | assert "Another error" in actual_message
18 | assert "first" in actual_message
19 | assert "second" not in actual_message
20 | assert "third" in actual_message
21 |
22 |
23 | def test_raise_validation_errors_does_not_catch_other_errors():
24 | with pytest.raises(expected_exception=RuntimeError, match="Not to be collected"):
25 | ...
26 | with raise_validation_errors(exc_type=ValueError) as collect_error:
27 | ...
28 | with collect_error:
29 | raise RuntimeError("Not to be collected as ValidationError")
30 |
--------------------------------------------------------------------------------
/tests/data/multipolygon.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "GeometryCollection",
3 | "geometries": [
4 | {
5 | "type": "MultiPolygon",
6 | "coordinates": [
7 | [
8 | [
9 | [40.0, 40.0],
10 | [20.0, 45.0],
11 | [45.0, 30.0],
12 | [40.0, 40.0]
13 | ]
14 | ],
15 | [
16 | [
17 | [20.0, 35.0],
18 | [10.0, 30.0],
19 | [10.0, 10.0],
20 | [30.0, 5.0],
21 | [45.0, 20.0],
22 | [20.0, 35.0]
23 | ],
24 | [
25 | [30.0, 20.0],
26 | [20.0, 15.0],
27 | [20.0, 25.0],
28 | [30.0, 20.0]
29 | ]
30 | ]
31 | ]
32 | },
33 | {
34 | "type": "MultiPolygon",
35 | "coordinates": [
36 | [
37 | [
38 | [40.0, 40.0],
39 | [20.0, 45.0],
40 | [45.0, 30.0],
41 | [40.0, 40.0]
42 | ]
43 | ],
44 | [
45 | [
46 | [30.0, 20.0],
47 | [20.0, 15.0],
48 | [20.0, 25.0],
49 | [30.0, 20.0]
50 | ]
51 | ]
52 | ]
53 | }
54 | ]
55 | }
56 |
--------------------------------------------------------------------------------
/tests/data/points.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "GeometryCollection",
3 | "geometries": [
4 | {
5 | "type": "Point",
6 | "coordinates": [10.0, 2.0]
7 | },
8 | {
9 | "type": "Point",
10 | "coordinates": [5.0, 2.0]
11 | }
12 | ]
13 | }
14 |
--------------------------------------------------------------------------------
/tests/data/polygon.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "GeometryCollection",
3 | "geometries": [
4 | {
5 | "type": "Polygon",
6 | "coordinates": [
7 | [
8 | [40.0, 40.0],
9 | [20.0, 45.0],
10 | [45.0, 30.0],
11 | [40.0, 40.0]
12 | ]
13 | ]
14 | },
15 | {
16 | "type": "Polygon",
17 | "coordinates": [
18 | [
19 | [40.0, 50.0],
20 | [20.0, 15.0],
21 | [45.0, 50.0],
22 | [40.0, 50.0]
23 | ]
24 | ]
25 | }
26 | ]
27 | }
28 |
--------------------------------------------------------------------------------
/tests/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | try:
4 | from spatialdata.dataloader.datasets import ImageTilesDataset
5 | except ImportError as e:
6 | _error: str | None = str(e)
7 | else:
8 | _error = None
9 |
10 | __all__ = ["ImageTilesDataset"]
11 |
--------------------------------------------------------------------------------
/tests/dataloader/test_datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from spatialdata.dataloader import ImageTilesDataset
5 | from spatialdata.datasets import blobs_annotating_element
6 |
7 |
8 | class TestImageTilesDataset:
9 | @pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"])
10 | @pytest.mark.parametrize(
11 | "regions_element",
12 | ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"],
13 | )
14 | @pytest.mark.parametrize("table", [True, False])
15 | def test_validation(self, sdata_blobs, image_element: str, regions_element: str, table: bool):
16 | if table:
17 | sdata = blobs_annotating_element(regions_element)
18 | else:
19 | sdata = sdata_blobs
20 | del sdata_blobs.tables["table"]
21 | _ = ImageTilesDataset(
22 | sdata=sdata,
23 | regions_to_images={regions_element: image_element},
24 | regions_to_coordinate_systems={regions_element: "global"},
25 | table_name="table" if table else None,
26 | return_annotations="instance_id" if table else None,
27 | )
28 |
29 | @pytest.mark.parametrize(
30 | "regions_element",
31 | ["blobs_circles", "blobs_polygons", "blobs_multipolygons", "blobs_labels", "blobs_multiscale_labels"],
32 | )
33 | @pytest.mark.parametrize("rasterize", [True, False])
34 | def test_default(self, sdata_blobs, regions_element, rasterize):
35 | rasterize_kwargs = {"target_unit_to_pixels": 2} if rasterize else {}
36 |
37 | sdata = blobs_annotating_element(regions_element)
38 | ds = ImageTilesDataset(
39 | sdata=sdata,
40 | rasterize=rasterize,
41 | regions_to_images={regions_element: "blobs_image"},
42 | regions_to_coordinate_systems={regions_element: "global"},
43 | rasterize_kwargs=rasterize_kwargs,
44 | table_name="table",
45 | )
46 |
47 | sdata_tile = ds[0]
48 | tile = sdata_tile.images.values().__iter__().__next__()
49 |
50 | if regions_element == "blobs_circles":
51 | if rasterize:
52 | assert tile.shape == (3, 20, 20)
53 | else:
54 | assert tile.shape == (3, 10, 10)
55 | elif regions_element == "blobs_polygons":
56 | if rasterize:
57 | assert tile.shape == (3, 6, 6)
58 | else:
59 | assert tile.shape == (3, 3, 3)
60 | elif regions_element == "blobs_multipolygons":
61 | if rasterize:
62 | assert tile.shape == (3, 9, 9)
63 | else:
64 | assert tile.shape == (3, 5, 4)
65 | elif regions_element == "blobs_labels" or regions_element == "blobs_multiscale_labels":
66 | if rasterize:
67 | assert tile.shape == (3, 16, 16)
68 | else:
69 | assert tile.shape == (3, 8, 8)
70 | else:
71 | raise ValueError(f"Unexpected regions_element: {regions_element}")
72 |
73 | # extent has units in pixel so should be the same as tile shape
74 | if rasterize:
75 | assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1]
76 | else:
77 | # here we have a tolerance of 1 pixel because the size of the tile depends on the values of the centroids
78 | # and of the extenta and here we keep the test simple.
79 | # For example, if the centroid is 0.5 and the extent is 0.1, the tile will be 1 pixel since the extent will
80 | # span 0.4 to 0.6; but if the centroid is 0.95 now the tile will be 2 pixels
81 | assert np.ceil(ds.tiles_coords.extent.unique()[0]) in [tile.shape[1], tile.shape[1] + 1]
82 | assert np.all(sdata_tile["table"].obs.columns == ds.sdata["table"].obs.columns)
83 | assert list(sdata_tile.images.keys())[0] == "blobs_image"
84 |
85 | @pytest.mark.parametrize(
86 | "regions_element",
87 | ["blobs_circles", "blobs_polygons", "blobs_multipolygons", "blobs_labels", "blobs_multiscale_labels"],
88 | )
89 | @pytest.mark.parametrize("return_annot", [None, "region", ["region", "instance_id"]])
90 | def test_return_annot(self, sdata_blobs, regions_element, return_annot):
91 | sdata = blobs_annotating_element(regions_element)
92 | ds = ImageTilesDataset(
93 | sdata=sdata,
94 | regions_to_images={regions_element: "blobs_image"},
95 | regions_to_coordinate_systems={regions_element: "global"},
96 | return_annotations=return_annot,
97 | table_name="table",
98 | )
99 | if return_annot is None:
100 | sdata_tile = ds[0]
101 | tile = sdata_tile["blobs_image"]
102 | else:
103 | tile, annot = ds[0]
104 | if regions_element == "blobs_circles":
105 | assert tile.shape == (3, 10, 10)
106 | elif regions_element == "blobs_polygons":
107 | assert tile.shape == (3, 3, 3)
108 | elif regions_element == "blobs_multipolygons":
109 | assert tile.shape == (3, 5, 4)
110 | elif regions_element == "blobs_labels" or regions_element == "blobs_multiscale_labels":
111 | assert tile.shape == (3, 8, 8)
112 | else:
113 | raise ValueError(f"Unexpected regions_element: {regions_element}")
114 | # extent has units in pixel so should be the same as tile shape
115 | # see comment in the test above explaining why we have a tolerance of 1 pixel
116 | assert np.ceil(ds.tiles_coords.extent.unique()[0]) in [tile.shape[1], tile.shape[1] + 1]
117 | if return_annot is not None:
118 | return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
119 | assert annot.shape[1] == len(return_annot)
120 |
121 | @pytest.mark.parametrize("rasterize", [True, False])
122 | @pytest.mark.parametrize("return_annot", [None, "region"])
123 | def test_multiscale_images(self, sdata_blobs, rasterize: bool, return_annot):
124 | sdata = blobs_annotating_element("blobs_circles")
125 | ds = ImageTilesDataset(
126 | sdata=sdata,
127 | regions_to_images={"blobs_circles": "blobs_multiscale_image"},
128 | regions_to_coordinate_systems={"blobs_circles": "global"},
129 | rasterize=rasterize,
130 | return_annotations=return_annot,
131 | table_name="table" if return_annot is not None else None,
132 | rasterize_kwargs={"target_unit_to_pixels": 1} if rasterize else None,
133 | )
134 | if return_annot is None:
135 | sdata_tile = ds[0]
136 | tile = sdata_tile["blobs_multiscale_image"]
137 | else:
138 | tile, annot = ds[0]
139 | assert tile.shape == (3, 10, 10)
140 |
--------------------------------------------------------------------------------
/tests/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/datasets/__init__.py
--------------------------------------------------------------------------------
/tests/datasets/test_datasets.py:
--------------------------------------------------------------------------------
1 | from spatialdata.datasets import blobs, raccoon
2 |
3 |
4 | def test_datasets() -> None:
5 | extra_cs = "test"
6 | sdata_blobs = blobs(extra_coord_system=extra_cs)
7 |
8 | assert len(sdata_blobs["table"]) == 26
9 | assert len(sdata_blobs.shapes["blobs_circles"]) == 5
10 | assert len(sdata_blobs.shapes["blobs_polygons"]) == 5
11 | assert len(sdata_blobs.shapes["blobs_multipolygons"]) == 2
12 | assert len(sdata_blobs.points["blobs_points"].compute()) == 200
13 | assert sdata_blobs.images["blobs_image"].shape == (3, 512, 512)
14 | assert len(sdata_blobs.images["blobs_multiscale_image"]) == 3
15 | assert sdata_blobs.labels["blobs_labels"].shape == (512, 512)
16 | assert len(sdata_blobs.labels["blobs_multiscale_labels"]) == 3
17 | assert extra_cs in sdata_blobs.coordinate_systems
18 | # this catches this bug: https://github.com/scverse/spatialdata/issues/269
19 | _ = str(sdata_blobs)
20 |
21 | sdata_raccoon = raccoon()
22 | assert "table" not in sdata_raccoon.tables
23 | assert len(sdata_raccoon.shapes["circles"]) == 4
24 | assert sdata_raccoon.images["raccoon"].shape == (3, 768, 1024)
25 | assert sdata_raccoon.labels["segmentation"].shape == (768, 1024)
26 | _ = str(sdata_raccoon)
27 |
--------------------------------------------------------------------------------
/tests/io/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/io/__init__.py
--------------------------------------------------------------------------------
/tests/io/test_format.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tempfile
3 | from pathlib import Path
4 | from typing import Any
5 |
6 | import pytest
7 | from shapely import GeometryType
8 |
9 | from spatialdata._io.format import (
10 | CurrentPointsFormat,
11 | CurrentShapesFormat,
12 | RasterFormatV01,
13 | RasterFormatV02,
14 | ShapesFormatV01,
15 | SpatialDataFormat,
16 | )
17 | from spatialdata.models import PointsModel, ShapesModel
18 |
19 | Points_f = CurrentPointsFormat()
20 | Shapes_f = CurrentShapesFormat()
21 |
22 |
23 | class TestFormat:
24 | """Test format."""
25 |
26 | @pytest.mark.parametrize("attrs_key", [PointsModel.ATTRS_KEY])
27 | @pytest.mark.parametrize("feature_key", [None, PointsModel.FEATURE_KEY])
28 | @pytest.mark.parametrize("instance_key", [None, PointsModel.INSTANCE_KEY])
29 | def test_format_points(
30 | self,
31 | attrs_key: str | None,
32 | feature_key: str | None,
33 | instance_key: str | None,
34 | ) -> None:
35 | metadata: dict[str, Any] = {attrs_key: {"version": Points_f.spatialdata_format_version}}
36 | format_metadata: dict[str, Any] = {attrs_key: {}}
37 | if feature_key is not None:
38 | metadata[attrs_key][feature_key] = "target"
39 | if instance_key is not None:
40 | metadata[attrs_key][instance_key] = "cell_id"
41 | format_metadata[attrs_key] = Points_f.attrs_from_dict(metadata)
42 | metadata[attrs_key].pop("version")
43 | assert metadata[attrs_key] == Points_f.attrs_to_dict(format_metadata)
44 | if feature_key is None and instance_key is None:
45 | assert len(format_metadata[attrs_key]) == len(metadata[attrs_key]) == 0
46 |
47 | @pytest.mark.parametrize("attrs_key", [ShapesModel.ATTRS_KEY])
48 | @pytest.mark.parametrize("geos_key", [ShapesModel.GEOS_KEY])
49 | @pytest.mark.parametrize("type_key", [ShapesModel.TYPE_KEY])
50 | @pytest.mark.parametrize("name_key", [ShapesModel.NAME_KEY])
51 | @pytest.mark.parametrize("shapes_type", [0, 3, 6])
52 | def test_format_shapes_v1(
53 | self,
54 | attrs_key: str,
55 | geos_key: str,
56 | type_key: str,
57 | name_key: str,
58 | shapes_type: int,
59 | ) -> None:
60 | shapes_dict = {
61 | 0: "POINT",
62 | 3: "POLYGON",
63 | 6: "MULTIPOLYGON",
64 | }
65 | metadata: dict[str, Any] = {attrs_key: {"version": ShapesFormatV01().spatialdata_format_version}}
66 | format_metadata: dict[str, Any] = {attrs_key: {}}
67 | metadata[attrs_key][geos_key] = {}
68 | metadata[attrs_key][geos_key][type_key] = shapes_type
69 | metadata[attrs_key][geos_key][name_key] = shapes_dict[shapes_type]
70 | format_metadata[attrs_key] = ShapesFormatV01().attrs_from_dict(metadata)
71 | metadata[attrs_key].pop("version")
72 | geometry = GeometryType(metadata[attrs_key][geos_key][type_key])
73 | assert metadata[attrs_key] == ShapesFormatV01().attrs_to_dict(geometry)
74 |
75 | @pytest.mark.parametrize("attrs_key", [ShapesModel.ATTRS_KEY])
76 | def test_format_shapes_v2(
77 | self,
78 | attrs_key: str,
79 | ) -> None:
80 | # not testing anything, maybe remove
81 | metadata: dict[str, Any] = {attrs_key: {"version": Shapes_f.spatialdata_format_version}}
82 | metadata[attrs_key].pop("version")
83 | assert metadata[attrs_key] == Shapes_f.attrs_to_dict({})
84 |
85 | @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02])
86 | def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormat]) -> None:
87 | with tempfile.TemporaryDirectory() as tmpdir:
88 | images.write(Path(tmpdir) / "images.zarr", format=format())
89 | zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs"
90 | with open(zattrs_file) as infile:
91 | zattrs = json.load(infile)
92 | ngff_version = zattrs["multiscales"][0]["version"]
93 | if format == RasterFormatV01:
94 | assert ngff_version == "0.4"
95 | else:
96 | assert format == RasterFormatV02
97 | assert ngff_version == "0.4-dev-spatialdata"
98 |
--------------------------------------------------------------------------------
/tests/io/test_pyramids_performance.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import TYPE_CHECKING
3 |
4 | import dask
5 | import dask.array
6 | import numpy as np
7 | import pytest
8 | import xarray as xr
9 | import zarr
10 |
11 | from spatialdata import SpatialData
12 | from spatialdata._io import write_image
13 | from spatialdata._io.format import CurrentRasterFormat
14 | from spatialdata.models import Image2DModel
15 |
16 | if TYPE_CHECKING:
17 | import _pytest.fixtures
18 |
19 |
20 | @pytest.fixture
21 | def sdata_with_image(request: "_pytest.fixtures.SubRequest", tmp_path: Path) -> SpatialData:
22 | params = request.param if request.param is not None else {}
23 | width = params.get("width", 2048)
24 | chunksize = params.get("chunk_size", 1024)
25 | scale_factors = params.get("scale_factors", (2,))
26 | # Create a disk-backed Dask array for scale 0.
27 | npg = np.random.default_rng(0)
28 | array = npg.integers(low=0, high=2**16, size=(1, width, width))
29 | array_path = tmp_path / "image.zarr"
30 | dask.array.from_array(array).rechunk(chunksize).to_zarr(array_path)
31 | array_backed = dask.array.from_zarr(array_path)
32 | # Create an in-memory SpatialData with disk-backed scale 0.
33 | image = Image2DModel.parse(array_backed, dims=("c", "y", "x"), scale_factors=scale_factors, chunks=chunksize)
34 | return SpatialData(images={"image": image})
35 |
36 |
37 | def count_chunks(array: xr.DataArray | xr.Dataset | xr.DataTree) -> int:
38 | if isinstance(array, xr.DataTree):
39 | array = array.ds
40 | # From `chunksizes`, we get only the number of chunks per axis.
41 | # By multiplying them, we get the total number of chunks in 2D/3D.
42 | return np.prod([len(chunk_sizes) for chunk_sizes in array.chunksizes.values()])
43 |
44 |
45 | @pytest.mark.parametrize(
46 | ("sdata_with_image",),
47 | [
48 | ({"width": 32, "chunk_size": 16, "scale_factors": (2,)},),
49 | ({"width": 64, "chunk_size": 16, "scale_factors": (2, 2)},),
50 | ({"width": 128, "chunk_size": 16, "scale_factors": (2, 2, 2)},),
51 | ({"width": 256, "chunk_size": 16, "scale_factors": (2, 2, 2, 2)},),
52 | ],
53 | indirect=["sdata_with_image"],
54 | )
55 | def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_path: Path, mocker):
56 | # Writing multiscale images with several pyramid levels should be efficient.
57 | # Specifically, it should not read the input image more often than necessary
58 | # (see issue https://github.com/scverse/spatialdata/issues/577).
59 | # Instead of measuring the time (which would have high variation if not using big datasets),
60 | # we watch the number of read and write accesses and compare to the theoretical number.
61 | zarr_chunk_write_spy = mocker.spy(zarr.core.Array, "__setitem__")
62 | zarr_chunk_read_spy = mocker.spy(zarr.core.Array, "__getitem__")
63 |
64 | image_name, image = next(iter(sdata_with_image.images.items()))
65 | element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images")
66 |
67 | write_image(
68 | image=image,
69 | group=element_type_group,
70 | name=image_name,
71 | format=CurrentRasterFormat(),
72 | )
73 |
74 | # The number of chunks of scale level 0
75 | num_chunks_scale0 = count_chunks(image.scale0 if isinstance(image, xr.DataTree) else image)
76 | # The total number of chunks of all scale levels
77 | num_chunks_all_scales = (
78 | sum(count_chunks(pyramid) for pyramid in image.children.values())
79 | if isinstance(image, xr.DataTree)
80 | else count_chunks(image)
81 | )
82 |
83 | actual_num_chunk_writes = zarr_chunk_write_spy.call_count
84 | actual_num_chunk_reads = zarr_chunk_read_spy.call_count
85 | assert actual_num_chunk_writes == num_chunks_all_scales.item()
86 | assert actual_num_chunk_reads == num_chunks_scale0.item()
87 |
--------------------------------------------------------------------------------
/tests/io/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from contextlib import nullcontext
4 |
5 | import dask.dataframe as dd
6 | import pytest
7 |
8 | from spatialdata import read_zarr
9 | from spatialdata._io._utils import get_dask_backing_files, handle_read_errors
10 |
11 |
12 | def test_backing_files_points(points):
13 | """Test the ability to identify the backing files of a dask dataframe from examining its computational graph"""
14 | with tempfile.TemporaryDirectory() as tmp_dir:
15 | f0 = os.path.join(tmp_dir, "points0.zarr")
16 | f1 = os.path.join(tmp_dir, "points1.zarr")
17 | points.write(f0)
18 | points.write(f1)
19 | points0 = read_zarr(f0)
20 | points1 = read_zarr(f1)
21 | p0 = points0.points["points_0"]
22 | p1 = points1.points["points_0"]
23 | p2 = dd.concat([p0, p1], axis=0)
24 | files = get_dask_backing_files(p2)
25 | expected_zarr_locations_legacy = [
26 | os.path.realpath(os.path.join(f, "points/points_0/points.parquet")) for f in [f0, f1]
27 | ]
28 | expected_zarr_locations_new = [
29 | os.path.realpath(os.path.join(f, "points/points_0/points.parquet/part.0.parquet")) for f in [f0, f1]
30 | ]
31 | assert set(files) == set(expected_zarr_locations_legacy) or set(files) == set(expected_zarr_locations_new)
32 |
33 |
34 | def test_backing_files_images(images):
35 | """
36 | Test the ability to identify the backing files of single scale and multiscale images from examining their
37 | computational graph
38 | """
39 | with tempfile.TemporaryDirectory() as tmp_dir:
40 | f0 = os.path.join(tmp_dir, "images0.zarr")
41 | f1 = os.path.join(tmp_dir, "images1.zarr")
42 | images.write(f0)
43 | images.write(f1)
44 | images0 = read_zarr(f0)
45 | images1 = read_zarr(f1)
46 |
47 | # single scale
48 | im0 = images0.images["image2d"]
49 | im1 = images1.images["image2d"]
50 | im2 = im0 + im1
51 | files = get_dask_backing_files(im2)
52 | expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]]
53 | assert set(files) == set(expected_zarr_locations)
54 |
55 | # multiscale
56 | im3 = images0.images["image2d_multiscale"]
57 | im4 = images1.images["image2d_multiscale"]
58 | im5 = im3 + im4
59 | files = get_dask_backing_files(im5)
60 | expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]]
61 | assert set(files) == set(expected_zarr_locations)
62 |
63 |
64 | # TODO: this function here below is very similar to the above, unify the test with the above or delete this todo
65 | def test_backing_files_labels(labels):
66 | """
67 | Test the ability to identify the backing files of single scale and multiscale labels from examining their
68 | computational graph
69 | """
70 | with tempfile.TemporaryDirectory() as tmp_dir:
71 | f0 = os.path.join(tmp_dir, "labels0.zarr")
72 | f1 = os.path.join(tmp_dir, "labels1.zarr")
73 | labels.write(f0)
74 | labels.write(f1)
75 | labels0 = read_zarr(f0)
76 | labels1 = read_zarr(f1)
77 |
78 | # single scale
79 | im0 = labels0.labels["labels2d"]
80 | im1 = labels1.labels["labels2d"]
81 | im2 = im0 + im1
82 | files = get_dask_backing_files(im2)
83 | expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d")) for f in [f0, f1]]
84 | assert set(files) == set(expected_zarr_locations)
85 |
86 | # multiscale
87 | im3 = labels0.labels["labels2d_multiscale"]
88 | im4 = labels1.labels["labels2d_multiscale"]
89 | im5 = im3 + im4
90 | files = get_dask_backing_files(im5)
91 | expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d_multiscale")) for f in [f0, f1]]
92 | assert set(files) == set(expected_zarr_locations)
93 |
94 |
95 | def test_backing_files_combining_points_and_images(points, images):
96 | """
97 | Test the ability to identify the backing files of an object that depends both on dask dataframes and dask arrays
98 | from examining its computational graph
99 | """
100 | with tempfile.TemporaryDirectory() as tmp_dir:
101 | f0 = os.path.join(tmp_dir, "points0.zarr")
102 | f1 = os.path.join(tmp_dir, "images1.zarr")
103 | points.write(f0)
104 | images.write(f1)
105 | points0 = read_zarr(f0)
106 | images1 = read_zarr(f1)
107 |
108 | p0 = points0.points["points_0"]
109 | im1 = images1.images["image2d"]
110 | v = p0["x"].loc[0].values
111 | v.compute_chunk_sizes()
112 | im2 = v + im1
113 | files = get_dask_backing_files(im2)
114 | expected_zarr_locations_old = [
115 | os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")),
116 | os.path.realpath(os.path.join(f1, "images/image2d")),
117 | ]
118 | expected_zarr_locations_new = [
119 | os.path.realpath(os.path.join(f0, "points/points_0/points.parquet/part.0.parquet")),
120 | os.path.realpath(os.path.join(f1, "images/image2d")),
121 | ]
122 | assert set(files) == set(expected_zarr_locations_old) or set(files) == set(expected_zarr_locations_new)
123 |
124 |
125 | @pytest.mark.parametrize(
126 | ("on_bad_files", "actual_error", "expectation"),
127 | [
128 | ("error", None, nullcontext()),
129 | ("error", KeyError("key"), pytest.raises(KeyError)),
130 | ("warn", None, nullcontext()),
131 | ("warn", KeyError("key"), pytest.warns(UserWarning, match="location: KeyError")),
132 | ("warn", RuntimeError("unhandled"), pytest.raises(RuntimeError)),
133 | ],
134 | )
135 | def test_handle_read_errors(on_bad_files: str, actual_error: Exception, expectation):
136 | with expectation: # noqa: SIM117
137 | with handle_read_errors(on_bad_files=on_bad_files, location="location", exc_types=KeyError):
138 | if actual_error is not None:
139 | raise actual_error
140 |
--------------------------------------------------------------------------------
/tests/io/test_versions.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from pathlib import Path
3 |
4 | from spatialdata import read_zarr
5 | from spatialdata._io.format import ShapesFormatV01, ShapesFormatV02
6 | from spatialdata.testing import assert_spatial_data_objects_are_identical
7 |
8 |
9 | def test_shapes_v1_to_v2(shapes):
10 | with tempfile.TemporaryDirectory() as tmpdir:
11 | f0 = Path(tmpdir) / "data0.zarr"
12 | f1 = Path(tmpdir) / "data1.zarr"
13 |
14 | # write shapes in version 1
15 | shapes.write(f0, format=ShapesFormatV01())
16 |
17 | # reading from v1 works
18 | shapes_read = read_zarr(f0)
19 |
20 | assert_spatial_data_objects_are_identical(shapes, shapes_read)
21 |
22 | # write shapes using the v2 version
23 | shapes_read.write(f1, format=ShapesFormatV02())
24 |
25 | # read again
26 | shapes_read = read_zarr(f1)
27 |
28 | assert_spatial_data_objects_are_identical(shapes, shapes_read)
29 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/models/__init__.py
--------------------------------------------------------------------------------
/tests/transformations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/transformations/__init__.py
--------------------------------------------------------------------------------
/tests/transformations/ngff/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/transformations/ngff/__init__.py
--------------------------------------------------------------------------------
/tests/transformations/ngff/conftest.py:
--------------------------------------------------------------------------------
1 | from spatialdata.transformations.ngff.ngff_coordinate_system import (
2 | NgffAxis,
3 | NgffCoordinateSystem,
4 | )
5 |
6 | x_axis = NgffAxis(name="x", type="space", unit="micrometer")
7 | y_axis = NgffAxis(name="y", type="space", unit="micrometer")
8 | z_axis = NgffAxis(name="z", type="space", unit="micrometer")
9 | c_axis = NgffAxis(name="c", type="channel")
10 | x_cs = NgffCoordinateSystem(name="x", axes=[x_axis])
11 | y_cs = NgffCoordinateSystem(name="y", axes=[y_axis])
12 | z_cs = NgffCoordinateSystem(name="z", axes=[z_axis])
13 | c_cs = NgffCoordinateSystem(name="c", axes=[c_axis])
14 | xy_cs = NgffCoordinateSystem(name="xy", axes=[x_axis, y_axis])
15 | xyz_cs = NgffCoordinateSystem(name="xyz", axes=[x_axis, y_axis, z_axis])
16 | xyc_cs = NgffCoordinateSystem(name="xyc", axes=[x_axis, y_axis, c_axis])
17 | xyzc_cs = NgffCoordinateSystem(name="xyzc", axes=[x_axis, y_axis, z_axis, c_axis])
18 | yx_cs = NgffCoordinateSystem(name="yx", axes=[y_axis, x_axis])
19 | zyx_cs = NgffCoordinateSystem(name="zyx", axes=[z_axis, y_axis, x_axis])
20 | cyx_cs = NgffCoordinateSystem(name="cyx", axes=[c_axis, y_axis, x_axis])
21 | czyx_cs = NgffCoordinateSystem(name="czyx", axes=[c_axis, z_axis, y_axis, x_axis])
22 |
--------------------------------------------------------------------------------
/tests/transformations/ngff/test_ngff_coordinate_system.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 |
4 | import pytest
5 |
6 | from spatialdata.transformations.ngff.ngff_coordinate_system import (
7 | NgffAxis,
8 | NgffCoordinateSystem,
9 | )
10 |
11 | input_dict = {
12 | "name": "volume_micrometers",
13 | "axes": [
14 | {"name": "x", "type": "space", "unit": "micrometer"},
15 | {"name": "y", "type": "space", "unit": "micrometer"},
16 | {"name": "z", "type": "space", "unit": "micrometer"},
17 | ],
18 | }
19 |
20 |
21 | def test_coordinate_system_instantiation_and_properties():
22 | coord_sys = NgffCoordinateSystem.from_dict(input_dict)
23 | assert coord_sys.name == "volume_micrometers"
24 | assert [ax.name for ax in coord_sys._axes] == ["x", "y", "z"]
25 | assert coord_sys.axes_names == ("x", "y", "z")
26 |
27 | assert [ax.type for ax in coord_sys._axes] == ["space", "space", "space"]
28 | assert coord_sys.axes_types == ("space", "space", "space")
29 |
30 | output_dict = coord_sys.to_dict()
31 | assert input_dict == output_dict
32 |
33 | axes = [
34 | NgffAxis(name="x", type="space", unit="micrometer"),
35 | NgffAxis(name="y", type="space", unit="micrometer"),
36 | NgffAxis(name="z", type="space", unit="micrometer"),
37 | ]
38 | coord_manual = NgffCoordinateSystem(
39 | name="volume_micrometers",
40 | axes=axes,
41 | )
42 |
43 | assert coord_manual.to_dict() == coord_sys.to_dict()
44 | with pytest.raises(ValueError):
45 | NgffCoordinateSystem(
46 | name="non unique axes names",
47 | axes=[
48 | NgffAxis(name="x", type="space", unit="micrometer"),
49 | NgffAxis(name="x", type="space", unit="micrometer"),
50 | ],
51 | )
52 |
53 |
54 | def test_coordinate_system_exceptions():
55 | input_dict1 = copy.deepcopy(input_dict)
56 | input_dict1["axes"][0].pop("name")
57 | coord_sys = NgffCoordinateSystem(name="test")
58 | with pytest.raises(ValueError):
59 | coord_sys.from_dict(input_dict1)
60 |
61 | input_dict2 = copy.deepcopy(input_dict)
62 | input_dict2["axes"][0].pop("type")
63 | coord_sys = NgffCoordinateSystem(name="test")
64 | with pytest.raises(ValueError):
65 | coord_sys.from_dict(input_dict2)
66 |
67 | # not testing the cases with axis.type in ['channel', 'array'] as all the more complex checks are handled by the
68 | # validation schema
69 |
70 |
71 | def test_coordinate_system_roundtrip():
72 | input_json = json.dumps(input_dict)
73 | cs = NgffCoordinateSystem.from_json(input_json)
74 | output_json = cs.to_json()
75 | assert input_json == output_json
76 | cs2 = NgffCoordinateSystem.from_json(output_json)
77 | assert cs == cs2
78 |
79 |
80 | def test_repr():
81 | cs = NgffCoordinateSystem(
82 | "some coordinate system",
83 | [
84 | NgffAxis("X", "space", "micrometers"),
85 | NgffAxis("Y", "space", "meters"),
86 | NgffAxis("T", "time"),
87 | ],
88 | )
89 | expected = (
90 | "NgffCoordinateSystem('some coordinate system',"
91 | + " [NgffAxis('X', 'space', 'micrometers'),"
92 | + " NgffAxis('Y', 'space', 'meters'), NgffAxis('T', 'time')])"
93 | )
94 | as_str = repr(cs)
95 |
96 | assert as_str == expected
97 |
98 |
99 | def test_equal_up_to_the_units():
100 | cs1 = NgffCoordinateSystem(
101 | "some coordinate system",
102 | [
103 | NgffAxis("X", "space", "micrometers"),
104 | NgffAxis("Y", "space", "meters"),
105 | NgffAxis("T", "time"),
106 | ],
107 | )
108 | cs2 = NgffCoordinateSystem(
109 | "some coordinate systema",
110 | [
111 | NgffAxis("X", "space", "micrometers"),
112 | NgffAxis("Y", "space", "meters"),
113 | NgffAxis("T", "time"),
114 | ],
115 | )
116 | cs3 = NgffCoordinateSystem(
117 | "some coordinate system",
118 | [
119 | NgffAxis("X", "space", "gigameters"),
120 | NgffAxis("Y", "space", ""),
121 | NgffAxis("T", "time"),
122 | ],
123 | )
124 |
125 | assert cs1.equal_up_to_the_units(cs1)
126 | assert not cs1.equal_up_to_the_units(cs2)
127 | assert cs1.equal_up_to_the_units(cs3)
128 |
129 |
130 | def test_subset_coordinate_system():
131 | cs = NgffCoordinateSystem(
132 | "some coordinate system",
133 | [
134 | NgffAxis("X", "space", "micrometers"),
135 | NgffAxis("Y", "space", "meters"),
136 | NgffAxis("Z", "space", "meters"),
137 | NgffAxis("T", "time"),
138 | ],
139 | )
140 | cs0 = cs.subset(["X", "Z"])
141 | cs1 = cs.subset(["X", "Y"], new_name="XY")
142 | assert cs0 == NgffCoordinateSystem(
143 | "some coordinate system_subset ['X', 'Z']",
144 | [
145 | NgffAxis("X", "space", "micrometers"),
146 | NgffAxis("Z", "space", "meters"),
147 | ],
148 | )
149 | assert cs1 == NgffCoordinateSystem(
150 | "XY",
151 | [
152 | NgffAxis("X", "space", "micrometers"),
153 | NgffAxis("Y", "space", "meters"),
154 | ],
155 | )
156 |
157 |
158 | def test_merge_coordinate_systems():
159 | cs0 = NgffCoordinateSystem(
160 | "cs0",
161 | [
162 | NgffAxis("X", "space", "micrometers"),
163 | NgffAxis("Y", "space", "meters"),
164 | ],
165 | )
166 | cs1 = NgffCoordinateSystem(
167 | "cs1",
168 | [
169 | NgffAxis("X", "space", "micrometers"),
170 | ],
171 | )
172 | cs2 = NgffCoordinateSystem(
173 | "cs2",
174 | [
175 | NgffAxis("X", "space", "meters"),
176 | NgffAxis("Y", "space", "meters"),
177 | ],
178 | )
179 | cs3 = NgffCoordinateSystem(
180 | "cs3",
181 | [
182 | NgffAxis("Z", "space", "micrometers"),
183 | ],
184 | )
185 | assert cs0.merge(cs0, cs1) == NgffCoordinateSystem(
186 | "cs0_merged_cs1",
187 | [
188 | NgffAxis("X", "space", "micrometers"),
189 | NgffAxis("Y", "space", "meters"),
190 | ],
191 | )
192 | with pytest.raises(ValueError):
193 | NgffCoordinateSystem.merge(cs0, cs2)
194 | assert NgffCoordinateSystem.merge(cs0, cs3) == NgffCoordinateSystem(
195 | "cs0_merged_cs3",
196 | [
197 | NgffAxis("X", "space", "micrometers"),
198 | NgffAxis("Y", "space", "meters"),
199 | NgffAxis("Z", "space", "micrometers"),
200 | ],
201 | )
202 |
--------------------------------------------------------------------------------
/tests/transformations/test_transformations_utils.py:
--------------------------------------------------------------------------------
1 | from spatialdata.transformations._utils import convert_transformations_to_affine
2 | from spatialdata.transformations.operations import get_transformation, set_transformation
3 | from spatialdata.transformations.transformations import Affine, Scale, Sequence, Translation
4 |
5 |
6 | def test_convert_transformations_to_affine(full_sdata):
7 | translation = Translation([1, 2, 3], axes=("x", "y", "z"))
8 | scale = Scale([1, 2, 3], axes=("x", "y", "z"))
9 | sequence = Sequence([translation, scale])
10 | for _, _, element in full_sdata.gen_spatial_elements():
11 | set_transformation(element, transformation=sequence, to_coordinate_system="test")
12 | convert_transformations_to_affine(full_sdata, "test")
13 | for _, _, element in full_sdata.gen_spatial_elements():
14 | t = get_transformation(element, "test")
15 | assert isinstance(t, Affine)
16 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/spatialdata/872f8c219fda0e348af056143af9be61a34e8ca5/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/test_element_utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import dask_image.ndinterp
4 | import pytest
5 | import xarray
6 | from xarray import DataArray, DataTree
7 |
8 | from spatialdata._utils import unpad_raster
9 | from spatialdata.models import get_model
10 | from spatialdata.transformations import Affine
11 |
12 |
13 | def _pad_raster(data: DataArray, axes: tuple[str, ...]) -> DataArray:
14 | new_shape = tuple([data.shape[i] * (2 if axes[i] != "c" else 1) for i in range(len(data.shape))])
15 | x = data.shape[axes.index("x")]
16 | y = data.shape[axes.index("y")]
17 | affine = Affine(
18 | [
19 | [1, 0, -x / 2.0],
20 | [0, 1, -y / 2.0],
21 | [0, 0, 1],
22 | ],
23 | input_axes=("x", "y"),
24 | output_axes=("x", "y"),
25 | )
26 | matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes)
27 | return dask_image.ndinterp.affine_transform(data, matrix, output_shape=new_shape)
28 |
29 |
30 | @pytest.mark.ci_only
31 | def test_unpad_raster(images, labels) -> None:
32 | for raster in itertools.chain(images.images.values(), labels.labels.values()):
33 | schema = get_model(raster)
34 | if isinstance(raster, DataArray):
35 | data = raster
36 | elif isinstance(raster, DataTree):
37 | d = dict(raster["scale0"])
38 | assert len(d) == 1
39 | data = d.values().__iter__().__next__()
40 | else:
41 | raise ValueError(f"Unknown type: {type(raster)}")
42 | padded = _pad_raster(data.data, data.dims)
43 | if isinstance(raster, DataArray):
44 | padded = schema.parse(padded, dims=data.dims, c_coords=data.coords.get("c", None))
45 | elif isinstance(raster, DataTree):
46 | # some arbitrary scaling factors
47 | padded = schema.parse(padded, dims=data.dims, scale_factors=[2, 2], c_coords=data.coords.get("c", None))
48 | else:
49 | raise ValueError(f"Unknown type: {type(raster)}")
50 | unpadded = unpad_raster(padded)
51 | if isinstance(raster, DataArray):
52 | try:
53 | xarray.testing.assert_equal(raster, unpadded)
54 | except AssertionError as e:
55 | raise e
56 | elif isinstance(raster, DataTree):
57 | d0 = dict(raster["scale0"])
58 | assert len(d0) == 1
59 | d1 = dict(unpadded["scale0"])
60 | assert len(d1) == 1
61 | try:
62 | xarray.testing.assert_equal(d0.values().__iter__().__next__(), d1.values().__iter__().__next__())
63 | except AssertionError as e:
64 | raise e
65 | else:
66 | raise ValueError(f"Unknown type: {type(raster)}")
67 |
--------------------------------------------------------------------------------
/tests/utils/test_sanitize.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from anndata import AnnData
7 |
8 | from spatialdata import SpatialData
9 | from spatialdata._core._utils import sanitize_name, sanitize_table
10 |
11 |
12 | @pytest.fixture
13 | def invalid_table() -> AnnData:
14 | """AnnData with invalid obs column names to test basic sanitization."""
15 | return AnnData(
16 | obs=pd.DataFrame(
17 | {
18 | "@invalid#": [1, 2],
19 | "valid_name": [3, 4],
20 | "__private": [5, 6],
21 | }
22 | )
23 | )
24 |
25 |
26 | @pytest.fixture
27 | def invalid_table_with_index() -> AnnData:
28 | """AnnData with a name requiring whitespace→underscore and a dataframe index column."""
29 | return AnnData(
30 | obs=pd.DataFrame(
31 | {
32 | "invalid name": [1, 2],
33 | "_index": [3, 4],
34 | }
35 | )
36 | )
37 |
38 |
39 | # -----------------------------------------------------------------------------
40 | # sanitize_name tests
41 | # -----------------------------------------------------------------------------
42 |
43 |
44 | @pytest.mark.parametrize(
45 | "raw,expected",
46 | [
47 | ("valid_name", "valid_name"),
48 | ("valid-name", "valid-name"),
49 | ("valid.name", "valid.name"),
50 | ("invalid@name", "invalid_name"),
51 | ("invalid#name", "invalid_name"),
52 | ("invalid name", "invalid_name"),
53 | ("", "unnamed"),
54 | (".", "unnamed"),
55 | ("..", "unnamed"),
56 | ("__", "_"),
57 | ("___", "_"),
58 | ("____#@$@", "_"),
59 | ("__private", "_private"),
60 | ],
61 | )
62 | def test_sanitize_name_strips_special_chars(raw, expected):
63 | assert sanitize_name(raw) == expected
64 |
65 |
66 | @pytest.mark.parametrize(
67 | "raw,is_df_col,expected",
68 | [
69 | ("_index", True, "index"),
70 | ("_index", False, "_index"),
71 | ("valid@column", True, "valid_column"),
72 | ("__private", True, "_private"),
73 | ],
74 | )
75 | def test_sanitize_name_dataframe_column(raw, is_df_col, expected):
76 | assert sanitize_name(raw, is_dataframe_column=is_df_col) == expected
77 |
78 |
79 | # -----------------------------------------------------------------------------
80 | # sanitize_table basic behaviors
81 | # -----------------------------------------------------------------------------
82 |
83 |
84 | def test_sanitize_table_basic_columns(invalid_table, invalid_table_with_index):
85 | ad1 = sanitize_table(invalid_table, inplace=False)
86 | assert isinstance(ad1, AnnData)
87 | assert list(ad1.obs.columns) == ["_invalid_", "valid_name", "_private"]
88 |
89 | ad2 = sanitize_table(invalid_table_with_index, inplace=False)
90 | assert list(ad2.obs.columns) == ["invalid_name", "index"]
91 |
92 | # original fixture remains unchanged
93 | assert list(invalid_table.obs.columns) == ["@invalid#", "valid_name", "__private"]
94 |
95 |
96 | def test_sanitize_table_inplace_copy(invalid_table):
97 | ad = invalid_table.copy()
98 | sanitize_table(ad) # inplace=True is now default
99 | assert list(ad.obs.columns) == ["_invalid_", "valid_name", "_private"]
100 |
101 |
102 | def test_sanitize_table_case_insensitive_collisions():
103 | obs = pd.DataFrame(
104 | {
105 | "Column1": [1, 2],
106 | "column1": [3, 4],
107 | "COLUMN1": [5, 6],
108 | }
109 | )
110 | ad = AnnData(obs=obs)
111 | sanitized = sanitize_table(ad, inplace=False)
112 | cols = list(sanitized.obs.columns)
113 | assert sorted(cols) == sorted(["Column1", "column1_1", "COLUMN1_2"])
114 |
115 |
116 | def test_sanitize_table_whitespace_collision():
117 | """Ensure 'a b' → 'a_b' doesn't collide silently with existing 'a_b'."""
118 | obs = pd.DataFrame({"a b": [1], "a_b": [2]})
119 | ad = AnnData(obs=obs)
120 | sanitized = sanitize_table(ad, inplace=False)
121 | cols = list(sanitized.obs.columns)
122 | assert "a_b" in cols
123 | assert "a_b_1" in cols
124 |
125 |
126 | # -----------------------------------------------------------------------------
127 | # sanitize_table attribute‐specific tests
128 | # -----------------------------------------------------------------------------
129 |
130 |
131 | def test_sanitize_table_obs_and_obs_columns():
132 | ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]}))
133 | sanitized = sanitize_table(ad, inplace=False)
134 | assert list(sanitized.obs.columns) == ["_col"]
135 |
136 |
137 | def test_sanitize_table_obsm_and_obsp():
138 | ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]}))
139 | ad.obsm["@col"] = np.array([[1, 2], [3, 4]])
140 | ad.obsp["bad name"] = np.array([[1, 2], [3, 4]])
141 | sanitized = sanitize_table(ad, inplace=False)
142 | assert list(sanitized.obsm.keys()) == ["_col"]
143 | assert list(sanitized.obsp.keys()) == ["bad_name"]
144 |
145 |
146 | def test_sanitize_table_varm_and_varp():
147 | ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"]))
148 | ad.varm["__priv"] = np.array([[1, 2], [3, 4]])
149 | ad.varp["_index"] = np.array([[1, 2], [3, 4]])
150 | sanitized = sanitize_table(ad, inplace=False)
151 | assert list(sanitized.varm.keys()) == ["_priv"]
152 | assert list(sanitized.varp.keys()) == ["_index"]
153 |
154 |
155 | def test_sanitize_table_uns_and_layers():
156 | ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"]))
157 | ad.uns["bad@key"] = "val"
158 | ad.layers["bad#layer"] = np.array([[0, 1], [1, 0]])
159 | sanitized = sanitize_table(ad, inplace=False)
160 | assert list(sanitized.uns.keys()) == ["bad_key"]
161 | assert list(sanitized.layers.keys()) == ["bad_layer"]
162 |
163 |
164 | def test_sanitize_table_empty_returns_empty():
165 | ad = AnnData()
166 | sanitized = sanitize_table(ad, inplace=False)
167 | assert isinstance(sanitized, AnnData)
168 | assert sanitized.obs.empty
169 | assert sanitized.var.empty
170 |
171 |
172 | def test_sanitize_table_preserves_underlying_data():
173 | ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]}))
174 | ad.obsm["@invalid#"] = np.array([[1, 2], [3, 4]])
175 | ad.uns["invalid@key"] = "value"
176 | sanitized = sanitize_table(ad, inplace=False)
177 | assert sanitized.obs["_invalid_"].tolist() == [1, 2]
178 | assert sanitized.obs["valid"].tolist() == [3, 4]
179 | assert np.array_equal(sanitized.obsm["_invalid_"], np.array([[1, 2], [3, 4]]))
180 | assert sanitized.uns["invalid_key"] == "value"
181 |
182 |
183 | # -----------------------------------------------------------------------------
184 | # SpatialData integration
185 | # -----------------------------------------------------------------------------
186 |
187 |
188 | def test_sanitize_table_in_spatialdata_sanitized_fixture(invalid_table, invalid_table_with_index):
189 | table1 = invalid_table.copy()
190 | table2 = invalid_table_with_index.copy()
191 | sanitize_table(table1)
192 | sanitize_table(table2)
193 | sdata_sanitized_tables = SpatialData(tables={"table1": table1, "table2": table2})
194 |
195 | t1 = sdata_sanitized_tables.tables["table1"]
196 | t2 = sdata_sanitized_tables.tables["table2"]
197 | assert list(t1.obs.columns) == ["_invalid_", "valid_name", "_private"]
198 | assert list(t2.obs.columns) == ["invalid_name", "index"]
199 |
200 |
201 | def test_spatialdata_retains_other_elements(full_sdata):
202 | # Add another sanitized table into an existing full_sdata
203 | tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]}))
204 | sanitize_table(tbl)
205 | full_sdata.tables["new_table"] = tbl
206 |
207 | # Verify columns and presence of other SpatialData attributes
208 | assert list(full_sdata.tables["new_table"].obs.columns) == ["_foo_", "bar"]
209 |
--------------------------------------------------------------------------------
/tests/utils/test_testing.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import pytest
5 | from xarray import DataArray, DataTree
6 |
7 | from spatialdata import SpatialData, deepcopy
8 | from spatialdata.models import (
9 | Image2DModel,
10 | Image3DModel,
11 | Labels2DModel,
12 | Labels3DModel,
13 | PointsModel,
14 | ShapesModel,
15 | TableModel,
16 | get_model,
17 | )
18 | from spatialdata.testing import assert_elements_are_identical, assert_spatial_data_objects_are_identical
19 | from spatialdata.transformations import Scale, set_transformation
20 |
21 | scale = Scale([1.0], axes=("x",))
22 |
23 |
24 | def _change_metadata_points(sdata: SpatialData, element_name: str, attrs: bool, transformations: bool) -> None:
25 | element = sdata[element_name]
26 | if attrs:
27 | # incorrect new values, just for them to be different from the original ones
28 | element.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY] = "a"
29 | element.attrs[PointsModel.ATTRS_KEY][PointsModel.INSTANCE_KEY] = "b"
30 | if transformations:
31 | set_transformation(element, copy.deepcopy(scale))
32 |
33 |
34 | def _change_metadata_shapes(sdata: SpatialData, element_name: str) -> None:
35 | set_transformation(sdata[element_name], copy.deepcopy(scale))
36 |
37 |
38 | def _change_metadata_tables(sdata: SpatialData, element_name: str) -> None:
39 | element = sdata[element_name]
40 | # incorrect new values, just for them to be different from the original ones
41 | element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = "circles"
42 | element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] = "a"
43 | element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "b"
44 |
45 |
46 | def _change_metadata_image(sdata: SpatialData, element_name: str, coords: bool, transformations: bool) -> None:
47 | if coords:
48 | if isinstance(sdata[element_name], DataArray):
49 | sdata[element_name] = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])})
50 | else:
51 | assert isinstance(sdata[element_name], DataTree)
52 |
53 | dt = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])})
54 | sdata[element_name] = dt
55 | if transformations:
56 | set_transformation(sdata[element_name], copy.deepcopy(scale))
57 |
58 |
59 | def _change_metadata_labels(sdata: SpatialData, element_name: str) -> None:
60 | set_transformation(sdata[element_name], copy.deepcopy(scale))
61 |
62 |
63 | def test_assert_elements_are_identical_metadata(full_sdata):
64 | assert_spatial_data_objects_are_identical(full_sdata, full_sdata)
65 |
66 | copied = deepcopy(full_sdata)
67 | assert_spatial_data_objects_are_identical(full_sdata, copied)
68 |
69 | to_iter = list(copied.gen_elements())
70 | for _, element_name, element in to_iter:
71 | if get_model(element) in (Image2DModel, Image3DModel):
72 | if not isinstance(copied[element_name], DataTree):
73 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
74 | _change_metadata_image(copied, element_name, coords=True, transformations=False)
75 | with pytest.raises(AssertionError):
76 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
77 | elif get_model(element) in (Labels2DModel, Labels3DModel):
78 | if not isinstance(copied[element_name], DataTree):
79 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
80 | _change_metadata_labels(copied, element_name)
81 | with pytest.raises(AssertionError):
82 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
83 | elif get_model(element) == PointsModel:
84 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
85 | _change_metadata_points(copied, element_name, attrs=True, transformations=False)
86 | with pytest.raises(AssertionError):
87 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
88 | elif get_model(element) == ShapesModel:
89 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
90 | _change_metadata_shapes(copied, element_name)
91 | with pytest.raises(AssertionError):
92 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
93 | else:
94 | assert get_model(element) == TableModel
95 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
96 | _change_metadata_tables(copied, element_name)
97 | with pytest.raises(AssertionError):
98 | assert_elements_are_identical(full_sdata[element_name], copied[element_name])
99 |
100 | with pytest.raises(AssertionError):
101 | assert_spatial_data_objects_are_identical(full_sdata, copied)
102 |
--------------------------------------------------------------------------------