├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── api │ ├── chunks.md │ ├── multiscale.md │ ├── reducers.md │ └── util.md └── index.md ├── mkdocs.yml ├── pyproject.toml ├── src └── xarray_multiscale │ ├── __about__.py │ ├── __init__.py │ ├── chunks.py │ ├── multiscale.py │ ├── py.typed │ ├── reducers.py │ └── util.py └── tests ├── __init__.py ├── test_chunks.py ├── test_docs.py ├── test_multiscale.py ├── test_reducers.py └── test_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | # Packages 4 | *.egg 5 | !/tests/**/*.egg 6 | /*.egg-info 7 | *.egg-info/ 8 | /dist/* 9 | build 10 | _build 11 | .cache 12 | *.so 13 | 14 | # Installer logs 15 | pip-log.txt 16 | 17 | # Unit test / coverage reports 18 | .coverage* 19 | .tox 20 | .pytest_cache 21 | 22 | .DS_Store 23 | .idea/* 24 | .python-version 25 | .vscode/* 26 | 27 | /test.py 28 | /test_*.* 29 | 30 | /setup.cfg 31 | MANIFEST.in 32 | /setup.py 33 | /docs/site/* 34 | /tests/fixtures/simple_project/setup.py 35 | /tests/fixtures/project_with_extras/setup.py 36 | .mypy_cache 37 | 38 | .venv 39 | /releases/* 40 | pip-wheel-metadata 41 | /poetry.toml 42 | 43 | # numba 44 | */__pycache__/* 45 | 46 | # docs 47 | site -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | autofix_prs: false 5 | default_stages: [commit, push] 6 | default_language_version: 7 | python: python3 8 | repos: 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: 'v0.5.4' 11 | hooks: 12 | - id: ruff 13 | args: ["--fix", "--show-fixes"] 14 | - id: ruff-format 15 | - repo: https://github.com/codespell-project/codespell 16 | rev: v2.3.0 17 | hooks: 18 | - id: codespell 19 | args: ["-L", "ba,ihs,kake,nd,noe,nwo,te,fo,zar", "-S", "fixture"] 20 | - repo: https://github.com/pre-commit/pre-commit-hooks 21 | rev: v4.6.0 22 | hooks: 23 | - id: check-yaml 24 | - repo: https://github.com/pre-commit/mirrors-mypy 25 | rev: v1.11.0 26 | hooks: 27 | - id: mypy 28 | files: src 29 | additional_dependencies: 30 | - numpy 31 | - typing_extensions 32 | # Tests 33 | - pytest -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Howard Hughes Medical Institute 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xarray-multiscale 2 | 3 | Simple tools for creating multiscale representations of large images. 4 | 5 | ## Installation 6 | 7 | `pip install xarray-multiscale` 8 | 9 | ## Motivation 10 | 11 | Many image processing applications benefit from representing images at multiple scales (also known as [image pyramids](https://en.wikipedia.org/wiki/Pyramid_(image_processing)). This package provides tools for generating lazy multiscale representations of N-dimensional data using [`xarray`](http://xarray.pydata.org/en/stable/) to ensure that the downsampled images have the correct coordinates. 12 | 13 | Why are coordinates important for this application? Because a downsampled image is typically scaled and *translated* relative to the source image. Without a coordinate-aware representation of the data, the scaling and translation information is easily lost. 14 | 15 | 16 | ## Usage 17 | 18 | Generate a multiscale representation of a numpy array: 19 | 20 | ```python 21 | from xarray_multiscale import multiscale, windowed_mean 22 | import numpy as np 23 | 24 | data = np.arange(4) 25 | print(*multiscale(data, windowed_mean, 2), sep='\n') 26 | """ 27 | Size: 32B 28 | array([0, 1, 2, 3]) 29 | Coordinates: 30 | * dim_0 (dim_0) float64 32B 0.0 1.0 2.0 3.0 31 | 32 | Size: 16B 33 | array([0, 2]) 34 | Coordinates: 35 | * dim_0 (dim_0) float64 16B 0.5 2.5 36 | """ 37 | ``` 38 | 39 | read more in the [project documentation](https://JaneliaSciComp.github.io/xarray-multiscale/). 40 | -------------------------------------------------------------------------------- /docs/api/chunks.md: -------------------------------------------------------------------------------- 1 | ::: xarray_multiscale.chunks -------------------------------------------------------------------------------- /docs/api/multiscale.md: -------------------------------------------------------------------------------- 1 | ::: xarray_multiscale.multiscale -------------------------------------------------------------------------------- /docs/api/reducers.md: -------------------------------------------------------------------------------- 1 | ::: xarray_multiscale.reducers -------------------------------------------------------------------------------- /docs/api/util.md: -------------------------------------------------------------------------------- 1 | ::: xarray_multiscale.util -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # xarray-multiscale 2 | 3 | Simple tools for creating multiscale representations of large images. 4 | 5 | ## Installation 6 | 7 | `pip install xarray-multiscale` 8 | 9 | ## Motivation 10 | 11 | Many image processing applications benefit from representing images at multiple scales (also known as [image pyramids](https://en.wikipedia.org/wiki/Pyramid_(image_processing)). This package provides tools for generating lazy multiscale representations of N-dimensional data using [`xarray`](http://xarray.pydata.org/en/stable/) to ensure that the downsampled images have the correct coordinates. 12 | 13 | ### Coordinates matter when you downsample images 14 | 15 | It's obvious that downsampling an image applies a scaling transformation, i.e. downsampling increases the distance between image samples. This is the whole purpose of downsampling the image. But it is less obvious that most downsampling operations also apply a *translation transformation* -- downsampling an image (generally) shifts the origin of the output relative to the source. 16 | 17 | In signal processing terms, image downsampling combines an image filtering step (blending local intensities) with a resampling step (sampling intensities at a set of positions in the signal). When you resample an image, you get to choose which points to resample on, and the best choice for most simple downsampling routines is to resample on points that are slightly translated relative to the original image. For simple windowed downsampling, this means that the first element of the downsampled image lies 18 | at the center (i.e., the mean) of the coordinate values of the window. 19 | 20 | We can illustrate this with some simple examples: 21 | 22 | ``` 23 | 2x windowed downsampling, in one dimension: 24 | 25 | source coordinates: | 0 | 1 | 26 | downsampled coordinates: | 0.5 | 27 | ``` 28 | 29 | ``` 30 | 3x windowed downsampling, in two dimensions: 31 | 32 | source coordinates: | (0,0) | (0,1) | (0,2) | 33 | | (1,0) | (1,1) | (1,2) | 34 | | (2,0) | (2,1) | (2,2) | 35 | 36 | downsampled coordinates: | | 37 | | (1,1) | 38 | | | 39 | 40 | ``` 41 | 42 | Another way of thinking about this is that if you downsample an arbitrarily large image to a single value, then the only sensible place to localize that value is at the center of the image. Thus, incrementally downsampling slightly shifts the downsampled image toward that point. 43 | 44 | Why should you care? If you work with images where the coordinates matter (for example, images recorded from scientific instruments), then you should care about keeping track of those coordinates. Tools like numpy or scikit-image make it very easy to ignore the coordinates of your image. These tools model images as simple arrays, and from the array perspective `data[0,0]` and `downsampled_data[0,0]` lie on the same position in space because they take the same array index. However, `downsampled_data[0,0]` is almost certainly shifted relative to `data[0,0]`. Coordinate-blind tools like `scikit-image` force you to track the coordinates on your own, which is a recipe for mistakes. This is the value of `xarray`. By explicitly modelling coordinates alongside data values, `xarray` ensures that you never lose track of where your data comes from, which is why `xarray-multiscale` uses it. 45 | 46 | ### Who needs this 47 | 48 | The library `xarray` already supports basic downsampling routines via the [`DataArray.coarsen`](https://docs.xarray.dev/en/stable/user-guide/computation.html#coarsen-large-arrays) API. So if you use `xarray` and just need to compute a windowed mean, then you may not need `xarray-multiscale` at all. But the `DataArray.coarsen` API does not 49 | allow users to provide their own downsampling functions; If you need something like [windowed mode](./api/reducers.md#xarray_multiscale.reducers.windowed_mode) downsampling, or something you wrote yourself, then `xarray-multiscale` should be useful to you. 50 | 51 | 52 | ## Usage 53 | 54 | Generate a multiscale representation of a numpy array: 55 | 56 | ```python 57 | from xarray_multiscale import multiscale, windowed_mean 58 | import numpy as np 59 | 60 | data = np.arange(4) 61 | print(*multiscale(data, windowed_mean, 2), sep='\n') 62 | """ 63 | Size: 32B 64 | array([0, 1, 2, 3]) 65 | Coordinates: 66 | * dim_0 (dim_0) float64 32B 0.0 1.0 2.0 3.0 67 | 68 | Size: 16B 69 | array([0, 2]) 70 | Coordinates: 71 | * dim_0 (dim_0) float64 16B 0.5 2.5 72 | """ 73 | ``` 74 | 75 | 76 | By default, the values of the downsampled arrays are cast to the same data type as the input. This behavior can be changed with the ``preserve_dtype`` keyword argument to ``multiscale``: 77 | 78 | ```python 79 | from xarray_multiscale import multiscale, windowed_mean 80 | import numpy as np 81 | 82 | data = np.arange(4) 83 | print(*multiscale(data, windowed_mean, 2, preserve_dtype=False), sep="\n") 84 | """ 85 | Size: 32B 86 | array([0, 1, 2, 3]) 87 | Coordinates: 88 | * dim_0 (dim_0) float64 32B 0.0 1.0 2.0 3.0 89 | 90 | Size: 16B 91 | array([0.5, 2.5]) 92 | Coordinates: 93 | * dim_0 (dim_0) float64 16B 0.5 2.5 94 | """ 95 | ``` 96 | 97 | Anisotropic downsampling is supported: 98 | 99 | ```python 100 | from xarray_multiscale import multiscale, windowed_mean 101 | import numpy as np 102 | 103 | data = np.arange(16).reshape((4,4)) 104 | print(*multiscale(data, windowed_mean, (1,2)), sep="\n") 105 | """ 106 | Size: 128B 107 | array([[ 0, 1, 2, 3], 108 | [ 4, 5, 6, 7], 109 | [ 8, 9, 10, 11], 110 | [12, 13, 14, 15]]) 111 | Coordinates: 112 | * dim_0 (dim_0) float64 32B 0.0 1.0 2.0 3.0 113 | * dim_1 (dim_1) float64 32B 0.0 1.0 2.0 3.0 114 | 115 | Size: 64B 116 | array([[ 0, 2], 117 | [ 4, 6], 118 | [ 8, 10], 119 | [12, 14]]) 120 | Coordinates: 121 | * dim_0 (dim_0) float64 32B 0.0 1.0 2.0 3.0 122 | * dim_1 (dim_1) float64 16B 0.5 2.5 123 | """ 124 | ``` 125 | 126 | 127 | Note that `multiscale` returns an `xarray.DataArray`. 128 | The `multiscale` function also accepts `DataArray` objects: 129 | 130 | ```python 131 | from xarray_multiscale import multiscale, windowed_mean 132 | from xarray import DataArray 133 | import numpy as np 134 | 135 | data = np.arange(16).reshape((4,4)) 136 | coords = (DataArray(np.arange(data.shape[0]), dims=('y',), attrs={'units' : 'm'}), 137 | DataArray(np.arange(data.shape[0]), dims=('x',), attrs={'units' : 'm'})) 138 | 139 | arr = DataArray(data, coords) 140 | print(*multiscale(arr, windowed_mean, (2,2)), sep="\n") 141 | """ 142 | Size: 128B 143 | array([[ 0, 1, 2, 3], 144 | [ 4, 5, 6, 7], 145 | [ 8, 9, 10, 11], 146 | [12, 13, 14, 15]]) 147 | Coordinates: 148 | * y (y) int64 32B 0 1 2 3 149 | * x (x) int64 32B 0 1 2 3 150 | 151 | Size: 32B 152 | array([[ 2, 4], 153 | [10, 12]]) 154 | Coordinates: 155 | * y (y) float64 16B 0.5 2.5 156 | * x (x) float64 16B 0.5 2.5 157 | """ 158 | ``` 159 | 160 | Dask arrays work too. Note the control over output chunks via the ``chunks`` keyword argument. 161 | 162 | ```python 163 | from xarray_multiscale import multiscale, windowed_mean 164 | import dask.array as da 165 | 166 | arr = da.random.randint(0, 255, (10,10,10)) 167 | print(*multiscale(arr, windowed_mean, 2, chunks=2), sep="\n") 168 | """ 169 | Size: 8kB 170 | dask.array 171 | Coordinates: 172 | * dim_0 (dim_0) float64 80B 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 173 | * dim_1 (dim_1) float64 80B 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 174 | * dim_2 (dim_2) float64 80B 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 175 | 176 | Size: 1kB 177 | dask.array 178 | Coordinates: 179 | * dim_0 (dim_0) float64 40B 0.5 2.5 4.5 6.5 8.5 180 | * dim_1 (dim_1) float64 40B 0.5 2.5 4.5 6.5 8.5 181 | * dim_2 (dim_2) float64 40B 0.5 2.5 4.5 6.5 8.5 182 | 183 | Size: 64B 184 | dask.array 185 | Coordinates: 186 | * dim_0 (dim_0) float64 16B 1.5 5.5 187 | * dim_1 (dim_1) float64 16B 1.5 5.5 188 | * dim_2 (dim_2) float64 16B 1.5 5.5 189 | """ 190 | ``` 191 | 192 | ### Caveats 193 | 194 | - Arrays that are not evenly divisible by the downsampling factors will be trimmed as needed. If this behavior is undesirable, consider padding your array appropriately prior to downsampling. 195 | - For chunked arrays (e.g., dask arrays), the current implementation divides the input data into *contiguous* chunks. This means that attempting to use downsampling schemes based on sliding windowed smoothing will produce edge artifacts. 196 | - The [`multiscale`](api/multiscale/#xarray_multiscale.multiscale.multiscale) function in this library stops downsampling when a subqsequent downsampling operation would create an array where any axis has length 1. This is because I work with bioimaging data, and for bioimaging data we often want to represent the coordinates of an array with a translation transformation and a scaling transformation. But scaling and translation cannot both be estimated for a single point, so `multiscale` avoids creating arrays with singleton dimensions. 197 | 198 | ### Development 199 | 200 | This project is developed using [`hatch`](https://hatch.pypa.io/latest/). 201 | Run tests with `hatch run test:pytest`. 202 | Serve docs with `hatch run docs:serve`. 203 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: xarray-ome-ngff 2 | site_url: https://janeliascicomp.github.io/xarray-multiscale/ 3 | site_author: Davis Bennett 4 | site_description: >- 5 | Multiscale images via Xarray. 6 | 7 | # Repository 8 | repo_name: janeliascicomp/xarray-multiscale 9 | repo_url: https://github.com/janeliascicomp/xarray-multiscale 10 | 11 | # Copyright 12 | copyright: Copyright © 2016 - 2024 HHMI / Janelia 13 | watch: [src] 14 | theme: 15 | features: 16 | - navigation.expand 17 | - content.code.annotate 18 | name: material 19 | palette: 20 | # Palette toggle for light mode 21 | - scheme: default 22 | toggle: 23 | icon: material/brightness-7 24 | name: Switch to dark mode 25 | 26 | # Palette toggle for dark mode 27 | - scheme: slate 28 | toggle: 29 | icon: material/brightness-4 30 | name: Switch to light mode 31 | 32 | nav: 33 | - About: index.md 34 | - API: 35 | - multiscale: api/multiscale.md 36 | - reducers: api/reducers.md 37 | - chunks: api/chunks.md 38 | - util: api/util.md 39 | 40 | plugins: 41 | - mkdocstrings: 42 | handlers: 43 | python: 44 | options: 45 | docstring_style: numpy 46 | members_order: source 47 | separate_signature: true 48 | filters: ["!^_"] 49 | docstring_options: 50 | ignore_init_summary: true 51 | merge_init_into_class: true 52 | 53 | markdown_extensions: 54 | - pymdownx.highlight: 55 | anchor_linenums: true 56 | line_spans: __span 57 | pygments_lang_class: true 58 | - pymdownx.inlinehilite 59 | - pymdownx.snippets 60 | - pymdownx.superfences 61 | - toc: 62 | baselevel: 2 63 | toc_depth: 4 64 | permalink: "#" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "xarray-multiscale" 7 | dynamic = ["version"] 8 | description = '' 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = "MIT" 12 | keywords = [] 13 | authors = [ 14 | { name = "Davis Vann Bennett", email = "davis.v.bennett@gmail.com" }, 15 | ] 16 | classifiers = [ 17 | "Development Status :: 4 - Beta", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Programming Language :: Python :: Implementation :: PyPy", 25 | ] 26 | dependencies = [ 27 | "xarray >=2022.03.0", 28 | "scipy >=1.5.4", 29 | "numpy >=1.19.4", 30 | "dask >=2020.12.0" 31 | ] 32 | 33 | [project.urls] 34 | Documentation = "https://github.com/janelia-scicomp/xarray-multiscale#readme" 35 | Issues = "https://github.com/janelia-scicomp/xarray-multiscale/issues" 36 | Source = "https://github.com/janelia-scicomp/xarray-multiscale" 37 | 38 | [tool.hatch.version] 39 | path = "src/xarray_multiscale/__about__.py" 40 | 41 | [tool.hatch.envs.test] 42 | dependencies = [ 43 | "coverage", 44 | "pytest", 45 | "pytest-cov", 46 | "pytest-examples == 0.0.12" 47 | ] 48 | 49 | [[tool.hatch.envs.test.matrix]] 50 | python = ["3.9", "3.10", "3.11", "3.12"] 51 | 52 | [tool.hatch.envs.docs] 53 | dependencies = [ 54 | "mkdocs-material == 9.5.30", 55 | "mkdocstrings[python] == 0.25.1", 56 | ] 57 | 58 | [tool.hatch.envs.types] 59 | extra-dependencies = [ 60 | "mypy>=1.0.0", 61 | ] 62 | [tool.hatch.envs.types.scripts] 63 | check = "mypy --install-types --non-interactive {args:src/xarray_multiscale tests}" 64 | 65 | [tool.coverage.run] 66 | source_pkgs = ["xarray_multiscale", "tests"] 67 | branch = true 68 | parallel = true 69 | omit = [ 70 | "src/xarray_multiscale/__about__.py", 71 | ] 72 | 73 | [tool.coverage.paths] 74 | xarray_multiscale = ["src/xarray_multiscale", "*/xarray-multiscale/src/xarray_multiscale"] 75 | tests = ["tests", "*/xarray-multiscale/tests"] 76 | 77 | [tool.coverage.report] 78 | exclude_lines = [ 79 | "no cov", 80 | "if __name__ == .__main__.:", 81 | "if TYPE_CHECKING:", 82 | ] 83 | 84 | [tool.ruff] 85 | line-length = 100 86 | src = ["src"] 87 | force-exclude = true 88 | extend-exclude = [ 89 | ".bzr", 90 | ".direnv", 91 | ".eggs", 92 | ".git", 93 | ".mypy_cache", 94 | ".nox", 95 | ".pants.d", 96 | ".ruff_cache", 97 | ".venv", 98 | "__pypackages__", 99 | "_build", 100 | "buck-out", 101 | "build", 102 | "dist", 103 | "venv", 104 | "docs", 105 | ] 106 | 107 | [tool.ruff.lint] 108 | extend-select = [ 109 | "B", # flake8-bugbear 110 | "I", # isort 111 | "ISC", 112 | "UP", # pyupgrade 113 | "RSE", 114 | "RUF", 115 | ] 116 | ignore = [ 117 | "RUF005", 118 | ] 119 | 120 | [tool.mypy] 121 | python_version = "3.10" 122 | ignore_missing_imports = true 123 | namespace_packages = false 124 | 125 | strict = true 126 | warn_unreachable = true 127 | 128 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] -------------------------------------------------------------------------------- /src/xarray_multiscale/__about__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Davis Vann Bennett 2 | # 3 | # SPDX-License-Identifier: MIT 4 | __version__ = "1.1.1" 5 | -------------------------------------------------------------------------------- /src/xarray_multiscale/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .multiscale import downscale, multiscale 4 | from .reducers import ( 5 | windowed_max, 6 | windowed_mean, 7 | windowed_min, 8 | windowed_rank, 9 | ) 10 | 11 | __all__ = [ 12 | "downscale", 13 | "multiscale", 14 | "windowed_mean", 15 | "windowed_mode", 16 | "windowed_max", 17 | "windowed_min", 18 | "windowed_rank", 19 | ] 20 | -------------------------------------------------------------------------------- /src/xarray_multiscale/chunks.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Hashable, cast 4 | 5 | if TYPE_CHECKING: 6 | from typing import Sequence 7 | 8 | import dask.array as da 9 | import toolz as tz 10 | import xarray 11 | from dask.array.routines import aligned_coarsen_chunks 12 | from xarray.core.utils import is_dict_like 13 | 14 | 15 | def normalize_chunks( 16 | array: xarray.DataArray, chunk_size: str | int | Sequence[int] | dict[Hashable, int] 17 | ) -> dict[Hashable, int]: 18 | """ 19 | Given an `xarray.DataArray`, normalize a chunk size against that array. 20 | 21 | Parameters 22 | ---------- 23 | array: xarray.DataArray 24 | An `xarray.DataArray`. 25 | chunk_size: Union[str, int, Sequence[int], dict[Hashable, int]] 26 | A specification of a chunk size. 27 | 28 | Returns 29 | ------- 30 | dict[Hashable, int] 31 | An xarray-compatible specification of chunk sizes. 32 | """ 33 | _chunk_size: str | int | Sequence[int] | dict[Hashable, int] 34 | if not isinstance(chunk_size, (int, str, dict)): 35 | if len(chunk_size) != array.ndim: 36 | msg = msg = f"Incorrect number of chunks. Got {len(chunk_size)}, expected {array.ndim}." 37 | raise ValueError(msg) 38 | 39 | if is_dict_like(chunk_size): 40 | # dask's normalize chunks routine assumes dict inputs have integer 41 | # keys, so convert dim names to the corresponding integers 42 | chunk_size = cast(dict[Hashable, int], chunk_size) 43 | if len(chunk_size.keys() - set(array.dims)) > 0: 44 | extra: set[Hashable] = chunk_size.keys() - set(array.dims) 45 | msg = f"Keys of chunksize must be a subset of array dims. Got extraneous keys: {extra}." 46 | raise ValueError(msg) 47 | _chunk_size = dict(zip(range(array.ndim), map(tz.first, array.chunks))) 48 | _chunk_size.update({array.get_axis_num(d): c for d, c in chunk_size.items()}) 49 | else: 50 | _chunk_size = chunk_size 51 | 52 | new_chunks: tuple[int, ...] = tuple( 53 | map( 54 | tz.first, 55 | da.core.normalize_chunks( 56 | _chunk_size, 57 | array.shape, 58 | dtype=array.dtype, 59 | previous_chunks=array.data.chunksize, 60 | ), 61 | ) 62 | ) 63 | 64 | return {dim: new_chunks[array.get_axis_num(dim)] for dim in array.dims} 65 | 66 | 67 | def align_chunks(array: da.core.Array, scale_factors: Sequence[int]) -> da.core.Array: 68 | """ 69 | Ensure that all chunks of a dask array are divisible by scale_factors, rechunking the array 70 | if necessary. 71 | """ 72 | new_chunks = {} 73 | for idx, factor in enumerate(scale_factors): 74 | aligned = aligned_coarsen_chunks(array.chunks[idx], factor) 75 | if aligned != array.chunks[idx]: 76 | new_chunks[idx] = aligned 77 | if new_chunks: 78 | array = array.rechunk(new_chunks) 79 | return array 80 | -------------------------------------------------------------------------------- /src/xarray_multiscale/multiscale.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Literal 4 | 5 | if TYPE_CHECKING: 6 | from typing import Any, Callable, Hashable, Sequence, TypeAlias 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | import xarray 11 | from dask.array.core import Array 12 | from dask.base import tokenize 13 | from dask.core import flatten 14 | from dask.highlevelgraph import HighLevelGraph 15 | from dask.utils import apply 16 | 17 | from xarray_multiscale.chunks import align_chunks, normalize_chunks 18 | from xarray_multiscale.reducers import WindowedReducer 19 | from xarray_multiscale.util import adjust_shape, broadcast_to_rank, logn 20 | 21 | ChunkOption: TypeAlias = Literal["preserve", "auto"] 22 | 23 | 24 | def _default_namer(idx: int) -> str: 25 | """ 26 | The default naming function. Takes an integer index and prepends "s" in front of it. 27 | """ 28 | return f"s{idx}" 29 | 30 | 31 | def multiscale( 32 | array: npt.NDArray[Any], 33 | reduction: WindowedReducer, 34 | scale_factors: Sequence[int] | int, 35 | preserve_dtype: bool = True, 36 | chunks: ChunkOption | Sequence[int] | dict[Hashable, int] = "preserve", 37 | chained: bool = True, 38 | namer: Callable[[int], str] = _default_namer, 39 | **kwargs: Any, 40 | ) -> list[xarray.DataArray]: 41 | """ 42 | Generate a coordinate-aware multiscale representation of an array. 43 | 44 | Parameters 45 | ---------- 46 | array : Array-like, e.g. Numpy array, Dask array 47 | The array to be downscaled. 48 | 49 | reduction : callable 50 | A function that aggregates chunks of data over windows. 51 | See `xarray_multiscale.reducers.WindowedReducer` for the expected 52 | signature of this callable. 53 | 54 | scale_factors : int or sequence of ints 55 | The desired downscaling factors, one for each axis, or a single 56 | value for all axes. 57 | 58 | preserve_dtype : bool, default=True 59 | If True, output arrays are all cast to the same data type as the 60 | input array. If False, output arrays will have data type determined 61 | by the output of the reduction function. 62 | 63 | chunks : sequence or dict of ints, or the string "preserve" (default) 64 | Set the chunking of the output arrays. Applies only to dask arrays. 65 | If `chunks` is set to "preserve" (the default), then chunk sizes will 66 | decrease with each level of downsampling. Otherwise, this argument is 67 | passed to `xarray_multiscale.chunks.normalize_chunks`. 68 | 69 | Otherwise, this keyword argument will be passed to the 70 | `xarray.DataArray.chunk` method for each output array, 71 | producing a list of arrays with the same chunk size. 72 | Note that rechunking can be computationally expensive 73 | for arrays with many chunks. 74 | 75 | chained : bool, default=True 76 | If True (default), the nth downscaled array is generated by 77 | applying the reduction function on the n-1th downscaled array with 78 | the user-supplied `scale_factors`. This means that the nth 79 | downscaled array directly depends on the n-1th downscaled array. 80 | Note that nonlinear reductions like the windowed mode may give 81 | inaccurate results with `chained` set to True. 82 | 83 | If False, the nth downscaled array is generated by applying the 84 | reduction function on the 0th downscaled array 85 | (i.e., the input array) with the `scale_factors` raised to the nth 86 | power. This means that the nth downscaled array directly depends 87 | on the input array. 88 | 89 | namer : callable, defaults to `_default_namer` 90 | A function for naming the output arrays. This function should take an integer 91 | index and return a string. The default function simply prepends the string 92 | representation of the integer with the character "s". 93 | 94 | **kwargs: Any 95 | Additional keyword arguments that will be passed to the reduction function. 96 | 97 | Returns 98 | ------- 99 | result : list[xarray.DataArray] 100 | The first element of this list is the input array, converted to an 101 | `xarray.DataArray`. Each subsequent element of the list is 102 | the result of downsampling the previous element of the list. 103 | 104 | The `coords` attributes of these DataArrays track the changing 105 | offset and scale induced by the downsampling operation. 106 | 107 | Examples 108 | -------- 109 | >>> import numpy as np 110 | >>> from xarray_multiscale import multiscale 111 | >>> from xarray_multiscale.reducers import windowed_mean 112 | >>> multiscale(np.arange(4), windowed_mean, 2) 113 | [ 114 | array([0, 1, 2, 3]) 115 | Coordinates: 116 | * dim_0 (dim_0) float64 0.0 1.0 2.0 3.0, 117 | array([0, 2]) 118 | Coordinates: 119 | * dim_0 (dim_0) float64 0.5 2.5] 120 | """ 121 | scale_factors = broadcast_to_rank(scale_factors, array.ndim) 122 | darray = to_dataarray(array, name=namer(0)) 123 | 124 | levels = range(1, downsampling_depth(darray.shape, scale_factors)) 125 | 126 | result: list[xarray.DataArray] = [darray] 127 | for level in levels: 128 | if chained: 129 | scale = scale_factors 130 | source = result[-1] 131 | else: 132 | scale = tuple(s**level for s in scale_factors) 133 | source = result[0] 134 | downscaled = downscale(source, reduction, scale, preserve_dtype, **kwargs) 135 | downscaled.name = namer(level) 136 | result.append(downscaled) 137 | 138 | if darray.chunks is not None and chunks != "preserve": 139 | new_chunks = [normalize_chunks(r, chunks) for r in result] 140 | result = [r.chunk(ch) for r, ch in zip(result, new_chunks)] 141 | 142 | return result 143 | 144 | 145 | def to_dataarray(array: Any, name: str | None = None) -> xarray.DataArray: 146 | """ 147 | Convert the input to an `xarray.DataArray` if it is not already one. 148 | """ 149 | if isinstance(array, xarray.DataArray): 150 | data = array.data 151 | dims = array.dims 152 | # ensure that key order matches dimension order 153 | coords = {d: array.coords[d] for d in dims} 154 | attrs = array.attrs 155 | else: 156 | data = array 157 | dims = tuple(f"dim_{d}" for d in range(data.ndim)) 158 | coords = { 159 | dim: xarray.DataArray(np.arange(shape, dtype="float"), dims=dim) 160 | for dim, shape in zip(dims, array.shape) 161 | } 162 | attrs = {} 163 | 164 | result = xarray.DataArray(data=data, coords=coords, dims=dims, attrs=attrs, name=name) 165 | return result 166 | 167 | 168 | def downscale_dask( 169 | array: Any, 170 | reduction: WindowedReducer, 171 | scale_factors: Sequence[int], 172 | **kwargs: Any, 173 | ) -> Any: 174 | """ 175 | Downscale a dask array. 176 | """ 177 | if not np.all((np.array(array.shape) % np.array(scale_factors)) == 0): 178 | msg = f"Coarsening factors {scale_factors} do not align with array shape {array.shape}." 179 | raise ValueError(msg) 180 | 181 | array = align_chunks(array, scale_factors) 182 | name: str = "downscale-" + tokenize(reduction, array, scale_factors) 183 | dsk = { 184 | (name,) + key[1:]: (apply, reduction, [key, scale_factors], kwargs) 185 | for key in flatten(array.__dask_keys__()) 186 | } 187 | chunks = tuple( 188 | tuple(int(size // scale_factors[axis]) for size in sizes) 189 | for axis, sizes in enumerate(array.chunks) 190 | ) 191 | 192 | meta = reduction(np.empty(scale_factors, dtype=array.dtype), scale_factors, **kwargs) 193 | graph = HighLevelGraph.from_collections(name, dsk, dependencies=[array]) 194 | return Array(graph, name, chunks, meta=meta) 195 | 196 | 197 | def downscale( 198 | array: xarray.DataArray, 199 | reduction: WindowedReducer, 200 | scale_factors: Sequence[int], 201 | preserve_dtype: bool = True, 202 | **kwargs: Any, 203 | ) -> xarray.DataArray: 204 | to_downscale = adjust_shape(array, scale_factors) 205 | if to_downscale.chunks is not None: 206 | downscaled_data = downscale_dask(to_downscale.data, reduction, scale_factors, **kwargs) 207 | else: 208 | downscaled_data = reduction(to_downscale.data, scale_factors, **kwargs) 209 | if preserve_dtype: 210 | downscaled_data = downscaled_data.astype(array.dtype) 211 | downscaled_coords = downscale_coords(to_downscale, scale_factors) 212 | return xarray.DataArray(downscaled_data, downscaled_coords, attrs=array.attrs, dims=array.dims) 213 | 214 | 215 | def downscale_coords(array: xarray.DataArray, scale_factors: Sequence[int]) -> dict[Hashable, Any]: 216 | """ 217 | Downscale coordinates by taking the windowed mean of each coordinate array. 218 | """ 219 | new_coords = {} 220 | for ( 221 | coord_name, 222 | coord, 223 | ) in array.coords.items(): 224 | coarsening_dims = { 225 | d: scale_factors[idx] for idx, d in enumerate(array.dims) if d in coord.dims 226 | } 227 | new_coords[coord_name] = coord.coarsen(coarsening_dims).mean() 228 | return new_coords 229 | 230 | 231 | def downsampling_depth(shape: Sequence[int], scale_factors: Sequence[int]) -> int: 232 | """ 233 | For a shape and a sequence of scale factors, calculate the 234 | number of downsampling operations that must be performed to produce 235 | a downsampled shape with at least one singleton value. 236 | 237 | If any element of `scale_factors` is greater than the 238 | corresponding shape, this function returns 0. 239 | 240 | If all `scale_factors` are 1, this function returns 0. 241 | 242 | Parameters 243 | ---------- 244 | shape: Sequence[int] 245 | An array shape. 246 | 247 | scale_factors : Sequence[int] 248 | Downsampling factors. 249 | 250 | Examples 251 | -------- 252 | >>> downsampling_depth((8,), (2,)) 253 | 3 254 | >>> downsampling_depth((8,2), (2,2)) 255 | 1 256 | >>> downsampling_depth((7,), (2,)) 257 | 2 258 | """ 259 | if len(shape) != len(scale_factors): 260 | msg = ( 261 | "The shape and scale_factors parameters do not have the same length." 262 | f"Shape={shape} has length {len(shape)}, " 263 | f"but scale_factors={scale_factors} has length {len(scale_factors)}" 264 | ) 265 | raise ValueError(msg) 266 | 267 | _scale_factors = np.array(scale_factors).astype("int") 268 | _shape = np.array(shape).astype("int") 269 | valid = _scale_factors > 1 270 | if not valid.any(): 271 | result = 0 272 | else: 273 | depths = np.floor(logn(_shape[valid], _scale_factors[valid])) 274 | result = min(depths.astype("int")) 275 | return result 276 | -------------------------------------------------------------------------------- /src/xarray_multiscale/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaneliaSciComp/xarray-multiscale/2282409c2c9aac9f80b6c09073bd1310683fd012/src/xarray_multiscale/py.typed -------------------------------------------------------------------------------- /src/xarray_multiscale/reducers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Protocol 4 | 5 | if TYPE_CHECKING: 6 | from typing import Any, Sequence 7 | 8 | import numpy.typing as npt 9 | 10 | import math 11 | from functools import reduce 12 | from itertools import combinations 13 | 14 | import numpy as np 15 | from scipy.stats import mode 16 | 17 | 18 | class WindowedReducer(Protocol): 19 | def __call__( 20 | self, array: npt.NDArray[Any], window_size: Sequence[int], **kwargs: Any 21 | ) -> npt.NDArray[Any]: ... 22 | 23 | 24 | def reshape_windowed(array: npt.NDArray[Any], window_size: tuple[int, ...]) -> npt.NDArray[Any]: 25 | """ 26 | Reshape an array to support windowed operations. New 27 | dimensions will be added to the array, one for each element of 28 | `window_size`. 29 | 30 | Parameters 31 | ---------- 32 | array: Array-like, e.g. Numpy array, Dask array 33 | The array to be reshaped. The array must have a ``reshape`` method. 34 | 35 | window_size: Tuple of ints 36 | The window size. The length of ``window_size`` must match the 37 | dimensionality of ``array``. 38 | 39 | Returns 40 | ------- 41 | The input array reshaped with extra dimensions. 42 | E.g., for an ``array`` with shape ``(10, 2)``, 43 | ``reshape_windowed(array, (2, 2))`` returns 44 | output with shape ``(5, 2, 1, 2)``. 45 | 46 | Examples 47 | -------- 48 | >>> import numpy as np 49 | >>> from xarray_multiscale.reducers import reshape_windowed 50 | >>> data = np.arange(12).reshape(3, 4) 51 | >>> reshaped = reshape_windowed(data, (1, 2)) 52 | >>> reshaped.shape 53 | (3, 1, 2, 2) 54 | """ 55 | if len(window_size) != array.ndim: 56 | raise ValueError( 57 | f"""Length of window_size must match array dimensionality. 58 | Got {len(window_size)}, expected {array.ndim}""" 59 | ) 60 | new_shape: tuple[int, ...] = () 61 | for s, f in zip(array.shape, window_size): 62 | new_shape += (s // f, f) 63 | return array.reshape(new_shape) 64 | 65 | 66 | def windowed_mean( 67 | array: npt.NDArray[Any], window_size: tuple[int, ...], **kwargs: Any 68 | ) -> npt.NDArray[Any]: 69 | """ 70 | Compute the windowed mean of an array. 71 | 72 | Parameters 73 | ---------- 74 | array: Array-like, e.g. Numpy array, Dask array 75 | The array to be downscaled. The array must have 76 | ``reshape`` and ``mean`` methods that obey the 77 | ``np.reshape`` and ``np.mean`` APIs. 78 | 79 | window_size: Tuple of ints 80 | The window to use for aggregations. The array is partitioned into 81 | non-overlapping regions with size equal to ``window_size``, and the 82 | values in each window are aggregated to generate the result. 83 | 84 | **kwargs: dict, optional 85 | Extra keyword arguments passed to ``array.mean`` 86 | 87 | Returns 88 | ------- 89 | Array-like 90 | The result of the windowed mean. The length of each axis of this array 91 | will be a fraction of the input. The datatype is determined by the 92 | behavior of ``array.mean`` given the kwargs (if any) passed to it. 93 | 94 | Notes 95 | ----- 96 | This function works by first reshaping the array to have an extra 97 | axis per element of ``window_size``, then computing the 98 | mean along those extra axes. 99 | 100 | See ``xarray_multiscale.reductions.reshape_windowed`` for the 101 | implementation of the array reshaping routine. 102 | 103 | Examples 104 | -------- 105 | >>> import numpy as np 106 | >>> from xarray_multiscale.reducers import windowed_mean 107 | >>> data = np.arange(16).reshape(4, 4) 108 | >>> windowed_mean(data, (2, 2)) 109 | array([[ 2.5, 4.5], 110 | [10.5, 12.5]]) 111 | """ 112 | reshaped = reshape_windowed(array, window_size) 113 | result: npt.NDArray[Any] = reshaped.mean(axis=tuple(range(1, reshaped.ndim, 2)), **kwargs) 114 | return result 115 | 116 | 117 | def windowed_max( 118 | array: npt.NDArray[Any], window_size: tuple[int, ...], **kwargs: Any 119 | ) -> npt.NDArray[Any]: 120 | """ 121 | Compute the windowed maximum of an array. 122 | 123 | Parameters 124 | ---------- 125 | array: Array-like, e.g. Numpy array, Dask array 126 | The array to be downscaled. The array must have ``reshape`` and 127 | ``max`` methods. 128 | 129 | window_size: Tuple of ints 130 | The window to use for aggregations. The array is partitioned into 131 | non-overlapping regions with size equal to ``window_size``, and the 132 | values in each window are aggregated to generate the result. 133 | 134 | **kwargs: dict, optional 135 | Extra keyword arguments passed to ``array.mean`` 136 | 137 | Returns 138 | ------- 139 | Array-like 140 | The result of the windowed max. The length of each axis of this array 141 | will be a fraction of the input. The datatype of the return value will 142 | will be the same as the input. 143 | 144 | Notes 145 | ----- 146 | This function works by first reshaping the array to have an extra 147 | axis per element of ``window_size``, then computing the 148 | max along those extra axes. 149 | 150 | See ``xarray_multiscale.reductions.reshape_windowed`` for 151 | the implementation of the array reshaping routine. 152 | 153 | Examples 154 | -------- 155 | >>> import numpy as np 156 | >>> from xarray_multiscale.reducers import windowed_mean 157 | >>> data = np.arange(16).reshape(4, 4) 158 | >>> windowed_max(data, (2, 2)) 159 | array([[ 5, 7], 160 | [13, 15]]) 161 | """ 162 | reshaped = reshape_windowed(array, window_size) 163 | result: npt.NDArray[Any] = reshaped.max(axis=tuple(range(1, reshaped.ndim, 2)), **kwargs) 164 | return result 165 | 166 | 167 | def windowed_min( 168 | array: npt.NDArray[Any], window_size: tuple[int, ...], **kwargs: Any 169 | ) -> npt.NDArray[Any]: 170 | """ 171 | Compute the windowed minimum of an array. 172 | 173 | Parameters 174 | ---------- 175 | array: Array-like, e.g. Numpy array, Dask array 176 | The array to be downscaled. The array must have ``reshape`` and 177 | ``min`` methods. 178 | 179 | window_size: Tuple of ints 180 | The window to use for aggregations. The array is partitioned into 181 | non-overlapping regions with size equal to ``window_size``, and the 182 | values in each window are aggregated to generate the result. 183 | 184 | **kwargs: dict, optional 185 | Extra keyword arguments passed to ``array.mean`` 186 | 187 | Returns 188 | ------- 189 | Array-like 190 | The result of the windowed min. The length of each axis of this array 191 | will be a fraction of the input. The datatype of the return value will 192 | will be the same as the input. 193 | 194 | Notes 195 | ----- 196 | This function works by first reshaping the array to have an extra 197 | axis per element of ``window_size``, then computing the 198 | min along those extra axes. 199 | 200 | See ``xarray_multiscale.reductions.reshape_windowed`` 201 | for the implementation of the array reshaping routine. 202 | 203 | Examples 204 | -------- 205 | >>> import numpy as np 206 | >>> from xarray_multiscale.reducers import windowed_mean 207 | >>> data = np.arange(16).reshape(4, 4) 208 | >>> windowed_min(data, (2, 2)) 209 | array([[0, 2], 210 | [8, 10]]) 211 | """ 212 | reshaped = reshape_windowed(array, window_size) 213 | result: npt.NDArray[Any] = reshaped.min(axis=tuple(range(1, reshaped.ndim, 2)), **kwargs) 214 | return result 215 | 216 | 217 | def windowed_mode(array: npt.NDArray[Any], window_size: tuple[int, ...]) -> npt.NDArray[Any]: 218 | """ 219 | Compute the windowed mode of an array using either 220 | `windowed_mode_countess` or `windowed_mode_scipy` 221 | Input will be coerced to a numpy array. 222 | 223 | Parameters 224 | ---------- 225 | array: Array-like, e.g. Numpy array, Dask array 226 | The array to be downscaled. The array must have a ``reshape`` 227 | method. 228 | 229 | window_size: Tuple of ints 230 | The window to use for aggregation. The array is partitioned into 231 | non-overlapping regions with size equal to ``window_size``, and the 232 | values in each window are aggregated to generate the result. 233 | If the product of the elements of ``window_size`` is 16 or less, then 234 | ``windowed_mode_countless`` will be used. Otherwise, 235 | ``windowed_mode_scipy`` is used. This is a speculative cutoff based 236 | on the documentation of the countless algorithm used in 237 | ``windowed_mode_countless`` which was created by William Silversmith. 238 | 239 | Returns 240 | ------- 241 | Numpy array 242 | The result of the windowed mode. The length of each axis of this array 243 | will be a fraction of the input. 244 | 245 | Examples 246 | -------- 247 | >>> import numpy as np 248 | >>> from xarray_multiscale.reducers import windowed_mode 249 | >>> data = np.arange(16).reshape(4, 4) 250 | >>> windowed_mode(data, (2, 2)) 251 | array([[ 0, 2], 252 | [ 8, 10]]) 253 | """ 254 | 255 | if np.prod(window_size) <= 16: 256 | return windowed_mode_countless(array, window_size) 257 | else: 258 | return windowed_mode_scipy(array, window_size) 259 | 260 | 261 | def windowed_mode_scipy(array: npt.NDArray[Any], window_size: tuple[int, ...]) -> npt.NDArray[Any]: 262 | """ 263 | Compute the windowed mode of an array using scipy.stats.mode. 264 | Input will be coerced to a numpy array. 265 | 266 | Parameters 267 | ---------- 268 | array: Array-like, e.g. Numpy array, Dask array 269 | The array to be downscaled. The array must have a ``reshape`` 270 | method. 271 | 272 | window_size: Tuple of ints 273 | The window to use for aggregation. The array is partitioned into 274 | non-overlapping regions with size equal to ``window_size``, and the 275 | values in each window are aggregated to generate the result. 276 | 277 | Returns 278 | ------- 279 | Numpy array 280 | The result of the windowed mode. The length of each axis of this array 281 | will be a fraction of the input. 282 | 283 | Notes 284 | ----- 285 | This function wraps ``scipy.stats.mode``. 286 | 287 | Examples 288 | -------- 289 | >>> import numpy as np 290 | >>> from xarray_multiscale.reducers import windowed_mode 291 | >>> data = np.arange(16).reshape(4, 4) 292 | >>> windowed_mode(data, (2, 2)) 293 | array([[ 0, 2], 294 | [ 8, 10]]) 295 | """ 296 | reshaped = reshape_windowed(array, window_size) 297 | transposed_shape = tuple(range(0, reshaped.ndim, 2)) + tuple(range(1, reshaped.ndim, 2)) 298 | transposed = reshaped.transpose(transposed_shape) 299 | collapsed = transposed.reshape(tuple(reshaped.shape[slice(0, None, 2)]) + (-1,)) 300 | result: npt.NDArray[Any] = mode(collapsed, axis=collapsed.ndim - 1, keepdims=False).mode 301 | return result 302 | 303 | 304 | def _pick(a: npt.NDArray[Any], b: npt.NDArray[Any]) -> Any: 305 | return a * (a == b) 306 | 307 | 308 | def _lor(a: npt.NDArray[Any], b: npt.NDArray[Any]) -> Any: 309 | return a + (a == 0) * b 310 | 311 | 312 | def windowed_mode_countless( 313 | array: npt.NDArray[Any], window_size: tuple[int, ...] 314 | ) -> npt.NDArray[Any]: 315 | """ 316 | countless downsamples labeled images (segmentations) 317 | by finding the mode using vectorized instructions. 318 | It is ill advised to use this O(2^N-1) time algorithm 319 | and O(NCN/2) space for N > about 16 tops. 320 | This means it's useful for the following kinds of downsampling. 321 | This could be implemented for higher performance in 322 | C/Cython more simply, but at least this is easily 323 | portable. 324 | 2x2x1 (N=4), 2x2x2 (N=8), 4x4x1 (N=16), 3x2x1 (N=6) 325 | and various other configurations of a similar nature. 326 | c.f. https://medium.com/@willsilversmith/countless-3d-vectorized-2x-downsampling-of-labeled-volume-images-using-python-and-numpy-59d686c2f75 327 | 328 | This function has been modified from the original 329 | to avoid mutation of the input argument. 330 | 331 | Parameters 332 | ---------- 333 | array: Numpy array 334 | The array to be downscaled. 335 | 336 | window_size: Tuple of ints 337 | The window size. The length of ``window_size`` must match the 338 | dimensionality of ``array``. 339 | 340 | """ 341 | sections = [] 342 | 343 | mode_of = reduce(lambda x, y: x * y, window_size) 344 | majority = int(math.ceil(float(mode_of) / 2)) 345 | 346 | for offset in np.ndindex(window_size): 347 | part = 1 + array[tuple(np.s_[o::f] for o, f in zip(offset, window_size))] 348 | sections.append(part) 349 | 350 | subproblems: list[dict[tuple[int, int], npt.ArrayLike]] = [{}, {}] 351 | results2 = None 352 | for x, y in combinations(range(len(sections) - 1), 2): 353 | res = _pick(sections[x], sections[y]) 354 | subproblems[0][(x, y)] = res 355 | if results2 is not None: 356 | results2 = _lor(results2, res) # type: ignore[unreachable] 357 | else: 358 | results2 = res 359 | 360 | results = [results2] 361 | for r in range(3, majority + 1): 362 | r_results = None 363 | for combo in combinations(range(len(sections)), r): 364 | res = _pick(subproblems[0][combo[:-1]], sections[combo[-1]]) # type: ignore[index, arg-type] 365 | 366 | if combo[-1] != len(sections) - 1: 367 | subproblems[1][combo] = res # type: ignore[index] 368 | 369 | if r_results is not None: 370 | r_results = _lor(r_results, res) # type: ignore[unreachable] 371 | else: 372 | r_results = res 373 | results.append(r_results) 374 | subproblems[0] = subproblems[1] 375 | subproblems[1] = {} 376 | 377 | results.reverse() 378 | final_result: npt.NDArray[Any] = _lor(reduce(_lor, results), sections[-1]) - 1 # type: ignore[arg-type] 379 | 380 | return final_result 381 | 382 | 383 | def windowed_rank( 384 | array: npt.NDArray[Any], window_size: tuple[int, ...], rank: int = -1 385 | ) -> npt.NDArray[Any]: 386 | """ 387 | Compute the windowed rank order filter of an array. 388 | Input will be coerced to a numpy array. 389 | 390 | Parameters 391 | ---------- 392 | array: Array-like, e.g. Numpy array, Dask array 393 | The array to be downscaled. The array must have a ``reshape`` 394 | method. 395 | 396 | window_size: tuple[int, ...] 397 | The window to use for aggregation. The array is partitioned into 398 | non-overlapping regions with size equal to ``window_size``, and the 399 | values in each window are sorted to generate the result. 400 | 401 | rank: int, default=-1 402 | The index to take from the sorted values in each window. If non-negative, then 403 | rank must be between 0 and the product of the elements of ``window_size`` minus one, 404 | (inclusive). 405 | Rank may be negative, in which case it denotes an index relative to the end of the sorted 406 | values following normal python indexing rules. 407 | E.g., when rank is -1 (the default), this takes the maxmum value of each window. 408 | 409 | Returns 410 | ------- 411 | Numpy array 412 | The result of the windowed rank filter. The length of each axis of this array 413 | will be a fraction of the input. 414 | 415 | Examples 416 | -------- 417 | >>> import numpy as np 418 | >>> from xarray_multiscale.reducers import windowed_rank 419 | >>> data = np.arange(16).reshape(4, 4) 420 | >>> windowed_rank(data, (2, 2), -2) 421 | array([[ 4 6] 422 | [12 14]]) 423 | """ 424 | max_rank = np.prod(window_size) - 1 425 | if rank > max_rank or rank < -max_rank - 1: 426 | msg = ( 427 | f"Invalid rank: {rank} for window_size: {window_size} ", 428 | f"If rank is negative then between either -1 and {-max_rank-1}, inclusive", 429 | f"If rank is non-negtaive, then it must be between 0 and {max_rank}, inclusive.", 430 | ) 431 | raise ValueError(msg) 432 | reshaped = reshape_windowed(array, window_size) 433 | transposed_shape = tuple(range(0, reshaped.ndim, 2)) + tuple(range(1, reshaped.ndim, 2)) 434 | transposed = reshaped.transpose(transposed_shape) 435 | collapsed = transposed.reshape(tuple(reshaped.shape[slice(0, None, 2)]) + (-1,)) 436 | result: npt.NDArray[Any] = np.take(np.sort(collapsed, axis=-1), rank, axis=-1) 437 | return result 438 | -------------------------------------------------------------------------------- /src/xarray_multiscale/util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Tuple, Union 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | from xarray import DataArray 6 | 7 | 8 | def adjust_shape(array: DataArray, scale_factors: Sequence[int]) -> DataArray: 9 | """ 10 | Pad or crop array such that its new dimensions are evenly 11 | divisible by a set of integers. 12 | 13 | Parameters 14 | ---------- 15 | array : ndarray 16 | Array that will be padded. 17 | 18 | scale_factors : Sequence of ints 19 | The output array is guaranteed to have dimensions that are each 20 | evenly divisible by the corresponding scale factor, and chunks 21 | that are smaller than or equal to the scale factor 22 | (if the array has chunks) 23 | 24 | Returns 25 | ------- 26 | DataArray 27 | """ 28 | result = array 29 | misalignment = np.any(np.mod(array.shape, scale_factors)) 30 | if misalignment: 31 | new_shape = np.subtract(array.shape, np.mod(array.shape, scale_factors)) 32 | result = array.isel({d: slice(s) for d, s in zip(array.dims, new_shape)}) 33 | return result 34 | 35 | 36 | def logn(x: npt.ArrayLike, n: npt.ArrayLike) -> npt.NDArray[np.float64]: 37 | """ 38 | Compute the logarithm of x base n. 39 | 40 | Parameters 41 | ---------- 42 | x : float or int. 43 | n: float or int. 44 | 45 | Returns 46 | ------- 47 | float 48 | np.log(x) / np.log(n) 49 | 50 | """ 51 | result: npt.NDArray[np.float64] = np.log(x) / np.log(n) 52 | return result 53 | 54 | 55 | def broadcast_to_rank( 56 | value: Union[int, Sequence[int], Dict[int, int]], rank: int 57 | ) -> Tuple[int, ...]: 58 | result_dict = {} 59 | if isinstance(value, int): 60 | result_dict = {k: value for k in range(rank)} 61 | elif isinstance(value, Sequence): 62 | if not (len(value) == rank): 63 | raise ValueError(f"Length of value {len(value)} must match rank: {rank}") 64 | else: 65 | result_dict = {k: v for k, v in enumerate(value)} 66 | elif isinstance(value, dict): 67 | for dim in range(rank): 68 | result_dict[dim] = value.get(dim, 1) 69 | else: 70 | msg = ( # type: ignore[unreachable] 71 | "The first argument must be an integer, a sequence of integers", 72 | f"or a dictionary of integers. Got {type(value)}", 73 | ) 74 | raise ValueError(msg) 75 | result = tuple(result_dict.values()) 76 | typecheck = tuple(isinstance(val, int) for val in result) 77 | if not all(typecheck): 78 | bad_values = tuple(result[idx] for idx, val in enumerate(typecheck) if not val) 79 | msg = ( 80 | "All elements of the first argument of this function must be integers. " 81 | f"Got non-integer values: {bad_values}" 82 | ) 83 | raise ValueError(msg) 84 | 85 | return result 86 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaneliaSciComp/xarray-multiscale/2282409c2c9aac9f80b6c09073bd1310683fd012/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_chunks.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | from xarray import DataArray 3 | 4 | from xarray_multiscale.chunks import align_chunks, normalize_chunks 5 | 6 | 7 | def test_normalize_chunks(): 8 | data1 = DataArray(da.zeros((4, 6), chunks=(1, 1))) 9 | assert normalize_chunks(data1, {"dim_0": 2, "dim_1": 1}) == {"dim_0": 2, "dim_1": 1} 10 | 11 | data2 = DataArray(da.zeros((4, 6), chunks=(1, 1)), dims=("a", "b")) 12 | assert normalize_chunks(data2, {"a": 2, "b": 1}) == {"a": 2, "b": 1} 13 | 14 | data3 = DataArray(da.zeros((4, 6), chunks=(1, 1)), dims=("a", "b")) 15 | assert normalize_chunks(data3, {"a": -1, "b": -1}) == {"a": 4, "b": 6} 16 | 17 | data4 = DataArray(da.zeros((4, 6), chunks=(1, 1)), dims=("a", "b")) 18 | assert normalize_chunks(data4, {"a": -1}) == {"a": 4, "b": 1} 19 | 20 | 21 | def test_align_chunks(): 22 | data = da.arange(10, chunks=1) 23 | rechunked = align_chunks(data, scale_factors=(2,)) 24 | assert rechunked.chunks == ((2,) * 5,) 25 | 26 | data = da.arange(10, chunks=2) 27 | rechunked = align_chunks(data, scale_factors=(2,)) 28 | assert rechunked.chunks == ((2,) * 5,) 29 | 30 | data = da.arange(10, chunks=(1, 1, 3, 5)) 31 | rechunked = align_chunks(data, scale_factors=(2,)) 32 | assert rechunked.chunks == ( 33 | ( 34 | 2, 35 | 2, 36 | 2, 37 | 4, 38 | ), 39 | ) 40 | -------------------------------------------------------------------------------- /tests/test_docs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | from pytest_examples import CodeExample, EvalExample, find_examples 5 | 6 | 7 | @pytest.mark.parametrize("example", find_examples("docs"), ids=str) 8 | def test_docstrings(example: CodeExample, eval_example: EvalExample): 9 | if "test=skip" not in example.prefix_tags(): 10 | eval_example.run_print_check(example) 11 | else: 12 | pytest.skip() 13 | -------------------------------------------------------------------------------- /tests/test_multiscale.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import numpy as np 3 | import pytest 4 | from src.xarray_multiscale.reducers import windowed_rank 5 | from xarray import DataArray 6 | from xarray.testing import assert_equal 7 | 8 | from xarray_multiscale.multiscale import ( 9 | adjust_shape, 10 | downsampling_depth, 11 | downscale, 12 | downscale_coords, 13 | downscale_dask, 14 | multiscale, 15 | ) 16 | from xarray_multiscale.reducers import windowed_mean 17 | 18 | 19 | def test_downscale_depth(): 20 | assert downsampling_depth((1,), (1,)) == 0 21 | assert downsampling_depth((2,), (3,)) == 0 22 | assert downsampling_depth((2, 1), (2, 1)) == 1 23 | assert downsampling_depth((2, 2, 2), (2, 2, 2)) == 1 24 | assert downsampling_depth((1, 2, 2), (2, 2, 2)) == 0 25 | assert downsampling_depth((4, 4, 4), (2, 2, 2)) == 2 26 | assert downsampling_depth((4, 2, 2), (2, 2, 2)) == 1 27 | assert downsampling_depth((5, 2, 2), (2, 2, 2)) == 1 28 | assert downsampling_depth((7, 2, 2), (2, 2, 2)) == 1 29 | assert downsampling_depth((1500, 5495, 5200), (2, 2, 2)) == 10 30 | 31 | 32 | @pytest.mark.parametrize(("size", "scale"), ((10, 2), (11, 2), ((10, 11), (2, 3)))) 33 | def test_adjust_shape(size, scale): 34 | arr = DataArray(np.zeros(size)) 35 | scale_array = np.array(scale) 36 | old_shape_array = np.array(arr.shape) 37 | 38 | cropped = adjust_shape(arr, scale) 39 | new_shape_array = np.array(cropped.shape) 40 | if np.all((old_shape_array % scale_array) == 0): 41 | assert np.array_equal(new_shape_array, old_shape_array) 42 | else: 43 | assert np.array_equal(new_shape_array, old_shape_array - (old_shape_array % scale_array)) 44 | 45 | 46 | def test_downscale_2d(): 47 | scale = (2, 1) 48 | 49 | data = DataArray( 50 | np.array([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype="uint8"), 51 | ) 52 | answer = DataArray(np.array([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]])) 53 | downscaled = downscale(data, windowed_mean, scale, preserve_dtype=False) 54 | downscaled_old_dtype = downscale(data, windowed_mean, scale, preserve_dtype=True) 55 | assert np.array_equal(downscaled, answer) 56 | assert np.array_equal( 57 | downscaled_old_dtype, 58 | answer.astype(data.dtype), 59 | ) 60 | 61 | 62 | def test_downscale_coords(): 63 | data = DataArray(np.zeros((10, 10)), dims=("x", "y"), coords={"x": np.arange(10)}) 64 | scale_factors = (2, 1) 65 | downscaled = downscale_coords(data, scale_factors) 66 | answer = {"x": data["x"].coarsen({"x": scale_factors[0]}).mean()} 67 | 68 | assert downscaled.keys() == answer.keys() 69 | for k in downscaled: 70 | assert_equal(answer[k], downscaled[k]) 71 | 72 | data = DataArray( 73 | np.zeros((10, 10)), 74 | dims=("x", "y"), 75 | coords={"x": np.arange(10), "y": 5 + np.arange(10)}, 76 | ) 77 | scale_factors = (2, 1) 78 | downscaled = downscale_coords(data, scale_factors) 79 | answer = { 80 | "x": data["x"].coarsen({"x": scale_factors[0]}).mean(), 81 | "y": data["y"].coarsen({"y": scale_factors[1]}).mean(), 82 | } 83 | 84 | assert downscaled.keys() == answer.keys() 85 | for k in downscaled: 86 | assert_equal(answer[k], downscaled[k]) 87 | 88 | data = DataArray( 89 | np.zeros((10, 10)), 90 | dims=("x", "y"), 91 | coords={"x": np.arange(10), "y": 5 + np.arange(10), "foo": 5}, 92 | ) 93 | scale_factors = (2, 2) 94 | downscaled = downscale_coords(data, scale_factors) 95 | answer = { 96 | "x": data["x"].coarsen({"x": scale_factors[0]}).mean(), 97 | "y": data["y"].coarsen({"y": scale_factors[1]}).mean(), 98 | "foo": data["foo"], 99 | } 100 | 101 | assert downscaled.keys() == answer.keys() 102 | for k in downscaled: 103 | assert_equal(answer[k], downscaled[k]) 104 | 105 | 106 | def test_invalid_multiscale(): 107 | with pytest.raises(ValueError): 108 | downscale_dask(np.arange(10), windowed_mean, (3,)) 109 | with pytest.raises(ValueError): 110 | downscale_dask(np.arange(16).reshape(4, 4), windowed_mean, (3, 3)) 111 | 112 | 113 | @pytest.mark.parametrize("chained", (True, False)) 114 | @pytest.mark.parametrize("ndim", (1, 2, 3, 4)) 115 | def test_multiscale(ndim: int, chained: bool): 116 | chunks = (2,) * ndim 117 | shape = (9,) * ndim 118 | cropslice = tuple(slice(s) for s in shape) 119 | cell = np.zeros(np.prod(chunks)).astype("float") 120 | cell[0] = 1 121 | cell = cell.reshape(*chunks) 122 | base_array = np.tile(cell, np.ceil(np.divide(shape, chunks)).astype("int"))[cropslice] 123 | 124 | pyr = multiscale(base_array, windowed_mean, 2, chained=chained) 125 | assert [p.shape for p in pyr] == [shape, (4,) * ndim, (2,) * ndim] 126 | 127 | # check that the first multiscale array is identical to the input data 128 | assert np.array_equal(pyr[0].data, base_array) 129 | 130 | 131 | @pytest.mark.parametrize("rank", (-1, 0, 1)) 132 | def test_multiscale_rank_kwargs(rank: int): 133 | data = np.arange(16) 134 | window_size = (4,) 135 | pyr = multiscale(data, windowed_rank, window_size, rank=rank) 136 | assert np.array_equal(pyr[1].data, windowed_rank(data, window_size=window_size, rank=rank)) 137 | 138 | 139 | def test_chunking(): 140 | ndim = 3 141 | shape = (16,) * ndim 142 | chunks = (4,) * ndim 143 | base_array = da.zeros(shape, chunks=chunks) 144 | reducer = windowed_mean 145 | scale_factors = (2,) * ndim 146 | 147 | multi = multiscale(base_array, reducer, scale_factors) 148 | expected_chunks = [ 149 | np.floor_divide(chunks, [s**idx for s in scale_factors]) for idx, m in enumerate(multi) 150 | ] 151 | expected_chunks = [ 152 | x 153 | if np.all(x) 154 | else [ 155 | 1, 156 | ] 157 | * ndim 158 | for x in expected_chunks 159 | ] 160 | assert all([np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]) 161 | 162 | multi = multiscale(base_array, reducer, scale_factors, chunks=chunks) 163 | expected_chunks = [ 164 | chunks if np.greater(m.shape, chunks).all() else m.shape for idx, m in enumerate(multi) 165 | ] 166 | assert all([np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]) 167 | 168 | chunks = (3, -1, -1) 169 | multi = multiscale(base_array, reducer, scale_factors, chunks=chunks) 170 | expected_chunks = [(min(chunks[0], m.shape[0]), m.shape[1], m.shape[2]) for m in multi] 171 | assert all([np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]) 172 | 173 | chunks = 3 174 | multi = multiscale(base_array, reducer, scale_factors, chunks=chunks) 175 | expected_chunks = [tuple(min(chunks, s) for s in m.shape) for m in multi] 176 | assert all([np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]) 177 | 178 | 179 | def test_coords(): 180 | dims = ("z", "y", "x") 181 | shape = (16,) * len(dims) 182 | base_array = np.random.randint(0, 255, shape, dtype="uint8") 183 | 184 | translates = (0.0, -10, 10) 185 | scales = (1.0, 2.0, 3.0) 186 | coords = tuple( 187 | (d, sc * (np.arange(shp) + tr)) 188 | for d, sc, shp, tr in zip(dims, scales, base_array.shape, translates) 189 | ) 190 | array = DataArray(base_array, coords=coords) 191 | downscaled = array.coarsen({"z": 2, "y": 2, "x": 2}).mean() 192 | 193 | multi = multiscale(array, windowed_mean, (2, 2, 2), preserve_dtype=False) 194 | 195 | assert_equal(multi[0], array) 196 | assert_equal(multi[1], downscaled) 197 | 198 | 199 | @pytest.mark.parametrize("template", ("default", "{}")) 200 | def test_namer(template): 201 | from xarray_multiscale.multiscale import _default_namer 202 | 203 | if template == "default": 204 | namer = _default_namer 205 | else: 206 | 207 | def namer(v): 208 | return template.format(v) 209 | 210 | data = np.arange(16) 211 | m = multiscale(data, windowed_mean, 2, namer=namer) 212 | assert all(element.name == namer(idx) for idx, element in enumerate(m)) 213 | -------------------------------------------------------------------------------- /tests/test_reducers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import dask.array as da 4 | import numpy as np 5 | import pytest 6 | 7 | from xarray_multiscale.reducers import ( 8 | reshape_windowed, 9 | windowed_max, 10 | windowed_mean, 11 | windowed_min, 12 | windowed_mode, 13 | windowed_rank, 14 | ) 15 | 16 | 17 | @pytest.mark.parametrize("ndim", (1, 2, 3)) 18 | @pytest.mark.parametrize("window_size", (1, 2, 3, 4, 5)) 19 | def test_windowed_mean(ndim: int, window_size: int): 20 | cell = (window_size,) * ndim 21 | cell_scaling = 4 22 | size = np.array(cell) * cell_scaling 23 | data = da.random.randint(0, 255, size=size.tolist(), chunks=cell) 24 | result = windowed_mean(data.compute(), cell) 25 | test = data.map_blocks(lambda v: v.mean(keepdims=True)).compute() 26 | assert np.array_equal(result, test) 27 | 28 | 29 | @pytest.mark.parametrize("ndim", (1, 2, 3)) 30 | @pytest.mark.parametrize("window_size", (1, 2, 3, 4, 5)) 31 | def test_windowed_max(ndim: int, window_size: int): 32 | cell = (window_size,) * ndim 33 | cell_scaling = 4 34 | size = np.array(cell) * cell_scaling 35 | data = da.random.randint(0, 255, size=size.tolist(), chunks=cell) 36 | result = windowed_max(data.compute(), cell) 37 | test = data.map_blocks(lambda v: v.max(keepdims=True)).compute() 38 | assert np.array_equal(result, test) 39 | 40 | 41 | @pytest.mark.parametrize("ndim", (1, 2, 3)) 42 | @pytest.mark.parametrize("window_size", (1, 2, 3, 4, 5)) 43 | def test_windowed_min(ndim: int, window_size: int): 44 | cell = (window_size,) * ndim 45 | cell_scaling = 4 46 | size = np.array(cell) * cell_scaling 47 | data = da.random.randint(0, 255, size=size.tolist(), chunks=cell) 48 | result = windowed_min(data.compute(), cell) 49 | test = data.map_blocks(lambda v: v.min(keepdims=True)).compute() 50 | assert np.array_equal(result, test) 51 | 52 | 53 | def test_windowed_mode(): 54 | data = np.arange(16) % 3 + np.arange(16) % 2 55 | answer = np.array([2, 0, 1, 2]) 56 | results = windowed_mode(data, (4,)) 57 | # only compare regions with a majority value 58 | assert np.array_equal(results[[0, 2, 3]], answer[[0, 2, 3]]) 59 | 60 | data = np.arange(16).reshape(4, 4) % 3 61 | answer = np.array([[1, 0], [0, 2]]) 62 | results = windowed_mode(data, (2, 2)) 63 | assert np.array_equal(results, answer) 64 | 65 | 66 | def test_windowed_rank(): 67 | initial_array = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) 68 | larger_array = np.tile(initial_array, (2, 2, 2)) 69 | window_size = (2, 2, 2) 70 | 71 | # 2nd brightest voxel 72 | rank = np.prod(window_size) - 2 73 | answer = np.array([[[7, 7], [7, 7]], [[7, 7], [7, 7]]]) 74 | results = windowed_rank(larger_array, window_size, rank) 75 | assert np.array_equal(results, answer) 76 | 77 | # Test negative rank 78 | rank = -8 79 | answer = np.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) 80 | results = windowed_rank(larger_array, window_size, rank) 81 | assert np.array_equal(results, answer) 82 | 83 | # Test out-of-bounds rank 84 | rank = 100 85 | with pytest.raises(ValueError): 86 | windowed_rank(larger_array, window_size, rank) 87 | 88 | 89 | @pytest.mark.parametrize("windows_per_dim", (1, 2, 3, 4, 5)) 90 | @pytest.mark.parametrize( 91 | "window_size", ((1,), (2,), (1, 2), (2, 2), (2, 2, 2), (1, 2, 3), (3, 3, 3, 3)) 92 | ) 93 | def test_reshape_windowed(windows_per_dim: int, window_size: Tuple[int, ...]): 94 | size = (windows_per_dim * np.array(window_size)).tolist() 95 | data = np.arange(np.prod(size)).reshape(size) 96 | reshaped = reshape_windowed(data, window_size) 97 | with pytest.raises(ValueError): 98 | reshape_windowed(data, [*window_size, 1]) 99 | assert reshaped.shape[0::2] == (windows_per_dim,) * len(window_size) 100 | assert reshaped.shape[1::2] == window_size 101 | slice_data = tuple(slice(w) for w in window_size) 102 | slice_reshaped = tuple(slice(None) if s % 2 else slice(0, 1) for s in range(reshaped.ndim)) 103 | # because we are reshaping the array, if the first window is correct, all the others 104 | # will be correct too 105 | assert np.array_equal(data[slice_data].squeeze(), reshaped[slice_reshaped].squeeze()) 106 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | from xarray_multiscale.util import broadcast_to_rank 2 | 3 | 4 | def test_broadcast_to_rank(): 5 | assert broadcast_to_rank(2, 1) == (2,) 6 | assert broadcast_to_rank(2, 2) == (2, 2) 7 | assert broadcast_to_rank((2, 3), 2) == (2, 3) 8 | assert broadcast_to_rank({0: 2}, 3) == (2, 1, 1) 9 | --------------------------------------------------------------------------------