├── xarray_extras ├── py.typed ├── tests │ ├── __init__.py │ ├── test_numba_extras.py │ ├── test_sort.py │ ├── test_stack.py │ ├── test_cumulatives.py │ ├── test_interpolate.py │ └── test_csv.py ├── kernels │ ├── __init__.py │ ├── cumulatives.py │ ├── csv.py │ ├── np_to_csv_py.py │ ├── np_to_csv.c │ └── interpolate.py ├── duck │ ├── __init__.py │ └── sort.py ├── compat.py ├── __init__.py ├── numba_extras.py ├── stack.py ├── sort.py ├── cumulatives.py ├── csv.py └── interpolate.py ├── doc ├── _static │ ├── .gitignore │ └── style.css ├── _templates │ └── layout.html ├── api │ ├── csv.rst │ ├── stack.rst │ ├── cumulatives.rst │ ├── interpolate.rst │ ├── numba_extras.rst │ └── sort.rst ├── installing.rst ├── develop.rst ├── index.rst ├── whats-new.rst └── conf.py ├── .gitattributes ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── pre-commit.yml │ ├── docs.yml │ └── pytest.yml ├── ci ├── requirements-latest.yml ├── requirements-docs.yml ├── requirements-minimal.yml └── requirements-upstream.yml ├── setup.py ├── .readthedocs.yaml ├── MANIFEST.in ├── README.md ├── .gitignore ├── .pre-commit-config.yaml ├── HOW_TO_RELEASE ├── pyproject.toml └── LICENSE /xarray_extras/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xarray_extras/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/_static/.gitignore: -------------------------------------------------------------------------------- 1 | examples*.png 2 | *.log 3 | *.pdf 4 | *.fbd_latexmk 5 | *.aux 6 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # reduce the number of merge conflicts 2 | doc/whats-new.rst merge=union 3 | -------------------------------------------------------------------------------- /doc/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} {% set css_files = css_files + 2 | ["_static/style.css"] %} 3 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | - [ ] Closes #xxxx 2 | - [ ] Tests added / passed 3 | - [ ] Passes `pre-commit run --all-files` 4 | -------------------------------------------------------------------------------- /doc/api/csv.rst: -------------------------------------------------------------------------------- 1 | csv 2 | === 3 | 4 | .. automodule:: xarray_extras.csv 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/api/stack.rst: -------------------------------------------------------------------------------- 1 | stack 2 | ===== 3 | 4 | .. automodule:: xarray_extras.stack 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/api/cumulatives.rst: -------------------------------------------------------------------------------- 1 | cumulatives 2 | =========== 3 | 4 | .. automodule:: xarray_extras.cumulatives 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/api/interpolate.rst: -------------------------------------------------------------------------------- 1 | interpolate 2 | =========== 3 | 4 | .. automodule:: xarray_extras.interpolate 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/api/numba_extras.rst: -------------------------------------------------------------------------------- 1 | numba\_extras 2 | ============= 3 | 4 | .. automodule:: xarray_extras.numba_extras 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /xarray_extras/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | """Pure numpy functions that either run in dask or are used directly. 2 | Typyically invoked through ``xarray.apply_ufunc(dask='parallelized'). 3 | """ 4 | -------------------------------------------------------------------------------- /xarray_extras/duck/__init__.py: -------------------------------------------------------------------------------- 1 | """Functions that accept either a numpy array or a dask array, and return 2 | another of the matching type. 3 | Typyically invoked through ``xarray.apply_ufunc(dask='allowed'). 4 | """ 5 | -------------------------------------------------------------------------------- /ci/requirements-latest.yml: -------------------------------------------------------------------------------- 1 | name: xarray-extras 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - dask 6 | - numba 7 | - numpy 8 | - pandas 9 | - pytest 10 | - pytest-cov 11 | - scipy 12 | - xarray 13 | -------------------------------------------------------------------------------- /ci/requirements-docs.yml: -------------------------------------------------------------------------------- 1 | name: xarray-extras-docs 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.12 6 | - sphinx 7 | - sphinx_rtd_theme 8 | - gcc_linux-64 9 | - dask 10 | - numba 11 | - scipy 12 | - xarray 13 | -------------------------------------------------------------------------------- /ci/requirements-minimal.yml: -------------------------------------------------------------------------------- 1 | name: xarray-extras-minimal 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - dask=2022.6.0 6 | - numba=0.56 7 | - numpy=1.23 8 | - pandas=1.5 9 | - pytest 10 | - pytest-cov 11 | - scipy=1.9 12 | - xarray=2022.11.0 13 | -------------------------------------------------------------------------------- /xarray_extras/compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | try: 4 | from xarray.namedarray.pycompat import array_type 5 | except ImportError: # <2024.2.0 6 | from xarray.core.pycompat import array_type # type: ignore[no-redef] 7 | 8 | dask_array_type = array_type("dask") 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import Extension, setup 2 | 3 | setup( 4 | use_scm_version=True, 5 | # Compile CPython extensions 6 | ext_modules=[ 7 | Extension( 8 | "xarray_extras.kernels.np_to_csv", ["xarray_extras/kernels/np_to_csv.c"] 9 | ) 10 | ], 11 | ) 12 | -------------------------------------------------------------------------------- /xarray_extras/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | try: 4 | __version__ = importlib.metadata.version("xarray_extras") 5 | except importlib.metadata.PackageNotFoundError: # pragma: nocover 6 | # Local copy, not installed with pip 7 | __version__ = "999" 8 | 9 | __all__ = ("__version__",) 10 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: mambaforge-22.9 7 | 8 | conda: 9 | environment: ci/requirements-docs.yml 10 | 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | 16 | sphinx: 17 | builder: html 18 | configuration: doc/conf.py 19 | fail_on_warning: false 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include HOW_TO_RELEASE 2 | include LICENSE 3 | include *.md 4 | include *.py 5 | recursive-include ci * 6 | recursive-include doc * 7 | include xarray_extras * 8 | prune doc/_build 9 | global-exclude __pycache__ 10 | global-exclude *.pyc 11 | global-exclude .DS_Store 12 | global-exclude .ipynb_checkpoints 13 | global-exclude dask-worker-space 14 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: ["*"] 8 | 9 | jobs: 10 | checks: 11 | name: pre-commit hooks 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-python@v5 16 | - uses: pre-commit/action@v3.0.0 17 | -------------------------------------------------------------------------------- /doc/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-side-nav-search > a img.logo, 4 | .wy-side-nav-search .wy-dropdown > a img.logo { 5 | width: 12rem; 6 | } 7 | 8 | .wy-side-nav-search { 9 | background-color: #eee; 10 | } 11 | 12 | .wy-side-nav-search > div.version { 13 | display: none; 14 | } 15 | 16 | .wy-nav-top { 17 | background-color: #555; 18 | } 19 | -------------------------------------------------------------------------------- /ci/requirements-upstream.yml: -------------------------------------------------------------------------------- 1 | name: xarray-extras 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - dask 6 | # - numba # Not compatible with numpy 2 7 | - numpy 8 | - pandas 9 | - pyarrow 10 | - pytest 11 | - pytest-cov 12 | - scipy 13 | - xarray 14 | - pip 15 | - pip: 16 | - git+https://github.com/dask/dask 17 | - git+https://github.com/dask/distributed 18 | - git+https://github.com/pydata/xarray 19 | # numpy, pandas, pyarrow, and scipy are upgraded to nightly builds by pytest.yml 20 | -------------------------------------------------------------------------------- /doc/installing.rst: -------------------------------------------------------------------------------- 1 | .. _installing: 2 | 3 | Installation 4 | ============ 5 | 6 | Required dependencies 7 | --------------------- 8 | 9 | - Python 3.8 or later 10 | - `scipy `__ 11 | - `xarray `__ 12 | - `dask `__ 13 | - `numba `__ 14 | - C compiler (only if building from sources) 15 | 16 | Deployment 17 | ---------- 18 | 19 | - With pip: :command:`pip install xarray-extras` 20 | - With `anaconda `_: 21 | :command:`conda install -c conda-forge xarray-extras` 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xarray_extras 2 | 3 | [![doc-badge](https://github.com/crusaderky/xarray_extras/actions/workflows/docs.yml/badge.svg)](https://github.com/crusaderky/xarray_extras/actions) 4 | [![pre-commit-badge](https://github.com/crusaderky/xarray_extras/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/crusaderky/xarray_extras/actions) 5 | [![pytest-badge](https://github.com/crusaderky/xarray_extras/actions/workflows/pytest.yml/badge.svg)](https://github.com/crusaderky/xarray_extras/actions) 6 | [![codecov-badge](https://codecov.io/gh/crusaderky/xarray_extras/branch/main/graph/badge.svg)](https://codecov.io/gh/crusaderky/xarray_extras/branch/main) 7 | 8 | Advanced / experimental algorithms for xarray 9 | 10 | Full documentation at http://xarray-extras.readthedocs.io/ 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | 4 | # C extensions 5 | *.so 6 | 7 | # Packages 8 | *.egg 9 | *.egg-info 10 | .eggs 11 | dist 12 | build 13 | eggs 14 | parts 15 | var 16 | sdist 17 | develop-eggs 18 | .installed.cfg 19 | lib 20 | lib64 21 | 22 | # Installer logs 23 | pip-log.txt 24 | 25 | # Unit test / coverage reports 26 | .coverage* 27 | coverage.xml 28 | .tox 29 | nosetests.xml 30 | .cache 31 | .mypy_cache/ 32 | .ropeproject/ 33 | .tags* 34 | .testmon* 35 | .pytest_cache 36 | dask-worker-space/ 37 | 38 | # asv environments 39 | .asv 40 | 41 | # Translations 42 | *.mo 43 | 44 | # Mr Developer 45 | .mr.developer.cfg 46 | .project 47 | .pydevproject 48 | 49 | # IDEs 50 | .idea 51 | *.swp 52 | .DS_Store 53 | .vscode/ 54 | 55 | # Sync tools 56 | Icon* 57 | 58 | .ipynb_checkpoints 59 | 60 | doc/_build 61 | Untitled.ipynb 62 | htmlcov/ 63 | -------------------------------------------------------------------------------- /xarray_extras/tests/test_numba_extras.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | pytest.importorskip("numba") # Not available in upstream CI 5 | 6 | from xarray_extras.numba_extras import guvectorize 7 | 8 | DTYPES = [ 9 | # uint needs to appear before signed int: 10 | # https://github.com/numba/numba/issues/2934 11 | "uint8", 12 | "uint16", 13 | "uint32", 14 | "uint64", 15 | "int8", 16 | "int16", 17 | "int32", 18 | "int64", 19 | "float32", 20 | "float64", 21 | "complex64", 22 | "complex128", 23 | ] 24 | 25 | 26 | @guvectorize("{T}[:], {T}[:]", "()->()") 27 | def dumb_copy(x, y): 28 | for i in range(x.size): 29 | y.flat[i] = x.flat[i] 30 | 31 | 32 | @pytest.mark.parametrize("dtype", DTYPES) 33 | def test_guvectorize(dtype): 34 | x = np.arange(3, dtype=dtype) 35 | y = dumb_copy(x) 36 | np.testing.assert_equal(x, y) 37 | assert x.dtype == y.dtype 38 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: ["*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash -l {0} 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | 20 | - name: Setup Conda Environment 21 | uses: conda-incubator/setup-miniconda@v3 22 | with: 23 | miniforge-version: latest 24 | use-mamba: true 25 | environment-file: ci/requirements-docs.yml 26 | activate-environment: xarray-extras-docs 27 | 28 | - name: Show conda options 29 | run: conda config --show 30 | 31 | - name: conda info 32 | run: conda info 33 | 34 | - name: conda list 35 | run: conda list 36 | 37 | - name: Install 38 | run: python -m pip install --no-deps -e . 39 | 40 | - name: Build docs 41 | run: sphinx-build -n -j auto -b html -d build/doctrees doc build/html 42 | 43 | - uses: actions/upload-artifact@v4 44 | with: 45 | name: xarray_extras-docs 46 | path: build/html 47 | -------------------------------------------------------------------------------- /doc/api/sort.rst: -------------------------------------------------------------------------------- 1 | sort 2 | ==== 3 | 4 | .. automodule:: xarray_extras.sort 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | 10 | An example that uses all of the above functions is *source attribution*. 11 | Given a generic function :math:`y = f(x_{0}, x_{1}, ..., x_{i})`, which is 12 | embarrassingly parallel along a given dimension, one wants to find: 13 | 14 | - the top k elements of y along the dimension 15 | - the elements of all x's that generated the top k elements of y 16 | 17 | .. code:: 18 | 19 | >>> from xarray import DataArray 20 | >>> from xarray_extras.sort import * 21 | >>> x = DataArray([[5, 3, 2, 8, 1], 22 | >>> [0, 7, 1, 3, 2]], dims=['x', 's']) 23 | >>> y = x.sum('x') # y = f(x), embarrassingly parallel among dimension 's' 24 | >>> y 25 | 26 | array([ 5, 10, 3, 11, 3]) 27 | Dimensions without coordinates: s 28 | >>> top_y = topk(y, 3, 's') 29 | >>> top_y 30 | 31 | array([11, 10, 5]) 32 | Dimensions without coordinates: s 33 | >>> top_x = take_along_dim(x, argtopk(y, 3, 's'), 's') 34 | >>> top_x 35 | 36 | array([[8, 3, 5], 37 | [3, 7, 0]]) 38 | Dimensions without coordinates: x, s 39 | -------------------------------------------------------------------------------- /xarray_extras/kernels/cumulatives.py: -------------------------------------------------------------------------------- 1 | """Numba kernels for :mod:`cumulatives`""" 2 | 3 | import numpy as np 4 | 5 | from xarray_extras.numba_extras import guvectorize 6 | 7 | 8 | @guvectorize("{T}[:], intp[:], {T}[:]", "(i),(j)->()") 9 | def compound_sum(x: np.ndarray, c: np.ndarray, y: np.ndarray) -> None: 10 | """y = x[c[0]] + x[c[1]] + ... x[c[n]] 11 | until c[i] != -1 12 | """ 13 | acc = 0 14 | for i in c: 15 | if i == -1: 16 | break 17 | acc += x[i] 18 | y[0] = acc 19 | 20 | 21 | @guvectorize("{T}[:], intp[:], {T}[:]", "(i),(j)->()") 22 | def compound_prod(x: np.ndarray, c: np.ndarray, y: np.ndarray) -> None: 23 | """y = x[c[0]] * x[c[1]] * ... x[c[n]] 24 | until c[i] != -1 25 | """ 26 | acc = 1 27 | for i in c: 28 | if i == -1: 29 | break 30 | acc *= x[i] 31 | y[0] = acc 32 | 33 | 34 | @guvectorize("{T}[:], intp[:], {T}[:]", "(i),(j)->()") 35 | def compound_mean(x: np.ndarray, c: np.ndarray, y: np.ndarray) -> None: 36 | """y = mean(x[c[0]], x[c[1]], ... x[c[n]]) 37 | until c[i] != -1 38 | """ 39 | acc = 0 40 | j = 0 # Initialise j explicitly for when x.shape == (0, ) 41 | for j, i in enumerate(c): # noqa: B007 42 | if i == -1: 43 | break 44 | acc += x[i] 45 | else: 46 | # Reached the end of the row 47 | j += 1 48 | y[0] = acc / j 49 | -------------------------------------------------------------------------------- /xarray_extras/numba_extras.py: -------------------------------------------------------------------------------- 1 | """Extensions to numba""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Callable 6 | from typing import Any 7 | 8 | import numba 9 | 10 | _DTYPES = ( 11 | # uint needs to appear before signed int: 12 | # https://github.com/numba/numba/issues/2934 13 | "uint8", 14 | "uint16", 15 | "uint32", 16 | "uint64", 17 | "int8", 18 | "int16", 19 | "int32", 20 | "int64", 21 | "float32", 22 | "float64", 23 | "complex64", 24 | "complex128", 25 | ) 26 | 27 | 28 | def guvectorize( 29 | signature: str, layout: str, **kwargs: Any 30 | ) -> Callable[[Callable], Any]: 31 | """Convenience wrapper around :func:`numba.guvectorize`. 32 | Generate signature for all possible data types and set a few healthy 33 | defaults. 34 | 35 | :param str signature: 36 | numba signature, containing {T} 37 | :param str layout: 38 | as in :func:`numba.guvectorize` 39 | :param kwargs: 40 | passed verbatim to :func:`numba.guvectorize`. 41 | This function changes the default for cache from False to True. 42 | 43 | example:: 44 | 45 | guvectorize("{T}[:], {T}[:]", "(i)->(i)") 46 | 47 | Is the same as:: 48 | 49 | numba.guvectorize([ 50 | "float32[:], float32[:]", 51 | "float64[:], float64[:]", 52 | ... 53 | ], "(i)->(i)", cache=True) 54 | 55 | .. note:: 56 | Discussing upstream fix; see 57 | ``_. 58 | """ 59 | if "{T}" in signature: 60 | signatures = [signature.format(T=dtype) for dtype in _DTYPES] 61 | else: 62 | signatures = [signature] 63 | kwargs.setdefault("cache", True) 64 | return numba.guvectorize(signatures, layout, **kwargs) 65 | -------------------------------------------------------------------------------- /xarray_extras/stack.py: -------------------------------------------------------------------------------- 1 | """Utilities for stacking/unstacking dimensions""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Hashable 6 | from typing import TypeVar 7 | 8 | import pandas as pd 9 | import xarray 10 | 11 | T = TypeVar("T", xarray.DataArray, xarray.Dataset) 12 | 13 | 14 | def proper_unstack(array: T, dim: Hashable) -> T: 15 | """Work around an issue in xarray that causes the data to be sorted 16 | alphabetically by label on unstack(): 17 | 18 | ``_ 19 | 20 | Also work around issue that causes string labels to be converted to 21 | objects: 22 | 23 | ``_ 24 | 25 | :param array: 26 | xarray.DataArray or xarray.Dataset to unstack 27 | :param str dim: 28 | Name of existing dimension to unstack 29 | :returns: 30 | xarray.DataArray or xarray.Dataset with unstacked dimension 31 | """ 32 | # Regenerate Pandas multi-index to be ordered by first appearance 33 | mindex = array.coords[dim].to_pandas().index 34 | 35 | levels = [] 36 | codes = [] 37 | 38 | for levels_i, codes_i in zip(mindex.levels, mindex.codes): 39 | level_map: dict[Hashable, int] = {} 40 | 41 | for code in codes_i: 42 | if code not in level_map: 43 | level_map[code] = len(level_map) 44 | 45 | levels.append([levels_i[k] for k in level_map]) 46 | codes.append([level_map[k] for k in codes_i]) 47 | 48 | mindex = pd.MultiIndex(levels, codes, names=mindex.names) 49 | array = array.copy() 50 | array.coords[dim] = mindex 51 | 52 | # Invoke builtin unstack 53 | array = array.unstack((dim,)) 54 | 55 | # Convert numpy arrays of Python objects to numpy arrays of C floats, ints, 56 | # strings, etc. 57 | for name in mindex.names: 58 | if array.coords[name].dtype == object: 59 | array.coords[name] = array.coords[name].values.tolist() 60 | 61 | return array 62 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-case-conflict 7 | - id: check-merge-conflict 8 | - id: check-symlinks 9 | - id: check-yaml 10 | - id: debug-statements 11 | - id: end-of-file-fixer 12 | - id: mixed-line-ending 13 | - id: name-tests-test 14 | args: ["--pytest-test-first"] 15 | - id: requirements-txt-fixer 16 | - id: trailing-whitespace 17 | 18 | - repo: https://github.com/rbubley/mirrors-prettier 19 | rev: v3.4.2 20 | hooks: 21 | - id: prettier 22 | types_or: [yaml, markdown, html, css, scss, javascript, json] 23 | args: [--prose-wrap=always] 24 | 25 | - repo: https://github.com/astral-sh/ruff-pre-commit 26 | rev: v0.9.4 27 | hooks: 28 | - id: ruff-format 29 | - id: ruff 30 | args: ["--fix", "--show-fixes"] 31 | 32 | - repo: https://github.com/codespell-project/codespell 33 | rev: v2.4.1 34 | hooks: 35 | - id: codespell 36 | additional_dependencies: 37 | - tomli 38 | 39 | - repo: https://github.com/shellcheck-py/shellcheck-py 40 | rev: "v0.10.0.1" 41 | hooks: 42 | - id: shellcheck 43 | 44 | - repo: https://github.com/abravalheri/validate-pyproject 45 | rev: v0.23 46 | hooks: 47 | - id: validate-pyproject 48 | additional_dependencies: ["validate-pyproject-schema-store[all]"] 49 | 50 | - repo: https://github.com/python-jsonschema/check-jsonschema 51 | rev: "0.31.1" 52 | hooks: 53 | - id: check-dependabot 54 | - id: check-github-workflows 55 | 56 | - repo: https://github.com/pre-commit/mirrors-mypy 57 | rev: v1.14.1 58 | hooks: 59 | - id: mypy 60 | additional_dependencies: 61 | # Libraries exclusively imported under `if TYPE_CHECKING:` 62 | - typing_extensions 63 | # Typed libraries 64 | - numpy 65 | - dask 66 | - xarray 67 | - scipy 68 | -------------------------------------------------------------------------------- /doc/develop.rst: -------------------------------------------------------------------------------- 1 | Development Guidelines 2 | ====================== 3 | 4 | Install 5 | ------- 6 | 7 | 1. Clone this repository with git: 8 | 9 | .. code-block:: bash 10 | 11 | git clone git@github.com:crusaderky/xarray_extras.git 12 | cd xarray_extras 13 | 14 | 2. Install anaconda or miniconda (OS-dependent) 15 | 3. .. code-block:: bash 16 | 17 | conda env create -n xarray_extras --file ci/requirements.yml 18 | conda activate xarray_extras 19 | 20 | 4. Install C compilation stack: 21 | 22 | Linux 23 | .. code-block:: bash 24 | 25 | conda install gcc_linux-64 26 | 27 | MacOSX 28 | .. code-block:: bash 29 | 30 | conda install clang_osx-64 31 | 32 | Windows 33 | You need to manually install the Microsoft C compiler tools. Refer to CPython 34 | documentation. 35 | 36 | 37 | To keep a fork in sync with the upstream source: 38 | 39 | .. code-block:: bash 40 | 41 | cd xarray_extras 42 | git remote add upstream git@github.com:crusaderky/xarray_extras.git 43 | git remote -v 44 | git fetch -a upstream 45 | git checkout main 46 | git pull upstream main 47 | git push origin main 48 | 49 | Test 50 | ---- 51 | 52 | Test using ``py.test``: 53 | 54 | .. code-block:: bash 55 | 56 | python setup.py build_ext --inplace 57 | py.test xarray_extras 58 | 59 | Code Formatting 60 | --------------- 61 | 62 | xarray_extras uses several code linters (black, ruff, mypy), which are enforced by CI. 63 | Developers should run them locally before they submit a PR, through the single command 64 | 65 | .. code-block:: bash 66 | 67 | pre-commit run --all-files 68 | 69 | This makes sure that linter versions and options are aligned for all developers. 70 | 71 | Optionally, you may wish to setup the `pre-commit hooks `_ to 72 | run automatically when you make a git commit. This can be done by running: 73 | 74 | .. code-block:: bash 75 | 76 | pre-commit install 77 | 78 | from the root of the xarray_extras repository. Now the code linters will be run each time 79 | you commit changes. You can skip these checks with ``git commit --no-verify`` or with 80 | the short version ``git commit -n``. 81 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | xarray_extras: Advanced algorithms for xarray 2 | ============================================= 3 | This module offers several extensions to `xarray `_, 4 | which could not be included into the main module because they fall into one or 5 | more of the following categories: 6 | 7 | - They're too experimental 8 | - They're too niche 9 | - They introduce major new dependencies (e.g. 10 | `numba `_ or a C compiler) 11 | - They would be better done by doing major rework on multiple packages, and 12 | then one would need to wait for said changes to reach a stable release of 13 | each package - *in the right order*. 14 | 15 | The API of xarray-extras is unstable by definition, as features will be 16 | progressively migrated upwards towards xarray, dask, numpy, pandas, etc. 17 | 18 | Features 19 | -------- 20 | :doc:`api/csv` 21 | Multi-threaded CSV writer, much faster than 22 | :meth:`pandas.DataFrame.to_csv`, with full support for 23 | `dask `_ and 24 | `dask distributed `_. 25 | :doc:`api/cumulatives` 26 | Advanced cumulative sum/productory/mean functions 27 | :doc:`api/interpolate` 28 | dask-optimized n-dimensional spline interpolation 29 | :doc:`api/numba_extras` 30 | Additions to `numba `_ 31 | :doc:`api/sort` 32 | Advanced sort/take functions 33 | :doc:`api/stack` 34 | Tools for stacking/unstacking dimensions 35 | 36 | 37 | Index 38 | ----- 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | 43 | installing 44 | develop 45 | whats-new 46 | api/csv 47 | api/cumulatives 48 | api/interpolate 49 | api/numba_extras 50 | api/sort 51 | api/stack 52 | 53 | 54 | Credits 55 | ------- 56 | - :func:`~xarray_extras.stack.proper_unstack` was originally developed by 57 | Legal & General and released to the open source community in 2018. 58 | - All boilerplate is from 59 | `python_project_template `_, 60 | which in turn is from `xarray `_. 61 | 62 | License 63 | ------- 64 | 65 | xarray-extras is available under the open source `Apache License`__. 66 | 67 | __ http://www.apache.org/licenses/LICENSE-2.0.html 68 | -------------------------------------------------------------------------------- /xarray_extras/sort.py: -------------------------------------------------------------------------------- 1 | """Sorting functions""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Hashable 6 | from typing import TypeVar 7 | 8 | import xarray 9 | 10 | from xarray_extras.duck import sort as duck 11 | 12 | __all__ = ("argtopk", "take_along_dim", "topk") 13 | 14 | 15 | T = TypeVar("T", xarray.DataArray, xarray.Dataset) 16 | TV = TypeVar("TV", xarray.DataArray, xarray.Dataset, xarray.Variable) 17 | 18 | 19 | def topk(a: TV, k: int, dim: Hashable, split_every: int | None = None) -> TV: 20 | """Extract the k largest elements from a on the given dimension, and return 21 | them sorted from largest to smallest. If k is negative, extract the -k 22 | smallest elements instead, and return them sorted from smallest to largest. 23 | 24 | This assumes that ``k`` is small. All results will be returned in a single 25 | chunk along the given axis. 26 | """ 27 | return xarray.apply_ufunc( 28 | duck.topk, 29 | a, 30 | kwargs={"k": k, "split_every": split_every}, 31 | input_core_dims=[[dim]], 32 | output_core_dims=[["__temp_topk__"]], 33 | dask="allowed", 34 | ).rename({"__temp_topk__": dim}) 35 | 36 | 37 | def argtopk(a: TV, k: int, dim: Hashable, split_every: int | None = None) -> TV: 38 | """Extract the indexes of the k largest elements from a on the given 39 | dimension, and return them sorted from largest to smallest. If k is 40 | negative, extract the -k smallest elements instead, and return them 41 | sorted from smallest to largest. 42 | 43 | This assumes that ``k`` is small. All results will be returned in a single 44 | chunk along the given axis. 45 | """ 46 | return xarray.apply_ufunc( 47 | duck.argtopk, 48 | a, 49 | kwargs={"k": k, "split_every": split_every}, 50 | input_core_dims=[[dim]], 51 | output_core_dims=[["__temp_topk__"]], 52 | dask="allowed", 53 | ).rename({"__temp_topk__": dim}) 54 | 55 | 56 | def take_along_dim(a: T, ind: T, dim: Hashable) -> T: 57 | """Use the output of :func:`argtopk` to pick points from a. 58 | 59 | :param a: 60 | xarray.DataArray or xarray.Dataset 61 | :param ind: 62 | array of ints, as returned by :func:`argtopk` 63 | :param dim: 64 | dimension along which argtopk was executed 65 | """ 66 | a = a.rename({dim: "__temp_take_along_dim__"}) 67 | 68 | return xarray.apply_ufunc( 69 | duck.take_along_axis, 70 | a, 71 | ind, 72 | input_core_dims=[["__temp_take_along_dim__"], [dim]], 73 | output_core_dims=[[dim]], 74 | dask="allowed", 75 | ) 76 | -------------------------------------------------------------------------------- /xarray_extras/tests/test_sort.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from xarray import DataArray 3 | from xarray.testing import assert_equal 4 | 5 | from xarray_extras.sort import argtopk, take_along_dim, topk 6 | 7 | 8 | @pytest.mark.parametrize("use_dask", [False, True]) 9 | @pytest.mark.parametrize("split_every", [None, 2]) 10 | @pytest.mark.parametrize("transpose", [False, True]) 11 | @pytest.mark.parametrize( 12 | "func,k,expect", 13 | [ 14 | (topk, 3, [[5, 4, 3], [8, 7, 2]]), 15 | (topk, -3, [[1, 2, 3], [0, 1, 2]]), 16 | (argtopk, 3, [[3, 1, 4], [2, 0, 4]]), 17 | (argtopk, -3, [[0, 2, 4], [3, 1, 4]]), 18 | ], 19 | ) 20 | def test_topk_argtopk(use_dask, split_every, transpose, func, k, expect): 21 | a = DataArray( 22 | [[1, 4, 2, 5, 3], [7, 1, 8, 0, 2]], 23 | dims=["y", "x"], 24 | coords={"y": ["y1", "y2"], "x": ["x1", "x2", "x3", "x4", "x5"]}, 25 | ) 26 | 27 | if transpose: 28 | a = a.T 29 | if use_dask: 30 | a = a.chunk(1) 31 | 32 | expect = DataArray(expect, dims=["y", "x"], coords={"y": ["y1", "y2"]}) 33 | actual = func(a, k, "x", split_every=split_every) 34 | assert_equal(expect, actual) 35 | if use_dask: 36 | assert actual.chunks 37 | 38 | 39 | @pytest.mark.parametrize("ind_use_dask", [False, True]) 40 | @pytest.mark.parametrize("a_use_dask", [False, True]) 41 | def test_take_along_dim(a_use_dask, ind_use_dask): 42 | """ind.ndim < a.ndim after broadcast""" 43 | a = DataArray( 44 | [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], 45 | dims=["z", "y", "x"], 46 | coords={"z": ["z1", "z2"], "y": ["y1", "y2"], "x": ["x1", "x2", "x3"]}, 47 | ) 48 | ind = DataArray( 49 | [[[1, 0], [2, 1]], [[2, 0], [2, 2]]], 50 | dims=["w", "y", "x"], 51 | coords={"y": ["y1", "y2"]}, 52 | ) 53 | 54 | expect = DataArray( 55 | [ 56 | [[[2, 1], [3, 1]], [[6, 5], [6, 6]]], 57 | [[[8, 7], [9, 7]], [[12, 11], [12, 12]]], 58 | ], 59 | dims=["z", "y", "w", "x"], 60 | coords={"z": ["z1", "z2"], "y": ["y1", "y2"]}, 61 | ) 62 | 63 | if a_use_dask: 64 | a = a.chunk(1) 65 | if ind_use_dask: 66 | ind = ind.chunk(1) 67 | 68 | actual = take_along_dim(a, ind, "x") 69 | assert_equal(expect, actual) 70 | if a_use_dask or ind_use_dask: 71 | assert actual.chunks 72 | 73 | 74 | @pytest.mark.parametrize("ind_use_dask", [False, True]) 75 | @pytest.mark.parametrize("a_use_dask", [False, True]) 76 | def test_take_along_dim2(a_use_dask, ind_use_dask): 77 | """ind.ndim > a.ndim after broadcast""" 78 | a = DataArray([1, 2, 3], dims=["x"]) 79 | ind = DataArray([[1, 0], [2, 1]], dims=["x", "y"]) 80 | 81 | expect = DataArray([[2, 3], [1, 2]], dims=["y", "x"]) 82 | 83 | if a_use_dask: 84 | a = a.chunk(1) 85 | if ind_use_dask: 86 | ind = ind.chunk(1) 87 | 88 | actual = take_along_dim(a, ind, "x") 89 | assert_equal(expect, actual) 90 | if a_use_dask or ind_use_dask: 91 | assert actual.chunks 92 | -------------------------------------------------------------------------------- /HOW_TO_RELEASE: -------------------------------------------------------------------------------- 1 | How to issue a release in 15 easy steps 2 | 3 | Time required: about an hour. 4 | 5 | 1. Ensure your main branch is synced to origin: 6 | git pull origin main 7 | 2. Look over whats-new.rst and the docs. Make sure "What's New" is complete 8 | (check the date!) and add a brief summary note describing the release at the 9 | top. 10 | 3. If you have any doubts, run the full test suite one final time! 11 | py.test 12 | 4. On the main branch, commit the release in git: 13 | git commit -a -m 'Release vX.Y.Z' 14 | 5. Tag the release: 15 | git tag -a vX.Y.Z -m 'vX.Y.Z' 16 | 6. Push your changes to main: 17 | git push origin main 18 | git push origin --tags 19 | 7. Update the stable branch (used by ReadTheDocs) and switch back to main: 20 | git checkout stable 21 | git rebase main 22 | git push origin stable 23 | git checkout main 24 | It's OK to force push to 'stable' if necessary. 25 | We also update the stable branch with `git cherrypick` for documentation 26 | only fixes that apply the current released version. 27 | 8. Build and test the release package 28 | python -m pip install --upgrade build twine 29 | rm -rf dist 30 | python -m build 31 | python -m twine check dist/* 32 | 9. Add a section for the next release to doc/whats-new.rst. 33 | 10. Commit your changes and push to main again: 34 | git commit -a -m 'Revert to dev version' 35 | git push origin main 36 | You're done pushing to main! 37 | 11. Issue the release on GitHub. Open https://github.com/crusaderky/xarray_extras/releases; 38 | the new release should have automatically appeared. Otherwise, click on 39 | "Draft a new release" and paste in the latest from whats-new.rst. 40 | 12. Use twine to register and upload the release on pypi. Be careful, you can't 41 | take this back! 42 | twine upload dist/* 43 | You will need to be listed as a package owner at 44 | https://pypi.python.org/pypi/xarray_extras for this to work. 45 | 13. Update the docs. Login to https://readthedocs.org/projects/xarray_extras/versions/ 46 | and switch your new release tag (at the bottom) from "Inactive" to "Active". 47 | It should now build automatically. 48 | Make sure that both the new tagged version and 'stable' build successfully. 49 | 14. Update conda-forge. 50 | 14a. Clone https://github.com/conda-forge/xarray_extras-feedstock 51 | 14b. Update the version number and sha256 in meta.yaml. 52 | You can calculate sha256 with 53 | sha256sum dist/* 54 | 14c. Double check dependencies in meta.yaml and update them to match pyproject.toml. 55 | 14d. Submit a pull request. 56 | 14e. Write a comment in the PR: 57 | @conda-forge-admin, please rerender 58 | Wait for the rerender commit (it may take a few minutes). 59 | 14f. Wait for CI to pass and merge. 60 | 14g. The next day, test the conda-forge release 61 | conda search xarray_extras 62 | conda create -n xarray_extras-test xarray_extras 63 | conda activate xarray_extras-test 64 | conda list 65 | python -c 'import xarray_extras' 66 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: ["*"] 8 | workflow_dispatch: # allows you to trigger manually 9 | 10 | # When this workflow is queued, automatically cancel any previous running 11 | # or pending jobs from the same branch 12 | concurrency: 13 | group: tests-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | defaults: 17 | run: 18 | shell: bash -l {0} 19 | 20 | jobs: 21 | build: 22 | name: 23 | ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.requirements }} 24 | runs-on: ${{ matrix.os }}-latest 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | os: [ubuntu] 29 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] 30 | requirements: [latest] 31 | include: 32 | # Test on macos and windows (first and last version of python only) 33 | - os: macos 34 | python-version: "3.8" 35 | requirements: latest 36 | - os: macos 37 | python-version: "3.13" 38 | requirements: latest 39 | - os: windows 40 | python-version: "3.8" 41 | requirements: latest 42 | - os: windows 43 | python-version: "3.13" 44 | requirements: latest 45 | # Test on minimal requirements 46 | - os: ubuntu 47 | python-version: "3.8" 48 | requirements: minimal 49 | - os: macos 50 | python-version: "3.8" 51 | requirements: minimal 52 | - os: windows 53 | python-version: "3.8" 54 | requirements: minimal 55 | # Test on nightly builds of requirements 56 | - os: ubuntu 57 | python-version: "3.13" 58 | requirements: upstream 59 | 60 | steps: 61 | - name: Checkout 62 | uses: actions/checkout@v4 63 | with: 64 | fetch-depth: 0 65 | 66 | - name: Setup Conda Environment 67 | uses: conda-incubator/setup-miniconda@v3 68 | with: 69 | miniforge-version: latest 70 | use-mamba: true 71 | python-version: ${{ matrix.python-version }} 72 | environment-file: ci/requirements-${{ matrix.requirements }}.yml 73 | activate-environment: xarray-extras 74 | 75 | - name: Install Linux compile env 76 | if: ${{ matrix.os == 'linux' }} 77 | run: mamba install gcc_linux-64 78 | 79 | - name: Install MacOS compile env 80 | if: ${{ matrix.os == 'macosx' }} 81 | run: mamba install clang_osx-64 82 | 83 | - name: Install nightly builds 84 | if: ${{ matrix.requirements == 'upstream' }} 85 | run: | 86 | mamba uninstall --force numpy pandas scipy pyarrow 87 | python -m pip install --no-deps --pre --prefer-binary \ 88 | --extra-index-url https://pypi.fury.io/arrow-nightlies/ \ 89 | pyarrow 90 | python -m pip install --no-deps --pre \ 91 | -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \ 92 | numpy pandas scipy 93 | 94 | - name: Show conda options 95 | run: conda config --show 96 | 97 | - name: conda info 98 | run: conda info 99 | 100 | - name: conda list 101 | run: conda list 102 | 103 | - name: Install 104 | run: python -m pip install --no-deps -e . 105 | 106 | - name: pytest 107 | run: py.test --verbose --cov=xarray_extras --cov-report=xml 108 | 109 | - name: codecov.io 110 | uses: codecov/codecov-action@v3 111 | -------------------------------------------------------------------------------- /xarray_extras/duck/sort.py: -------------------------------------------------------------------------------- 1 | """Helper functions for :mod:`xarray_extras.sort`, which accept either 2 | numpy arrays or dask arrays. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import TypeVar 8 | 9 | import dask.array as da 10 | import numpy as np 11 | from dask.array.slicing import slice_with_int_dask_array_on_axis 12 | from xarray.core.duck_array_ops import broadcast_to 13 | 14 | T = TypeVar("T", np.ndarray, da.Array) 15 | 16 | 17 | def topk(a: T, k: int, split_every: int | None = None) -> T: 18 | """If a is a :class:`dask.array.Array`, invoke a.topk; else reimplement 19 | the functionality in plain numpy. 20 | """ 21 | if isinstance(a, da.Array): 22 | return a.topk(k, split_every=split_every) 23 | 24 | if abs(k) < a.shape[-1]: 25 | a = np.partition(a, -k) 26 | if k > 0: 27 | a = a[..., -k:] 28 | else: 29 | a = a[..., :-k] 30 | 31 | # Sort the partitioned output 32 | a = np.sort(a) 33 | if k > 0: 34 | # Sort from greatest to smallest 35 | return a[..., ::-1] 36 | return a 37 | 38 | 39 | def argtopk(a: T, k: int, split_every: int | None = None) -> T: 40 | """If a is a :class:`dask.array.Array`, invoke a.argtopk; else reimplement 41 | the functionality in plain numpy. 42 | """ 43 | if isinstance(a, da.Array): 44 | return a.argtopk(k, split_every=split_every) 45 | 46 | idx = np.argpartition(a, -k) 47 | if k > 0: 48 | idx = idx[..., -k:] 49 | else: 50 | idx = idx[..., :-k] 51 | 52 | a = np.take_along_axis(a, idx, axis=-1) 53 | idx = np.take_along_axis(idx, a.argsort(), axis=-1) 54 | if k > 0: 55 | # Sort from greatest to smallest 56 | return idx[..., ::-1] 57 | return idx 58 | 59 | 60 | def take_along_axis( 61 | a: np.ndarray | da.Array, ind: np.ndarray | da.Array 62 | ) -> np.ndarray | da.Array: 63 | """Easily use the outputs of argsort on ND arrays to pick the results.""" 64 | if isinstance(a, np.ndarray) and isinstance(ind, np.ndarray): 65 | a = a.reshape((1,) * (ind.ndim - a.ndim) + a.shape) 66 | ind = ind.reshape((1,) * (a.ndim - ind.ndim) + ind.shape) 67 | return np.take_along_axis(a, ind, axis=-1) 68 | 69 | # a and/or ind are dask arrays. This is not yet implemented upstream. 70 | # Upstream tracker: https://github.com/dask/dask/issues/3663 71 | 72 | # This is going to be an ugly and slow mess, as dask does not support 73 | # fancy indexing. 74 | 75 | # Normalize a and ind. The end result is that a can have more axes than 76 | # ind on the left, but not vice versa, and that all axes except the 77 | # extra ones on the left and the rightmost one (the axis to take 78 | # along) are the same shape. 79 | if ind.ndim > a.ndim: 80 | a = a.reshape((1,) * (ind.ndim - a.ndim) + a.shape) 81 | common_shape = tuple(np.maximum(a.shape[-ind.ndim : -1], ind.shape[:-1])) 82 | a_extra_shape = a.shape[: -ind.ndim] 83 | a = broadcast_to(a, a_extra_shape + common_shape + a.shape[-1:]) 84 | ind = broadcast_to(ind, common_shape + ind.shape[-1:]) 85 | 86 | # Flatten all common axes onto axis -2 87 | final_shape = a.shape[: -ind.ndim] + ind.shape 88 | ind = ind.reshape(ind.size // ind.shape[-1], ind.shape[-1]) 89 | a = a.reshape(*a_extra_shape, ind.shape[0], a.shape[-1]) 90 | 91 | # Now we have a[..., i, j] and ind[i, j], where i are the flattened 92 | # common axes and j is the axis to take along. 93 | res = [] 94 | 95 | # Cycle a and ind along i, perform 1D slices, and then stack them back 96 | # together 97 | for i in range(ind.shape[0]): 98 | a_i = a[..., i, :] 99 | ind_i = ind[i, :] 100 | 101 | if not isinstance(a_i, da.Array): 102 | a_i = da.from_array(a_i, chunks=a_i.shape) 103 | 104 | if isinstance(ind_i, da.Array): 105 | res_i = slice_with_int_dask_array_on_axis(a_i, ind_i, axis=a_i.ndim - 1) 106 | else: 107 | res_i = a_i[..., ind_i] 108 | res.append(res_i) 109 | 110 | res_arr = da.stack(res, axis=-2) 111 | # Un-flatten axis i 112 | return res_arr.reshape(*final_shape) 113 | -------------------------------------------------------------------------------- /xarray_extras/tests/test_stack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | import xarray 5 | 6 | from xarray_extras.stack import proper_unstack 7 | 8 | # FIXME https://github.com/crusaderky/xarray_extras/issues/33 9 | pytestmark = pytest.mark.filterwarnings( 10 | "ignore:Updating MultiIndexed coordinate .* would corrupt indices", 11 | "ignore:invalid value encountered in cast:RuntimeWarning", 12 | ) 13 | 14 | 15 | def test_proper_unstack_order(): 16 | # Note: using MultiIndex.from_tuples is NOT the same thing as 17 | # round-tripping DataArray.stack().unstack(), as the latter is not 18 | # affected by the re-ordering issue 19 | index = pd.MultiIndex.from_tuples( 20 | [ 21 | ["x1", "first"], 22 | ["x1", "second"], 23 | ["x1", "third"], 24 | ["x1", "fourth"], 25 | ["x0", "first"], 26 | ["x0", "second"], 27 | ["x0", "third"], 28 | ["x0", "fourth"], 29 | ], 30 | names=["x", "count"], 31 | ) 32 | xa = xarray.DataArray(np.arange(8), dims=["dim_0"], coords={"dim_0": index}) 33 | 34 | a = proper_unstack(xa, "dim_0") 35 | b = xarray.DataArray( 36 | [[0, 1, 2, 3], [4, 5, 6, 7]], 37 | dims=["x", "count"], 38 | coords={"x": ["x1", "x0"], "count": ["first", "second", "third", "fourth"]}, 39 | ) 40 | xarray.testing.assert_equal(a, b) 41 | with pytest.raises(AssertionError): 42 | # Order is different 43 | xarray.testing.assert_equal(a, xa.unstack("dim_0")) 44 | 45 | 46 | def test_proper_unstack_dtype(): 47 | """Test that we don't accidentally end up with dtype=O for the coords""" 48 | a = xarray.DataArray( 49 | [[0, 1, 2, 3], [4, 5, 6, 7]], 50 | dims=["r", "c"], 51 | coords={ 52 | "r": pd.to_datetime(["2000/01/01", "2000/01/02"]), 53 | "c": [1, 2, 3, 4], 54 | }, 55 | ) 56 | b = a.stack(s=["r", "c"]) 57 | c = proper_unstack(b, "s") 58 | xarray.testing.assert_equal(a, c) 59 | 60 | 61 | def test_proper_unstack_mixed_coords(): 62 | a = xarray.DataArray( 63 | [[0, 1, 2, 3], [4, 5, 6, 7]], 64 | dims=["r", "c"], 65 | coords={"r": [1, "x0"], "c": [1, 2.2, "3", "fourth"]}, 66 | ) 67 | b = a.stack(s=["r", "c"]) 68 | c = proper_unstack(b, "s") 69 | xarray.testing.assert_equal(a, c) 70 | 71 | 72 | def test_proper_unstack_dataset(): 73 | a = xarray.DataArray( 74 | [[1, 2, 3, 4], [5, 6, 7, 8]], 75 | dims=["x", "col"], 76 | coords={ 77 | "x": ["x0", "x1"], 78 | "col": pd.MultiIndex.from_tuples( 79 | [("u0", "v0"), ("u0", "v1"), ("u1", "v0"), ("u1", "v1")], 80 | names=["u", "v"], 81 | ), 82 | }, 83 | ) 84 | xa = xarray.Dataset({"foo": a, "bar": ("w", [1, 2]), "baz": np.pi}) 85 | b = proper_unstack(xa, "col") 86 | c = xarray.DataArray( 87 | [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 88 | dims=["x", "u", "v"], 89 | coords={"x": ["x0", "x1"], "u": ["u0", "u1"], "v": ["v0", "v1"]}, 90 | ) 91 | d = xarray.Dataset({"foo": c, "bar": ("w", [1, 2]), "baz": np.pi}) 92 | xarray.testing.assert_equal(b, d) 93 | for c in b.coords: 94 | assert b.coords[c].dtype.kind == "U" 95 | 96 | 97 | def test_proper_unstack_other_mi(): 98 | a = xarray.DataArray( 99 | [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]], 100 | dims=["row", "col"], 101 | coords={ 102 | "row": pd.MultiIndex.from_tuples( 103 | [("x0", "w0"), ("x0", "w1"), ("x1", "w0"), ("x1", "w1")], 104 | names=["x", "w"], 105 | ), 106 | "col": pd.MultiIndex.from_tuples( 107 | [("y0", "z0"), ("y0", "z1"), ("y1", "z0"), ("y1", "z1")], 108 | names=["y", "z"], 109 | ), 110 | }, 111 | ) 112 | b = proper_unstack(a, "row") 113 | c = xarray.DataArray( 114 | [[[1, 5], [1, 5]], [[2, 6], [2, 6]], [[3, 7], [3, 7]], [[4, 8], [4, 8]]], 115 | dims=["col", "x", "w"], 116 | coords={ 117 | "col": pd.MultiIndex.from_tuples( 118 | [("y0", "z0"), ("y0", "z1"), ("y1", "z0"), ("y1", "z1")], 119 | names=["y", "z"], 120 | ), 121 | "x": ["x0", "x1"], 122 | "w": ["w0", "w1"], 123 | }, 124 | ) 125 | xarray.testing.assert_equal(b, c) 126 | -------------------------------------------------------------------------------- /xarray_extras/tests/test_cumulatives.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import xarray 4 | from xarray.testing import assert_equal 5 | 6 | pytest.importorskip("numba") # Not available in upstream CI 7 | 8 | import xarray_extras.cumulatives as cum 9 | 10 | # Skip 0 and 1 as they're neutral in addition and multiplication 11 | INPUT = xarray.DataArray( 12 | [[2, 20, 25], [3, 30, 35], [4, 40, 45], [5, 50, 55]], 13 | dims=["t", "s"], 14 | coords={ 15 | "t": np.array( 16 | ["1990-12-30", "2000-12-30", "2005-12-30", "2010-12-30"], dtype="M8[ns]" 17 | ), 18 | "s": ["s1", "s2", "s3"], 19 | }, 20 | ) 21 | 22 | 23 | T_COMPOUND_MATRIX = xarray.DataArray( 24 | np.array( 25 | [ 26 | ["1990-12-30", "NaT", "NaT"], 27 | ["1990-12-30", "2005-12-30", "NaT"], 28 | ["2000-12-30", "1990-12-30", "NaT"], 29 | ["2010-12-30", "1990-12-30", "2005-12-30"], 30 | ], 31 | dtype="M8[ns]", 32 | ), 33 | dims=["t2", "c"], 34 | coords={"t2": [10, 20, 30, 40]}, 35 | ) 36 | 37 | 38 | S_COMPOUND_MATRIX = xarray.DataArray( 39 | [["s3", "s2"], ["s1", ""]], dims=["s2", "c"], coords={"s2": ["foo", "bar"]} 40 | ) 41 | 42 | DTYPES = ( 43 | # There's a bug in numba.guvectorize for u8, u16, u32 44 | # i8 and i16 are too short to store the output 45 | "int32", 46 | "int64", 47 | "uint64", 48 | "float32", 49 | "float64", 50 | "complex64", 51 | "complex128", 52 | ) 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "func, meth", 57 | [ 58 | (cum.compound_sum, "sum"), 59 | (cum.compound_prod, "prod"), 60 | (cum.compound_mean, "mean"), 61 | ], 62 | ) 63 | @pytest.mark.parametrize("dtype", DTYPES) 64 | @pytest.mark.parametrize("use_dask", [False, True]) 65 | def test_compound_t(func, meth, dtype, use_dask): 66 | x = INPUT.astype(dtype) 67 | c = T_COMPOUND_MATRIX 68 | expect = xarray.concat( 69 | [ 70 | getattr(x.isel(t=[0]), meth)("t"), 71 | getattr(x.isel(t=[0, 2]), meth)("t"), 72 | getattr(x.isel(t=[1, 0]), meth)("t"), 73 | getattr(x.isel(t=[3, 0, 2]), meth)("t"), 74 | ], 75 | dim="t2", 76 | ).T.astype(dtype) 77 | expect.coords["t2"] = c.coords["t2"] 78 | 79 | if use_dask: 80 | x = x.chunk({"s": 2}) 81 | expect = expect.chunk({"s": 2}) 82 | c = c.chunk() 83 | 84 | actual = func(x, c, "t", "c") 85 | 86 | if use_dask: 87 | assert_equal(expect.compute(), actual.compute()) 88 | else: 89 | assert_equal(expect, actual) 90 | 91 | assert expect.dtype == actual.dtype 92 | assert actual.chunks == expect.chunks 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "func, meth", 97 | [ 98 | (cum.compound_sum, "sum"), 99 | (cum.compound_prod, "prod"), 100 | (cum.compound_mean, "mean"), 101 | ], 102 | ) 103 | @pytest.mark.parametrize("dtype", DTYPES) 104 | @pytest.mark.parametrize("use_dask", [False, True]) 105 | def test_compound_s(func, meth, dtype, use_dask): 106 | x = INPUT.astype(dtype) 107 | c = S_COMPOUND_MATRIX 108 | expect = xarray.concat( 109 | [ 110 | getattr(x.sel(s=["s3", "s2"]), meth)("s"), 111 | getattr(x.sel(s=["s1"]), meth)("s"), 112 | ], 113 | dim="s2", 114 | ).T.astype(dtype) 115 | expect.coords["s2"] = c.coords["s2"] 116 | 117 | if use_dask: 118 | x = x.chunk({"t": 2}) 119 | expect = expect.chunk({"t": 2}) 120 | c = c.chunk() 121 | 122 | actual = func(x, c, "s", "c") 123 | 124 | if use_dask: 125 | assert_equal(expect.compute(), actual.compute()) 126 | else: 127 | assert_equal(expect, actual) 128 | 129 | assert expect.dtype == actual.dtype 130 | assert actual.chunks == expect.chunks 131 | 132 | 133 | @pytest.mark.parametrize("dtype", [float, int, "complex128"]) 134 | @pytest.mark.parametrize("skipna", [False, True, None]) 135 | @pytest.mark.parametrize("use_dask", [False, True]) 136 | def test_cummean(use_dask, skipna, dtype): 137 | x = INPUT.copy(deep=True).astype(dtype) 138 | if dtype in (float, "complex128"): 139 | x[2, 1] = np.nan 140 | 141 | expect = xarray.concat( 142 | [ 143 | x[:1].mean("t", skipna=skipna), 144 | x[:2].mean("t", skipna=skipna), 145 | x[:3].mean("t", skipna=skipna), 146 | x[:4].mean("t", skipna=skipna), 147 | ], 148 | dim="t", 149 | ) 150 | expect.coords["t"] = x.coords["t"] 151 | if use_dask: 152 | x = x.chunk({"s": 2, "t": 3}) 153 | expect = expect.chunk({"s": 2, "t": 3}) 154 | 155 | actual = cum.cummean(x, "t", skipna=skipna) 156 | assert_equal(expect, actual) 157 | assert expect.dtype == actual.dtype 158 | assert actual.chunks == expect.chunks 159 | -------------------------------------------------------------------------------- /xarray_extras/kernels/csv.py: -------------------------------------------------------------------------------- 1 | """dask kernels for :mod:`xarray_extras.csv`""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from xarray_extras.kernels.np_to_csv_py import snprintcsvd, snprintcsvi 11 | 12 | 13 | def to_csv( 14 | x: np.ndarray, 15 | index: pd.Index, 16 | columns: pd.Index, 17 | first_chunk: bool, 18 | nogil: bool, 19 | kwargs: dict, 20 | ) -> bytes: 21 | """Format x into CSV and encode it to binary 22 | 23 | :param x: 24 | numpy.ndarray with 1 or 2 dimensions 25 | :param pandas.Index index: 26 | row index 27 | :param pandas.Index columns: 28 | column index. None for Series or for DataFrame chunks beyond the first. 29 | :param bool first_chunk: 30 | True if this is the first chunk; False otherwise 31 | :param bool nogil: 32 | If True, use accelerated C implementation. Several kwargs won't be 33 | processed correctly. If False, use pandas to_csv method (slow, and does 34 | not release the GIL). 35 | :param kwargs: 36 | arguments passed to pandas to_csv methods 37 | :returns: 38 | CSV file contents, encoded in UTF-8 39 | """ 40 | if x.ndim == 1: 41 | assert columns is None 42 | x_pd = pd.Series(x, index) 43 | elif x.ndim == 2: 44 | x_pd = pd.DataFrame(x, index, columns) 45 | else: 46 | # proper ValueError already raised in wrapper 47 | raise AssertionError("unreachable") # pragma: nocover 48 | 49 | encoding = kwargs.pop("encoding", "utf-8") 50 | header = kwargs.pop("header", True) 51 | lineterminator = kwargs.pop("lineterminator", os.linesep) 52 | 53 | if not nogil or not x.size: 54 | out = x_pd.to_csv(header=header, lineterminator=lineterminator, **kwargs) 55 | bout = out.encode(encoding) 56 | if encoding == "utf-16" and not first_chunk: 57 | # utf-16 contains a bang at the beginning of the text. However, 58 | # when concatenating multiple chunks we don't want to replicate it. 59 | assert bout[:2] == b"\xff\xfe" 60 | bout = bout[2:] 61 | return bout 62 | 63 | sep = kwargs.get("sep", ",") 64 | fmt = kwargs.get("float_format") 65 | na_rep = kwargs.get("na_rep", "") 66 | 67 | # Use pandas to format index 68 | if x.ndim == 1: 69 | x_df = x_pd.to_frame() 70 | else: 71 | x_df = x_pd 72 | 73 | index_csv = x_df.iloc[:, :0].to_csv(header=False, lineterminator="\n", **kwargs) 74 | index_csv = index_csv.strip().split("\n") 75 | if len(index_csv) != x.shape[0]: 76 | index_csv = "\n" * x.shape[0] 77 | else: 78 | index_csv = "\n".join(r + sep if r else "" for r in index_csv) + "\n" 79 | 80 | # Invoke C code to format the values. This releases the GIL. 81 | if x.dtype.kind == "i": 82 | body_bytes = snprintcsvi(x, index_csv, sep) 83 | elif x.dtype.kind == "f": 84 | body_bytes = snprintcsvd(x, index_csv, sep, fmt, na_rep) 85 | else: 86 | raise NotImplementedError("only int and float are supported when nogil=True") 87 | 88 | if header is not False: 89 | header_bytes = ( 90 | x_df.iloc[:0, :] 91 | .to_csv(header=header, lineterminator="\n", **kwargs) 92 | .encode("utf-8") 93 | ) 94 | body_bytes = header_bytes + body_bytes 95 | 96 | if encoding not in {"ascii", "utf-8"}: 97 | # Everything is encoded in UTF-8 until this moment. Recode if needed. 98 | body_str = body_bytes.decode("utf-8") 99 | if lineterminator != "\n": 100 | body_str = body_str.replace("\n", lineterminator) 101 | body_bytes = body_str.encode(encoding) 102 | if encoding == "utf-16" and not first_chunk: 103 | # utf-16 contains a bang at the beginning of the text. However, 104 | # when concatenating multiple chunks we don't want to replicate it. 105 | assert body_bytes[:2] == b"\xff\xfe" 106 | body_bytes = body_bytes[2:] 107 | elif lineterminator != "\n": 108 | body_bytes = body_bytes.replace(b"\n", lineterminator.encode("utf-8")) 109 | 110 | return body_bytes 111 | 112 | 113 | def to_file(fname: str, mode: str, data: str | bytes, rr_token: object) -> None: # noqa: ARG001 114 | """Write data to file 115 | 116 | :param fname: 117 | File path on disk 118 | :param mode: 119 | As in 'open' 120 | :param data: 121 | Binary or text data to write 122 | :param rr_token: 123 | Round-robin token passed by to_file from the previous chunk. It 124 | guarantees write order across multiple, otherwise parallel, tasks. This 125 | parameter is only used by the dask scheduler. 126 | """ 127 | with open(fname, mode) as fh: 128 | fh.write(data) 129 | -------------------------------------------------------------------------------- /xarray_extras/kernels/np_to_csv_py.py: -------------------------------------------------------------------------------- 1 | """Thin ctypes wrapper around :file:`np_to_csv.c`. 2 | This is a helper module of :mod:`xarray_extras.kernels.csv`. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import ctypes 8 | import warnings 9 | 10 | import numpy as np 11 | 12 | from xarray_extras.kernels import np_to_csv # type:ignore[attr-defined] 13 | 14 | with warnings.catch_warnings(): 15 | warnings.simplefilter("ignore", category=DeprecationWarning) 16 | np_to_csv = np.ctypeslib.load_library("np_to_csv", np_to_csv.__file__) 17 | 18 | np_to_csv.snprintcsvd.argtypes = [ 19 | ctypes.c_char_p, # char * buf 20 | ctypes.c_int32, # int bufsize 21 | # const double * array 22 | np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags="C_CONTIGUOUS"), 23 | ctypes.c_int32, # int h 24 | ctypes.c_int32, # int w 25 | ctypes.c_char_p, # const char * index 26 | ctypes.c_char_p, # const char * fmt 27 | ctypes.c_bool, # bool trim_zeros 28 | ctypes.c_char_p, # const char * na_rep 29 | ] 30 | np_to_csv.snprintcsvd.restype = ctypes.c_int32 31 | 32 | 33 | np_to_csv.snprintcsvi.argtypes = [ 34 | ctypes.c_char_p, # char * buf 35 | ctypes.c_int32, # int bufsize 36 | # const int64_t * array 37 | np.ctypeslib.ndpointer(dtype=np.int64, ndim=2, flags="C_CONTIGUOUS"), 38 | ctypes.c_int32, # int h 39 | ctypes.c_int32, # int w 40 | ctypes.c_char_p, # const char * index 41 | ctypes.c_char, # char sep 42 | ] 43 | np_to_csv.snprintcsvi.restype = ctypes.c_int32 44 | 45 | 46 | def snprintcsvd( 47 | a: np.ndarray, 48 | index: str, 49 | sep: str = ",", 50 | fmt: str | None = None, 51 | na_rep: str = "", 52 | ) -> bytes: 53 | """Convert array to CSV. 54 | 55 | :param a: 56 | 1D or 2D numpy array of floats 57 | :param str index: 58 | newline-separated list of prefixes for every row of a 59 | :param str sep: 60 | cell separator 61 | :param str fmt: 62 | printf formatting string for a single float number 63 | Set to None to replicate pandas to_csv default behaviour 64 | :param str na_rep: 65 | string representation of NaN 66 | :return: 67 | CSV file contents, binary-encoded in ascii format. 68 | The line terminator is always \n on all OSs. 69 | """ 70 | if a.ndim == 1: 71 | a = a.reshape((-1, 1)) 72 | if a.ndim != 2 or a.dtype.kind != "f": 73 | raise ValueError("Expected 2d numpy array of floats") 74 | a = a.astype(np.float64) 75 | a = np.ascontiguousarray(a) 76 | if len(sep) != 1: 77 | raise ValueError("sep must be exactly 1 character") 78 | bsep = sep.encode("ascii") 79 | 80 | # Test fmt while in Python - much better to get 81 | # an Exception here than a segfault in C! 82 | if fmt is not None: 83 | fmt % 1.23 84 | bfmt = fmt.encode("ascii") + bsep 85 | trim_zeros = False 86 | else: 87 | bfmt = b"%f" + bsep 88 | trim_zeros = True 89 | bna_rep = na_rep.encode("ascii") + bsep 90 | # We're relying on the fact that ascii is a strict subset of UTF-8 91 | bindex = index.encode("utf-8") 92 | 93 | # Blindly try ever-larger bufsizes until it fits 94 | # The first iteration should be sufficient in all but the most 95 | # degenerate cases. 96 | # FIXME: is there a better way? 97 | cellsize = 40 98 | while True: 99 | bufsize = cellsize * a.size + len(bindex) 100 | buf = ctypes.create_string_buffer(bufsize) 101 | nchar = np_to_csv.snprintcsvd( 102 | buf, bufsize, a, a.shape[0], a.shape[1], bindex, bfmt, trim_zeros, bna_rep 103 | ) 104 | if nchar < bufsize: 105 | return bytes(buf[:nchar]) # type: ignore[arg-type] 106 | cellsize *= 2 107 | 108 | 109 | def snprintcsvi(a: np.ndarray, index: str, sep: str = ",") -> bytes: 110 | """Convert array to CSV. 111 | 112 | :param a: 113 | 1D or 2D numpy array of integers 114 | :param str index: 115 | newline-separated list of prefixes for every row of a 116 | :param str sep: 117 | cell separator 118 | :return: 119 | CSV file contents, binary-encoded in ascii format. 120 | The line terminator is always \n on all OSs. 121 | """ 122 | if a.ndim == 1: 123 | a = a.reshape((-1, 1)) 124 | if a.ndim != 2 or a.dtype.kind != "i": 125 | raise ValueError("Expected 2d numpy array of ints") 126 | a = a.astype(np.int64) 127 | a = np.ascontiguousarray(a) 128 | if len(sep) != 1: 129 | raise ValueError("sep must be exactly 1 character") 130 | bsep = sep.encode("ascii") 131 | # We're relying on the fact that ascii is a strict subset of UTF-8 132 | bindex = index.encode("utf-8") 133 | 134 | cellsize = 22 # len('%d' % -2**64) + 1 135 | bufsize = cellsize * a.size + len(bindex) 136 | buf = ctypes.create_string_buffer(bufsize) 137 | nchar = np_to_csv.snprintcsvi(buf, bufsize, a, a.shape[0], a.shape[1], bindex, bsep) 138 | assert nchar < bufsize 139 | return bytes(buf[:nchar]) # type: ignore[arg-type] 140 | -------------------------------------------------------------------------------- /doc/whats-new.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: xarray_extras 2 | 3 | What's New 4 | ========== 5 | 6 | .. _whats-new.0.7.0: 7 | 8 | v0.7.0 (unreleased) 9 | ------------------- 10 | - ``interpolate``: Added support for timedelta coordinates; 11 | fixed support for datetime coordinates other than ``M8[ns]`` in xarray 2025.1.2. 12 | - Added formal support for Python 3.13 (but the previous version works fine too) 13 | 14 | 15 | .. _whats-new.0.6.0: 16 | 17 | v0.6.0 (2024-03-16) 18 | ------------------- 19 | - Bumped minimum version of all dependencies: 20 | 21 | ========== ======== ========= 22 | Dependency v0.5.0 v0.6.0 23 | ========== ======== ========= 24 | python 3.7 3.8 25 | dask 2021.4.0 2022.6.0 26 | numba 0.52 0.56 27 | numpy 1.18 1.23 28 | pandas 1.1 1.5 29 | scipy 1.5 1.9 30 | xarray 0.16 2022.11.0 31 | ========== ======== ========= 32 | 33 | - Added support for Python 3.10, 3.11, and 3.12 34 | - Added support for recent versions of Pandas (tested up to 2.2) and xarray 35 | - Added support for :cls:`pathlib.Path` in function arguments 36 | - Migrated from setup.cfg to pyproject.toml 37 | - Migrated from flake8+isort+pyupgrade to ruff 38 | 39 | 40 | .. _whats-new.0.5.0: 41 | 42 | v0.5.0 (2021-12-01) 43 | ------------------- 44 | - Bumped minimum version of all dependencies: 45 | 46 | ========== ====== ======== 47 | Dependency v0.4.2 v0.5.0 48 | ========== ====== ======== 49 | python 3.5 3.7 50 | dask 0.19 2021.4.0 51 | numba 0.34 0.52 52 | numpy 1.13 1.18 53 | pandas 0.21 1.1 54 | scipy 1.0 1.5 55 | xarray 0.10.1 0.16 56 | ========== ====== ======== 57 | 58 | - Added support for Python 3.8 and 3.9 59 | - Removed ``xarray_extras.backports`` module 60 | - Migrated CI to github workflows 61 | - Added code linters: pyupgrade, isort, black 62 | - Run all code linters through pre-commit 63 | - Use setuptools-scm for versioning 64 | - Moved the whole contents of setup.py to setup.cfg 65 | 66 | 67 | .. _whats-new.0.4.2: 68 | 69 | v0.4.2 (2019-06-03) 70 | ------------------- 71 | 72 | - Type annotations 73 | - Mandatory mypy validation in CI 74 | - CI unit tests for Windows now run on Python 3.7 75 | - Compatibility with dask >= 1.1 76 | - Suppress deprecation warnings with pandas >= 0.24 77 | - :func:`~xarray_extras.csv.to_csv` changes: 78 | 79 | - When invoked on a 1-dimensional DataArray, 80 | the default value for the ``index`` parameter has been changed from False to 81 | True, coherently to the default for pandas.Series.to_csv from pandas 0.24. 82 | This applies also to users who have pandas < 0.24 installed. 83 | - support for ``line_terminator`` parameter (all pandas versions); 84 | - fix incorrect line terminator in Windows with pandas >= 0.24 85 | - support for ``compression='infer'`` (all pandas versions) 86 | - support for ``compression`` parameter with pandas < 0.23 87 | 88 | 89 | .. _whats-new.0.4.1: 90 | 91 | v0.4.1 (2019-02-02) 92 | ------------------- 93 | 94 | - Fixed build regression in `readthedocs `_ 95 | 96 | 97 | .. _whats-new.0.4.0: 98 | 99 | v0.4.0 (2019-02-02) 100 | ------------------- 101 | 102 | - Moved ``recursive_diff``, ``recursive_eq`` and ``ncdiff`` 103 | to their own package `recursive_diff `_ 104 | - Fixed bug in :func:`~xarray_extras.stack.proper_unstack` where unstacking 105 | coords with dtype=datetime64 would convert them to integer 106 | - Mandatory flake8 in CI 107 | 108 | 109 | .. _whats-new.0.3.0: 110 | 111 | v0.3.0 (2018-12-13) 112 | ------------------- 113 | 114 | - Changed license to Apache 2.0 115 | - Increased minimum versions: dask >= 0.19, pandas >= 0.21, 116 | xarray >= 0.10.1, pytest >= 3.6 117 | - New function :func:`~xarray_extras.stack.proper_unstack` 118 | - New functions ``recursive_diff`` and ``ecursive_eq`` 119 | - New command-line tool ``ncdiff`` 120 | - Blacklisted Python 3.7 conda-forge builds in CI tests 121 | 122 | 123 | .. _whats-new.0.2.2: 124 | 125 | v0.2.2 (2018-07-24) 126 | ------------------- 127 | 128 | - Fixed segmentation faults in :func:`~xarray_extras.csv.to_csv` 129 | - Added conda-forge travis build 130 | - Blacklisted dask-0.18.2 because of regression in argtopk(split_every=2) 131 | 132 | 133 | .. _whats-new.0.2.1: 134 | 135 | v0.2.1 (2018-07-22) 136 | ------------------- 137 | 138 | - Added parameter nogil=True to :func:`~xarray_extras.csv.to_csv`, which will 139 | switch to a C-accelerated implementation instead of pandas to_csv (albeit 140 | with caveats). Fixed deadlock in to_csv as well as compatibility with dask 141 | distributed. Pandas code (when using nogil=False) is not wrapped by a 142 | subprocess anymore, which means it won't be able to use more than 1 CPU 143 | (but compression can run in pipeline). 144 | to_csv has lost the ability to write to a buffer - only file paths are 145 | supported now. 146 | - AppVeyor integration 147 | 148 | 149 | .. _whats-new.0.2.0: 150 | 151 | v0.2.0 (2018-07-15) 152 | ------------------- 153 | 154 | - New function :func:`xarray_extras.csv.to_csv` 155 | - Speed up interpolation for k=2 and k=3 156 | - CI: Rigorous tracking of minimum dependency versions 157 | - CI: Explicit support for Python 3.7 158 | 159 | 160 | .. _whats-new.0.1.0: 161 | 162 | v0.1.0 (2018-05-19) 163 | ------------------- 164 | 165 | Initial release. 166 | -------------------------------------------------------------------------------- /xarray_extras/kernels/np_to_csv.c: -------------------------------------------------------------------------------- 1 | /* 2 | * High speed implementation for to_csv() 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #ifdef _WIN32 13 | # define LIBRARY_API __declspec(dllexport) 14 | #else 15 | # define LIBRARY_API 16 | #endif 17 | 18 | // In case of buffer overflow, in Windows, snprintf returns -1 19 | // In Linux, it returns the number of characters that would have been written 20 | #define CHECK_OVERFLOW if (ret < 0 || char_count >= bufsize) return bufsize; 21 | 22 | 23 | static PyMethodDef module_methods[] = { 24 | {NULL, NULL, 0, NULL} 25 | }; 26 | 27 | 28 | PyMODINIT_FUNC PyInit_np_to_csv(void) 29 | { 30 | PyObject *module; 31 | static struct PyModuleDef moduledef = { 32 | PyModuleDef_HEAD_INIT, 33 | "np_to_csv", 34 | "High speed implementation for to_csv()", 35 | -1, 36 | module_methods, 37 | NULL, 38 | NULL, 39 | NULL, 40 | NULL 41 | }; 42 | module = PyModule_Create(&moduledef); 43 | if (!module) return NULL; 44 | 45 | return module; 46 | } 47 | 48 | 49 | /* Convert 2D array of doubles to CSV 50 | * 51 | * buf : output buffer 52 | * bufsize : maximum number of characters that can be written to buf 53 | * array : input numerical data 54 | * h : number of rows in array 55 | * w : number of columns in array 56 | * index : newline-separated list of prefix strings, one per row 57 | * fmt : printf formatting, including the cell separator at the end. 58 | * cell separator must be exactly 1 character. 59 | * trim_zeros : if true, trim trailing zeros after the . beyond the first 60 | * e.g. 1.000 -> 1.0 61 | * na_rep : string representation for NaN, including the cell separator at the end 62 | * 63 | * The line terminator is always \n, regardless of OS. 64 | */ 65 | LIBRARY_API 66 | int snprintcsvd(char * buf, int bufsize, const double * array, int h, int w, 67 | const char * index, const char * fmt, bool trim_zeros, const char * na_rep) 68 | { 69 | int char_count = 0; 70 | int ret = 0; 71 | int i, j; 72 | 73 | // Move along a single column, printing the value of each row 74 | for (i = 0; i < h; i++) { 75 | // Print row header 76 | while (1) { 77 | CHECK_OVERFLOW; 78 | char c = *(index++); 79 | assert(c != 0); 80 | if (c == '\n') { 81 | break; 82 | } 83 | buf[char_count++] = c; 84 | } 85 | CHECK_OVERFLOW; 86 | 87 | // Print row values 88 | for (j = 0; j < w; j++) { 89 | double n = *(array++); 90 | if (isnan(n)) { 91 | ret = snprintf(buf + char_count, bufsize - char_count, "%s", na_rep); 92 | char_count += ret; 93 | CHECK_OVERFLOW; 94 | } 95 | else { 96 | ret = snprintf(buf + char_count, bufsize - char_count, fmt, n); 97 | char_count += ret; 98 | CHECK_OVERFLOW; 99 | 100 | if (trim_zeros) { 101 | while (char_count > 2 && 102 | buf[char_count - 2] == '0' && 103 | buf[char_count - 3] != '.') { 104 | buf[char_count - 2] = buf[char_count - 1]; 105 | char_count--; 106 | } 107 | } 108 | } 109 | } 110 | // Replace latest column separator with line terminator 111 | buf[char_count - 1] = '\n'; 112 | } 113 | 114 | return char_count; 115 | } 116 | 117 | 118 | /* Convert 2D array of int64's to CSV 119 | * 120 | * buf : output buffer 121 | * bufsize : maximum number of characters that can be written to buf 122 | * array : input numerical data 123 | * h : number of rows in array 124 | * w : number of columns in array 125 | * index : newline-separated list of prefix strings, one per row 126 | * sep : cell separator 127 | * 128 | * The line terminator is always \n, regardless of OS. 129 | */ 130 | LIBRARY_API 131 | int snprintcsvi(char * buf, int bufsize, const int64_t * array, int h, int w, 132 | const char * index, char sep) 133 | { 134 | int char_count = 0; 135 | int ret = 0; 136 | int i, j; 137 | 138 | // '%d' + sep, but for int64_t 139 | char fmt[sizeof(PRId64) + 2]; 140 | fmt[0] = '%'; 141 | strcpy(fmt + 1, PRId64); 142 | fmt[sizeof(PRId64)] = sep; 143 | fmt[sizeof(PRId64) + 1] = 0; 144 | 145 | // Move along a single column, printing the value of each row 146 | for (i = 0; i < h; i++) { 147 | // Print row header 148 | while (1) { 149 | CHECK_OVERFLOW; 150 | char c = *(index++); 151 | if (c == 0 || c == '\n') { 152 | break; 153 | } 154 | buf[char_count++] = c; 155 | } 156 | CHECK_OVERFLOW; 157 | 158 | // Print row values 159 | for (j = 0; j < w; j++) { 160 | ret = snprintf(buf + char_count, bufsize - char_count, fmt, *(array++)); 161 | char_count += ret; 162 | CHECK_OVERFLOW; 163 | } 164 | // Replace latest column separator with line terminator 165 | buf[char_count - 1] = '\n'; 166 | } 167 | 168 | return char_count; 169 | } 170 | -------------------------------------------------------------------------------- /xarray_extras/kernels/interpolate.py: -------------------------------------------------------------------------------- 1 | """dask kernels for :mod:`xarray_extras.interpolate` 2 | 3 | .. codeauthor:: Guido Imperiale 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from collections.abc import Iterable 9 | from typing import TYPE_CHECKING 10 | 11 | import numpy as np 12 | from scipy.interpolate import BSpline, make_interp_spline 13 | from scipy.interpolate._bsplines import _as_float_array, _augknt, _not_a_knot 14 | 15 | if TYPE_CHECKING: 16 | # TODO Python 3.9 notations 17 | from typing import Tuple, Union 18 | 19 | # TODO import from typing (requires Python 3.10) 20 | from typing_extensions import TypeAlias 21 | 22 | Boundary: TypeAlias = Iterable[Tuple[int, float]] 23 | BCType: TypeAlias = Union[Tuple[Boundary, Boundary], str, None] 24 | 25 | 26 | def _memoryview_safe(x: np.ndarray) -> np.ndarray: 27 | """Make array safe to run in a Cython memoryview-based kernel. These 28 | kernels typically break down with the error ``ValueError: buffer source 29 | array is read-only`` when running in dask distributed. 30 | """ 31 | if not x.flags.writeable: 32 | if not x.flags.owndata: 33 | x = x.copy(order="C") 34 | x.setflags(write=True) 35 | return x 36 | 37 | 38 | def make_interp_knots( 39 | x: np.ndarray, k: int = 3, bc_type: BCType | None = None, check_finite: bool = True 40 | ) -> np.ndarray: 41 | """Compute the knots of the B-spline. 42 | 43 | .. note:: 44 | This is a temporary implementation that should be moved to the main 45 | scipy library - see ``_. 46 | 47 | Parameters 48 | ---------- 49 | x : array_like, shape (n,) 50 | Abscissas. 51 | k : int, optional 52 | B-spline degree. Default is cubic, k=3. 53 | bc_type : 2-tuple or None 54 | Boundary conditions. 55 | Default is None, which means choosing the boundary conditions 56 | automatically. Otherwise, it must be a length-two tuple where the first 57 | element sets the boundary conditions at ``x[0]`` and the second 58 | element sets the boundary conditions at ``x[-1]``. Each of these must 59 | be an iterable of pairs ``(order, value)`` which gives the values of 60 | derivatives of specified orders at the given edge of the interpolation 61 | interval. 62 | check_finite : bool, optional 63 | Whether to check that the input arrays contain only finite numbers. 64 | Disabling may give a performance gain, but may result in problems 65 | (crashes, non-termination) if the inputs do contain infinities or NaNs. 66 | Default is True. 67 | 68 | Returns 69 | ------- 70 | numpy array with size = x.size + k + 1, representing the B-spline knots. 71 | """ 72 | if k < 2 and bc_type is not None: 73 | raise ValueError("Too much info for k<2: bc_type can only be None.") 74 | 75 | x = np.array(x) 76 | if x.ndim != 1 or np.any(x[1:] <= x[:-1]): 77 | raise ValueError("Expect x to be a 1-D sorted array-like.") 78 | 79 | if k == 0: 80 | t = np.r_[x, x[-1]] 81 | elif k == 1: 82 | t = np.r_[x[0], x, x[-1]] 83 | elif bc_type is None: 84 | if k == 2: 85 | # OK, it's a bit ad hoc: Greville sites + omit 86 | # 2nd and 2nd-to-last points, a la not-a-knot 87 | t = (x[1:] + x[:-1]) / 2.0 88 | t = np.r_[(x[0],) * (k + 1), t[1:-1], (x[-1],) * (k + 1)] 89 | else: 90 | t = _not_a_knot(x, k) 91 | else: 92 | t = _augknt(x, k) 93 | 94 | return _as_float_array(t, check_finite) 95 | 96 | 97 | def make_interp_coeffs( 98 | x: np.ndarray, 99 | y: np.ndarray, 100 | k: int = 3, 101 | t: np.ndarray | None = None, 102 | bc_type: BCType = None, 103 | axis: int = 0, 104 | check_finite: bool = True, 105 | ) -> np.ndarray: 106 | """Compute the knots of the B-spline. 107 | 108 | .. note:: 109 | This is a temporary implementation that should be moved to the main 110 | scipy library - see ``_. 111 | 112 | See :func:`scipy.interpolate.make_interp_spline` for parameters. 113 | 114 | :param t: 115 | Knots array, as calculated by :func:`make_interp_knots`. 116 | 117 | - For k=0, must always be None (the coefficients are not a function of 118 | the knots). 119 | - For k=1, set to None if t has been calculated by 120 | :func:`make_interp_knots`; pass a vector if it already existed 121 | before. 122 | - For k=2 and k=3, must always pass either the output of 123 | :func:`make_interp_knots` or a pre-generated vector. 124 | """ 125 | x = _memoryview_safe(x) 126 | y = _memoryview_safe(y) 127 | if t is not None: 128 | t = _memoryview_safe(t) 129 | 130 | return make_interp_spline( 131 | x, y, k, t, bc_type=bc_type, axis=axis, check_finite=check_finite 132 | ).c 133 | 134 | 135 | def splev( 136 | x_new: np.ndarray, 137 | t: np.ndarray, 138 | c: np.ndarray, 139 | k: int = 3, 140 | extrapolate: bool | str = True, 141 | ) -> np.ndarray: 142 | """Generate a BSpline object on the fly from knots and coefficients and 143 | evaluate it on x_new. 144 | 145 | See :class:`scipy.interpolate.BSpline` for all parameters. 146 | """ 147 | t = _memoryview_safe(t) 148 | c = _memoryview_safe(c) 149 | x_new = _memoryview_safe(x_new) 150 | spline = BSpline.construct_fast(t, c, k, axis=0, extrapolate=extrapolate) 151 | return spline(x_new) 152 | -------------------------------------------------------------------------------- /xarray_extras/cumulatives.py: -------------------------------------------------------------------------------- 1 | """Advanced cumulative sum/productory/mean functions""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Callable, Hashable 6 | from typing import TypeVar 7 | 8 | import dask.array as da 9 | import numpy as np 10 | import xarray 11 | 12 | from xarray_extras.kernels import cumulatives as kernels 13 | 14 | __all__ = ("compound_mean", "compound_prod", "compound_sum", "cummean") 15 | 16 | 17 | T = TypeVar("T", xarray.DataArray, xarray.Dataset) 18 | TV = TypeVar("TV", xarray.DataArray, xarray.Dataset, xarray.Variable) 19 | 20 | 21 | def cummean(x: T, dim: Hashable, skipna: bool | None = None) -> T: 22 | """ 23 | .. math:: 24 | 25 | y_{i} = mean(x_{0}, x_{1}, ... x_{i}) 26 | 27 | :param x: 28 | :class:`~xarray.DataArray` or :class:`~xarray.Dataset` 29 | :param hashable dim: 30 | dimension along which to calculate the mean 31 | :param bool skipna: 32 | If True, skip missing values (as marked by NaN). By default, only skips 33 | missing values for float dtypes; other dtypes either do not have a 34 | sentinel missing value (int) or skipna=True has not been implemented 35 | (object, datetime64 or timedelta64). 36 | :returns: 37 | xarray object of the same type, dtype, and shape as x 38 | """ 39 | if skipna is False or (skipna is None and x.dtype.kind not in "fc"): 40 | # n is a simple arange 41 | if x.chunks: 42 | if isinstance(x, xarray.DataArray): 43 | chunks = x.chunks[x.dims.index(dim)] 44 | else: 45 | chunks = x.chunks[dim] 46 | n = da.arange(1, x.sizes[dim] + 1, chunks=chunks) 47 | else: 48 | n = np.arange(1, x.sizes[dim] + 1) 49 | n = xarray.DataArray(n, dims=[dim], coords={dim: x.coords[dim]}) 50 | else: 51 | # heavier computation 52 | n = (~x.isnull()).cumsum((dim,), skipna=False) 53 | 54 | return x.cumsum((dim,), skipna=skipna) / n 55 | 56 | 57 | def compound_sum(x: T, c: xarray.DataArray, xdim: Hashable, cdim: Hashable) -> T: 58 | """Compound sum on arbitrary points of x along dim. 59 | 60 | :param x: 61 | :class:`~xarray.DataArray` or :class:`~xarray.Dataset` containing the 62 | data to be compounded 63 | :param xarray.DataArray c: 64 | array where every row contains elements of x.coords[xdim] and 65 | is used to build a point of the output. 66 | The cells in the row are matched against x.coords[dim] and perform a 67 | sum. If different rows of c require different amounts of points from x, 68 | they must be padded on the right with NaN, NaT, or '' (respectively for 69 | numbers, datetimes, and strings). 70 | :param hashable xdim: 71 | dimension of x to acquire data from. The coord associated to it must be 72 | monotonic ascending. 73 | :param hashable cdim: 74 | dimension of c that represent the vector of points to be compounded for 75 | every point of dim 76 | :returns: 77 | xarray object of the same type and dtype as x, with all dims from x 78 | and c except xdim and cdim. 79 | 80 | example:: 81 | 82 | >>> x = xarray.DataArray( 83 | >>> [10, 20, 30], 84 | >>> dims=['x'], coords={'x': ['foo', 'bar', 'baz']}) 85 | >>> c = xarray.DataArray( 86 | >>> [['foo', 'baz', None], 87 | >>> ['bar', 'baz', 'baz']], 88 | >>> dims=['y', 'c'], coords={'y': ['new1', 'new2']}) 89 | >>> compound_sum(x, c, 'x', 'c') 90 | 91 | array([40, 80]) 92 | Coordinates: 93 | * y (y) T: 99 | """Compound product among arbitrary points of x along dim 100 | See :func:`compound_sum`. 101 | """ 102 | return _compound(x, c, xdim, cdim, kernels.compound_prod) 103 | 104 | 105 | def compound_mean(x: T, c: xarray.DataArray, xdim: Hashable, cdim: Hashable) -> T: 106 | """Compound mean among arbitrary points of x along dim 107 | See :func:`compound_sum`. 108 | """ 109 | return _compound(x, c, xdim, cdim, kernels.compound_mean) 110 | 111 | 112 | def _compound( 113 | x: T, 114 | c: xarray.DataArray, 115 | xdim: Hashable, 116 | cdim: Hashable, 117 | kernel: Callable[[T, xarray.DataArray], T], 118 | ) -> T: 119 | """Implementation of all compound functions 120 | 121 | :param kernel: 122 | numba kernel to apply to (x, idx), where 123 | idx is an array of indices with the same shape as c, 124 | containing the indices along x.coords[xdim] or -1 where c is null. 125 | """ 126 | # Convert coord points to indexes of x.coords[dim] 127 | idx = xarray.DataArray(x.coords[xdim].searchsorted(c), dims=c.dims, coords=c.coords) 128 | # searchsorted(NaN) returns 0; replace it with -1. 129 | # isnull('') returns False. We could have asked for None, however 130 | # searchsorted will refuse to compare strings and None's 131 | if c.dtype.kind == "U": 132 | idx = idx.where(c != "", -1) 133 | else: 134 | idx = idx.where(~c.isnull(), -1) 135 | 136 | dtype = x.dtypes if isinstance(x, xarray.Dataset) else x.dtype 137 | 138 | return xarray.apply_ufunc( 139 | kernel, 140 | x, 141 | idx, 142 | input_core_dims=[[xdim], [cdim]], 143 | output_core_dims=[[]], 144 | dask="parallelized", 145 | output_dtypes=[dtype], 146 | ) 147 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "xarray_extras" 3 | authors = [{name = "Guido Imperiale", email = "crusaderky@gmail.com"}] 4 | license = {text = "Apache"} 5 | description = "Advanced / experimental algorithms for xarray" 6 | keywords = ["xarray"] 7 | classifiers = [ 8 | "Development Status :: 3 - Alpha", 9 | "License :: OSI Approved :: Apache Software License", 10 | "Operating System :: OS Independent", 11 | "Intended Audience :: Science/Research", 12 | "Topic :: Scientific/Engineering", 13 | "Programming Language :: Python", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.8", 16 | "Programming Language :: Python :: 3.9", 17 | "Programming Language :: Python :: 3.10", 18 | "Programming Language :: Python :: 3.11", 19 | "Programming Language :: Python :: 3.12", 20 | "Programming Language :: Python :: 3.13", 21 | ] 22 | requires-python = ">=3.8" 23 | dependencies = [ 24 | "dask >= 2022.6.0", 25 | "numba >= 0.56", 26 | "numpy >= 1.23", 27 | "pandas >= 1.5", 28 | "scipy >= 1.9", 29 | "xarray >= 2022.11.0", 30 | ] 31 | dynamic = ["version"] 32 | 33 | [project.urls] 34 | Homepage = "https://github.com/crusaderky/xarray_extras" 35 | "Bug Tracker" = "https://github.com/crusaderky/xarray_extras/issues" 36 | Changelog = "https://xarray-extras.readthedocs.io/en/latest/whats-new.html" 37 | 38 | [project.readme] 39 | text = "Advanced / experimental algorithms for xarray" 40 | content-type = "text/x-rst" 41 | 42 | [tool.setuptools] 43 | packages = ["xarray_extras"] 44 | zip-safe = false # https://mypy.readthedocs.io/en/latest/installed_packages.html 45 | include-package-data = true 46 | 47 | [tool.setuptools_scm] 48 | # Use hardcoded version when .git has been removed and this is not a package created 49 | # by sdist. This is the case e.g. of a remote deployment with PyCharm. 50 | fallback_version = "9999" 51 | 52 | [tool.setuptools.package-data] 53 | xarray_extras = [ 54 | "py.typed", 55 | "tests/data/*", 56 | ] 57 | 58 | [build-system] 59 | requires = [ 60 | "setuptools>=66", 61 | "setuptools_scm[toml]", 62 | ] 63 | build-backend = "setuptools.build_meta" 64 | 65 | [tool.pytest.ini_options] 66 | addopts = "--strict-markers --strict-config -v -r sxfE --color=yes" 67 | xfail_strict = true 68 | python_files = ["test_*.py"] 69 | testpaths = ["xarray_extras/tests"] 70 | filterwarnings = [ 71 | "error", 72 | # FIXME these need to be fixed in xarray 73 | "ignore:__array_wrap__ must accept context and return:DeprecationWarning", 74 | # FIXME these need to be looked at 75 | 'ignore:.*will no longer be implicitly promoted:FutureWarning', 76 | 'ignore:.*updating coordinate .* with a PandasMultiIndex would leave the multi-index level coordinates .* in an inconsistent state:FutureWarning', 77 | # These have been fixed; still needed for Python 3.9 CI 78 | "ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed:DeprecationWarning", 79 | 'ignore:Converting non-nanosecond precision datetime:UserWarning', 80 | 'ignore:Converting non-nanosecond precision timedelta:UserWarning', 81 | ] 82 | 83 | [tool.coverage.report] 84 | show_missing = true 85 | exclude_lines = [ 86 | "pragma: nocover", 87 | "pragma: no cover", 88 | "TYPE_CHECKING", 89 | "except ImportError", 90 | "@overload", 91 | '@(abc\.)?abstractmethod', 92 | '@(numba\.)?jit', 93 | '@(numba\.)?vectorize', 94 | '@(numba\.)?guvectorize', 95 | ] 96 | 97 | [tool.codespell] 98 | ignore-words-list = ["ND"] 99 | 100 | [tool.ruff] 101 | exclude = [".eggs"] 102 | target-version = "py38" 103 | 104 | [tool.ruff.lint] 105 | ignore = [ 106 | "EM101", # Exception must not use a string literal, assign to variable first 107 | "EM102", # Exception must not use an f-string literal, assign to variable first 108 | "N802", # Function name should be lowercase 109 | "N803", # Argument name should be lowercase 110 | "N806", # Variable should be lowercase 111 | "N816", # Variable in global scope should not be mixedCase 112 | "PD901", # Avoid using the generic variable name `df` for DataFrames 113 | "PT006", # Wrong type passed to first argument of `pytest.mark.parametrize`; expected `tuple` 114 | "PLC0414", # Import alias does not rename original package 115 | "PLR0912", # Too many branches 116 | "PLR0913", # Too many arguments in function definition 117 | "PLR2004", # Magic value used in comparison, consider replacing `123` with a constant variable 118 | "PLW2901", # for loop variable overwritten by assignment target 119 | "SIM108", # Use ternary operator instead of if-else block 120 | ] 121 | select = [ 122 | "YTT", # flake8-2020 123 | "B", # flake8-bugbear 124 | "C4", # flake8-comprehensions 125 | "EM", # flake8-errmsg 126 | "EXE", # flake8-executable 127 | "ICN", # flake8-import-conventions 128 | "G", # flake8-logging-format 129 | "PIE", # flake8-pie 130 | "PT", # flake8-pytest-style 131 | "RET", # flake8-return 132 | "SIM", # flake8-simplify 133 | "ARG", # flake8-unused-arguments 134 | "I", # isort 135 | "NPY", # NumPy specific rules 136 | "N", # pep8-naming 137 | "E", # Pycodestyle 138 | "W", # Pycodestyle 139 | "PGH", # pygrep-hooks 140 | "F", # Pyflakes 141 | "PL", # pylint 142 | "UP", # pyupgrade 143 | "RUF", # unused-noqa 144 | "TID", # tidy-ups 145 | "EXE001", # Shebang is present but file is not executable 146 | ] 147 | 148 | [tool.ruff.lint.isort] 149 | known-first-party = ["xarray_extras"] 150 | 151 | [tool.mypy] 152 | disallow_incomplete_defs = true 153 | disallow_untyped_decorators = true 154 | disallow_untyped_defs = true 155 | ignore_missing_imports = true 156 | no_implicit_optional = true 157 | show_error_codes = true 158 | warn_redundant_casts = true 159 | warn_unused_ignores = true 160 | warn_unreachable = true 161 | 162 | [[tool.mypy.overrides]] 163 | module = ["*.tests.*"] 164 | disallow_untyped_defs = false 165 | -------------------------------------------------------------------------------- /xarray_extras/csv.py: -------------------------------------------------------------------------------- 1 | """Multi-threaded CSV writer, much faster than :meth:`pandas.DataFrame.to_csv`, 2 | with full support for `dask `_ and `dask distributed 3 | `_. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from collections.abc import Callable 9 | from pathlib import Path 10 | from typing import TYPE_CHECKING, Any, cast 11 | 12 | import xarray 13 | from dask.base import tokenize 14 | from dask.delayed import Delayed 15 | from dask.highlevelgraph import HighLevelGraph 16 | 17 | from xarray_extras.kernels import csv as kernels 18 | 19 | __all__ = ("to_csv",) 20 | 21 | if TYPE_CHECKING: 22 | # TODO: remove TYPE_CHECKING (requires dask >=2023.9.1) 23 | from dask.typing import DaskCollection, Key 24 | 25 | 26 | def to_csv( # noqa: PLR0915 27 | x: xarray.DataArray, path: str | Path, *, nogil: bool = True, **kwargs: Any 28 | ) -> Any: 29 | """Print DataArray to CSV. 30 | 31 | When x has numpy backend, this function is functionally equivalent to (but 32 | much) faster than):: 33 | 34 | x.to_pandas().to_csv(path_or_buf, **kwargs) 35 | 36 | When x has dask backend, this function returns a dask delayed object which 37 | will write to the disk only when its .compute() method is invoked. 38 | 39 | Formatting and optional compression are parallelised across all available 40 | CPUs, using one dask task per chunk on the first dimension. Chunks on other 41 | dimensions will be merged ahead of computation. 42 | 43 | :param x: 44 | :class:`~xarray.DataArray` with one or two dimensions 45 | :param path: 46 | Output file path 47 | :param bool nogil: 48 | If True, use accelerated C implementation. Several kwargs won't be 49 | processed correctly (see limitations below). If False, use pandas 50 | to_csv method (slow, and does not release the GIL). 51 | nogil=True exclusively supports float and integer values dtypes (but 52 | the coords can be anything). In case of incompatible dtype, nogil 53 | is automatically switched to False. 54 | :param kwargs: 55 | Passed verbatim to :meth:`pandas.DataFrame.to_csv` or 56 | :meth:`pandas.Series.to_csv` 57 | 58 | **Limitations** 59 | 60 | - Fancy URIs are not (yet) supported. 61 | - compression='zip' is not supported. All other compression methods (gzip, 62 | bz2, xz) are supported. 63 | - When running with nogil=True, the following parameters are ignored: 64 | columns, quoting, quotechar, doublequote, escapechar, chunksize, decimal 65 | 66 | **Distributed computing** 67 | 68 | This function supports `dask distributed`_, with the caveat that all workers 69 | must write to the same shared mountpoint and that the shared filesystem 70 | must strictly guarantee **close-open coherency**, meaning that one must be 71 | able to call write() and then close() on a file descriptor from one host 72 | and then immediately afterwards open() from another host and see the output 73 | from the first host. Note that, for performance reasons, most network 74 | filesystems do not enable this feature by default. 75 | 76 | Alternatively, one may write to local mountpoints and then manually collect 77 | and concatenate the partial outputs. 78 | """ 79 | if not isinstance(x, xarray.DataArray): 80 | raise TypeError("first argument must be a DataArray") 81 | 82 | # Health checks 83 | if not isinstance(path, (str, Path)): 84 | raise TypeError("path_or_buf must be a string or a pathlib.Path object") 85 | 86 | path = Path(path) 87 | 88 | if x.ndim not in (1, 2): 89 | raise ValueError( 90 | f"cannot convert arrays with {x.ndim} dimensions into pandas objects" 91 | ) 92 | 93 | if nogil and x.dtype.kind not in "if": 94 | nogil = False 95 | 96 | # Extract row and columns indices 97 | indices = [x.get_index(dim) for dim in x.dims] 98 | if x.ndim == 2: 99 | index, columns = indices 100 | else: 101 | index = indices[0] 102 | columns = None 103 | 104 | compression = kwargs.pop("compression", "infer") 105 | compress = _compress_func(path, compression) 106 | mode = kwargs.pop("mode", "w") 107 | if mode not in "wa": 108 | raise ValueError(f"mode: expected w or a; got {mode!r}") 109 | 110 | # Fast exit for numpy backend 111 | if not x.chunks: 112 | bdata = kernels.to_csv(x.values, index, columns, True, nogil, kwargs) 113 | if compress: 114 | bdata = compress(bdata) 115 | with open(path, mode + "b") as fh: 116 | fh.write(bdata) 117 | return None 118 | 119 | # Merge chunks on all dimensions beyond the first 120 | x = x.chunk({dim: -1 for dim in x.dims[1:]}) 121 | 122 | # Manually define the dask graph 123 | tok = tokenize(x.data, index, columns, compression, path, kwargs) 124 | name1 = "to_csv_encode-" + tok 125 | name2 = "to_csv_compress-" + tok 126 | name3 = "to_csv_write-" + tok 127 | name4 = "to_csv-" + tok 128 | 129 | dsk: dict[Key, Any] = {} 130 | 131 | assert x.chunks 132 | assert x.chunks[0] 133 | offset = 0 134 | for i, size in enumerate(x.chunks[0]): 135 | # Slice index 136 | index_i = index[offset : offset + size] 137 | offset += size 138 | 139 | x_i = (x.data.name, i) + (0,) * (x.ndim - 1) 140 | 141 | # Step 1: convert to CSV and encode to binary blob 142 | if i == 0: 143 | # First chunk: print header 144 | dsk[name1, i] = (kernels.to_csv, x_i, index_i, columns, True, nogil, kwargs) 145 | else: 146 | kwargs_i = kwargs.copy() 147 | kwargs_i["header"] = False 148 | dsk[name1, i] = (kernels.to_csv, x_i, index_i, None, False, nogil, kwargs_i) 149 | 150 | # Step 2 (optional): compress 151 | if compress: 152 | prevname = name2 153 | dsk[name2, i] = compress, (name1, i) 154 | else: 155 | prevname = name1 156 | 157 | # Step 3: write to file 158 | if i == 0: 159 | # First chunk: overwrite file if it already exists 160 | # Convert PosixPath / WindowsPath to str to support dask Client 161 | # on Windows and Worker on Linux or vice versa 162 | dsk[name3, i] = ( 163 | kernels.to_file, 164 | str(path), 165 | mode + "b", 166 | (prevname, i), 167 | None, 168 | ) 169 | else: 170 | # Next chunks: wait for previous chunk to complete and append 171 | dsk[name3, i] = ( 172 | kernels.to_file, 173 | str(path), 174 | "ab", 175 | (prevname, i), 176 | (name3, i - 1), 177 | ) 178 | 179 | # Rename final key 180 | dsk[name4] = dsk.pop((name3, i)) 181 | 182 | hlg = HighLevelGraph.from_collections(name4, dsk, [cast("DaskCollection", x)]) 183 | return Delayed(name4, hlg) 184 | 185 | 186 | def _compress_func( 187 | path: Path, compression: str | None 188 | ) -> Callable[[bytes], bytes] | None: 189 | if compression == "infer": 190 | compression = path.suffix[1:].lower() 191 | if compression == "gz": 192 | compression = "gzip" 193 | elif compression == "csv": 194 | compression = None 195 | 196 | if compression is None: 197 | return None 198 | if compression == "gzip": 199 | import gzip 200 | 201 | return gzip.compress 202 | if compression == "bz2": 203 | import bz2 204 | 205 | return bz2.compress 206 | if compression == "xz": 207 | import lzma 208 | 209 | return lzma.compress 210 | if compression == "zip": 211 | raise NotImplementedError("zip compression is not supported") 212 | raise ValueError(f"Unrecognized compression: {compression}") 213 | -------------------------------------------------------------------------------- /xarray_extras/interpolate.py: -------------------------------------------------------------------------------- 1 | """xarray spline interpolation functions""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Hashable 6 | 7 | import dask.array as da 8 | import numpy as np 9 | import xarray 10 | 11 | from xarray_extras.compat import dask_array_type 12 | from xarray_extras.kernels import interpolate as kernels 13 | 14 | __all__ = ("splev", "splrep") 15 | 16 | 17 | def splrep(a: xarray.DataArray, dim: Hashable, k: int = 3) -> xarray.Dataset: 18 | """Calculate the univariate B-spline for an N-dimensional array 19 | 20 | :param xarray.DataArray a: 21 | any :class:`~xarray.DataArray` 22 | :param dim: 23 | dimension of a to be interpolated. ``a.coords[dim]`` must be strictly 24 | monotonic ascending. All int, float (not complex), or datetime dtypes 25 | are supported. 26 | :param int k: 27 | B-spline order: 28 | 29 | = ================== 30 | k interpolation kind 31 | = ================== 32 | 0 nearest neighbour 33 | 1 linear 34 | 2 quadratic 35 | 3 cubic 36 | = ================== 37 | 38 | :returns: 39 | :class:`~xarray.Dataset` with t, c, k (knots, coefficients, order) 40 | variables, the same shape and coords as the input, that can be passed 41 | to :func:`splev`. 42 | 43 | Example:: 44 | 45 | >>> x = np.arange(0, 120, 20) 46 | >>> x = xarray.DataArray(x, dims=['x'], coords={'x': x}) 47 | >>> s = xarray.DataArray(np.linspace(1, 20, 5), dims=['s']) 48 | >>> y = np.exp(-x / s) 49 | >>> x_new = np.arange(0, 120, 1) 50 | >>> tck = splrep(y, 'x') 51 | >>> y_new = splev(x_new, tck) 52 | 53 | **Features** 54 | 55 | - Interpolate a ND array on any arbitrary dimension 56 | - dask supported on both on the interpolated array and x_new 57 | - Supports ND x_new arrays 58 | - The CPU-heavy interpolator generation (:func:`splrep`) is executed only 59 | once and then can be applied to multiple x_new (:func:`splev`) 60 | - memory-efficient 61 | - Can be pickled and used on dask distributed 62 | 63 | **Limitations** 64 | 65 | - Chunks are not supported along dim on the interpolated dimension. 66 | """ 67 | # Make sure that dim is on axis 0 68 | a = a.transpose(dim, *[d for d in a.dims if d != dim]) 69 | x = a.coords[dim].values 70 | 71 | if x.dtype.kind == "M": # datetime 72 | # Same treatment will be applied to x_new. 73 | # Allow x_new.dtype==M8[D] and x.dtype==M8[ns], or vice versa 74 | x = x.astype("M8[ns]").astype(float) 75 | elif x.dtype.kind == "m": # timedelta 76 | x = x.astype("m8[ns]").astype(float) 77 | 78 | t = kernels.make_interp_knots(x, k, check_finite=False) 79 | if k < 2: 80 | t_c_param = None 81 | else: 82 | t_c_param = t 83 | 84 | if isinstance(a.data, dask_array_type): 85 | from dask.array import map_blocks 86 | 87 | if len(a.data.chunks[0]) > 1: 88 | raise NotImplementedError( 89 | "Unsupported: multiple chunks on interpolation dim" 90 | ) 91 | 92 | c = map_blocks( 93 | kernels.make_interp_coeffs, 94 | x, 95 | a.data, 96 | k=k, 97 | t=t_c_param, 98 | check_finite=False, 99 | dtype=float, 100 | ) 101 | else: 102 | c = kernels.make_interp_coeffs(x, a.data, k=k, t=t_c_param, check_finite=False) 103 | 104 | return xarray.Dataset( 105 | data_vars={ 106 | "t": ("__t__", t), 107 | "c": (a.dims, c), 108 | }, 109 | coords=a.coords, 110 | attrs={ 111 | "spline_dim": dim, 112 | "k": k, 113 | }, 114 | ) 115 | 116 | 117 | def splev( 118 | x_new: object, tck: xarray.Dataset, extrapolate: bool | str = True 119 | ) -> xarray.DataArray: 120 | """Evaluate the B-spline generated with :func:`splrep`. 121 | 122 | :param x_new: 123 | Any :class:`~xarray.DataArray` with any number of dims, not necessarily 124 | the original interpolation dim. 125 | Alternatively, it can be any 1-dimensional array-like; it will be 126 | automatically converted to a :class:`~xarray.DataArray` on the 127 | interpolation dim. 128 | 129 | :param xarray.Dataset tck: 130 | As returned by :func:`splrep`. 131 | It can have been: 132 | 133 | - transposed (not recommended, as performance will 134 | drop if c is not C-contiguous) 135 | - sliced, reordered, or (re)chunked, on any 136 | dim except the interpolation dim 137 | - computed from dask to numpy backend 138 | - round-tripped to disk 139 | 140 | :param extrapolate: 141 | True 142 | Extrapolate the first and last polynomial pieces of b-spline 143 | functions active on the base interval 144 | False 145 | Return NaNs outside of the base interval 146 | 'periodic' 147 | Periodic extrapolation is used 148 | 'clip' 149 | Return y[0] and y[-1] outside of the base interval 150 | 151 | :returns: 152 | :class:`~xarray.DataArray` with all dims of the interpolated array, 153 | minus the interpolation dim, plus all dims of x_new 154 | 155 | See :func:`splrep` for usage example. 156 | """ 157 | # Pre-process x_new into a DataArray 158 | if not isinstance(x_new, xarray.DataArray): 159 | if not isinstance(x_new, dask_array_type): 160 | x_new = np.array(x_new) 161 | if x_new.ndim == 0: 162 | dims = [] 163 | elif x_new.ndim == 1: 164 | dims = [tck.spline_dim] 165 | else: 166 | raise ValueError( 167 | "N-dimensional x_new is only supported if x_new is a DataArray" 168 | ) 169 | x_new = xarray.DataArray(x_new, dims=dims, coords={tck.spline_dim: x_new}) 170 | 171 | dim = tck.spline_dim 172 | t = tck.t 173 | c = tck.c 174 | k = tck.k 175 | 176 | invalid_dims = {*x_new.dims} & {*c.dims} - {dim} 177 | if invalid_dims: 178 | raise ValueError( 179 | "Overlapping dims between interpolated " 180 | "array and x_new: " + ",".join(str(d) for d in invalid_dims) 181 | ) 182 | 183 | if t.shape != (c.sizes[dim] + k + 1,): 184 | raise ValueError("Interpolated dimension has been sliced") 185 | 186 | if x_new.dtype.kind == "M": # datetime 187 | # Note that we're modifying the x_new values, not the x_new coords 188 | x_new = x_new.astype("M8[ns]").astype(float) 189 | elif x_new.dtype.kind == "m": # timedelta 190 | x_new = x_new.astype("m8[ns]").astype(float) 191 | 192 | if extrapolate == "clip": 193 | x = tck.coords[dim].values 194 | 195 | if x.dtype.kind == "M": # datetime 196 | x = x.astype("M8[ns]").astype(float) 197 | elif x.dtype.kind == "m": # timedelta 198 | x = x.astype("m8[ns]").astype(float) 199 | 200 | x_new = np.clip(x_new, x[0].tolist(), x[-1].tolist()) 201 | extrapolate = False 202 | 203 | if c.dims[0] != dim: 204 | c = c.transpose(dim, *[d for d in c.dims if d != dim]) 205 | 206 | if any(isinstance(v.data, dask_array_type) for v in (x_new, t, c)): 207 | if t.chunks and len(t.chunks[0]) > 1: 208 | raise NotImplementedError( 209 | "Unsupported: multiple chunks on interpolation dim" 210 | ) 211 | if c.chunks and len(c.chunks[0]) > 1: 212 | raise NotImplementedError( 213 | "Unsupported: multiple chunks on interpolation dim" 214 | ) 215 | 216 | # omitting t and c 217 | x_new_axes = "abdefghijklm"[: x_new.ndim] 218 | c_axes = "nopqrsuvwxyz"[: c.ndim - 1] 219 | 220 | y_new = da.blockwise( 221 | kernels.splev, 222 | x_new_axes + c_axes, 223 | x_new.data, 224 | x_new_axes, 225 | t.data, 226 | "t", 227 | c.data, 228 | "c" + c_axes, 229 | k=k, 230 | extrapolate=extrapolate, 231 | concatenate=True, 232 | dtype=float, 233 | ) 234 | else: 235 | y_new = kernels.splev( 236 | x_new.values, t.values, c.values, k, extrapolate=extrapolate 237 | ) 238 | 239 | y_new = xarray.DataArray(y_new, dims=x_new.dims + c.dims[1:], coords=x_new.coords) 240 | y_new.coords.update({k: c for k, c in c.coords.items() if dim not in c.dims}) 241 | return y_new 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, and 10 | distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by the copyright 13 | owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all other entities 16 | that control, are controlled by, or are under common control with that entity. 17 | For the purposes of this definition, "control" means (i) the power, direct or 18 | indirect, to cause the direction or management of such entity, whether by 19 | contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the 20 | outstanding shares, or (iii) beneficial ownership of such entity. 21 | 22 | "You" (or "Your") shall mean an individual or Legal Entity exercising 23 | permissions granted by this License. 24 | 25 | "Source" form shall mean the preferred form for making modifications, including 26 | but not limited to software source code, documentation source, and configuration 27 | files. 28 | 29 | "Object" form shall mean any form resulting from mechanical transformation or 30 | translation of a Source form, including but not limited to compiled object code, 31 | generated documentation, and conversions to other media types. 32 | 33 | "Work" shall mean the work of authorship, whether in Source or Object form, made 34 | available under the License, as indicated by a copyright notice that is included 35 | in or attached to the work (an example is provided in the Appendix below). 36 | 37 | "Derivative Works" shall mean any work, whether in Source or Object form, that 38 | is based on (or derived from) the Work and for which the editorial revisions, 39 | annotations, elaborations, or other modifications represent, as a whole, an 40 | original work of authorship. For the purposes of this License, Derivative Works 41 | shall not include works that remain separable from, or merely link (or bind by 42 | name) to the interfaces of, the Work and Derivative Works thereof. 43 | 44 | "Contribution" shall mean any work of authorship, including the original version 45 | of the Work and any modifications or additions to that Work or Derivative Works 46 | thereof, that is intentionally submitted to Licensor for inclusion in the Work 47 | by the copyright owner or by an individual or Legal Entity authorized to submit 48 | on behalf of the copyright owner. For the purposes of this definition, 49 | "submitted" means any form of electronic, verbal, or written communication sent 50 | to the Licensor or its representatives, including but not limited to 51 | communication on electronic mailing lists, source code control systems, and 52 | issue tracking systems that are managed by, or on behalf of, the Licensor for 53 | the purpose of discussing and improving the Work, but excluding communication 54 | that is conspicuously marked or otherwise designated in writing by the copyright 55 | owner as "Not a Contribution." 56 | 57 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf 58 | of whom a Contribution has been received by Licensor and subsequently 59 | incorporated within the Work. 60 | 61 | 2. Grant of Copyright License. 62 | 63 | Subject to the terms and conditions of this License, each Contributor hereby 64 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 65 | irrevocable copyright license to reproduce, prepare Derivative Works of, 66 | publicly display, publicly perform, sublicense, and distribute the Work and such 67 | Derivative Works in Source or Object form. 68 | 69 | 3. Grant of Patent License. 70 | 71 | Subject to the terms and conditions of this License, each Contributor hereby 72 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 73 | irrevocable (except as stated in this section) patent license to make, have 74 | made, use, offer to sell, sell, import, and otherwise transfer the Work, where 75 | such license applies only to those patent claims licensable by such Contributor 76 | that are necessarily infringed by their Contribution(s) alone or by combination 77 | of their Contribution(s) with the Work to which such Contribution(s) was 78 | submitted. If You institute patent litigation against any entity (including a 79 | cross-claim or counterclaim in a lawsuit) alleging that the Work or a 80 | Contribution incorporated within the Work constitutes direct or contributory 81 | patent infringement, then any patent licenses granted to You under this License 82 | for that Work shall terminate as of the date such litigation is filed. 83 | 84 | 4. Redistribution. 85 | 86 | You may reproduce and distribute copies of the Work or Derivative Works thereof 87 | in any medium, with or without modifications, and in Source or Object form, 88 | provided that You meet the following conditions: 89 | 90 | You must give any other recipients of the Work or Derivative Works a copy of 91 | this License; and 92 | You must cause any modified files to carry prominent notices stating that You 93 | changed the files; and 94 | You must retain, in the Source form of any Derivative Works that You distribute, 95 | all copyright, patent, trademark, and attribution notices from the Source form 96 | of the Work, excluding those notices that do not pertain to any part of the 97 | Derivative Works; and 98 | If the Work includes a "NOTICE" text file as part of its distribution, then any 99 | Derivative Works that You distribute must include a readable copy of the 100 | attribution notices contained within such NOTICE file, excluding those notices 101 | that do not pertain to any part of the Derivative Works, in at least one of the 102 | following places: within a NOTICE text file distributed as part of the 103 | Derivative Works; within the Source form or documentation, if provided along 104 | with the Derivative Works; or, within a display generated by the Derivative 105 | Works, if and wherever such third-party notices normally appear. The contents of 106 | the NOTICE file are for informational purposes only and do not modify the 107 | License. You may add Your own attribution notices within Derivative Works that 108 | You distribute, alongside or as an addendum to the NOTICE text from the Work, 109 | provided that such additional attribution notices cannot be construed as 110 | modifying the License. 111 | You may add Your own copyright statement to Your modifications and may provide 112 | additional or different license terms and conditions for use, reproduction, or 113 | distribution of Your modifications, or for any such Derivative Works as a whole, 114 | provided Your use, reproduction, and distribution of the Work otherwise complies 115 | with the conditions stated in this License. 116 | 117 | 5. Submission of Contributions. 118 | 119 | Unless You explicitly state otherwise, any Contribution intentionally submitted 120 | for inclusion in the Work by You to the Licensor shall be under the terms and 121 | conditions of this License, without any additional terms or conditions. 122 | Notwithstanding the above, nothing herein shall supersede or modify the terms of 123 | any separate license agreement you may have executed with Licensor regarding 124 | such Contributions. 125 | 126 | 6. Trademarks. 127 | 128 | This License does not grant permission to use the trade names, trademarks, 129 | service marks, or product names of the Licensor, except as required for 130 | reasonable and customary use in describing the origin of the Work and 131 | reproducing the content of the NOTICE file. 132 | 133 | 7. Disclaimer of Warranty. 134 | 135 | Unless required by applicable law or agreed to in writing, Licensor provides the 136 | Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, 137 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, 138 | including, without limitation, any warranties or conditions of TITLE, 139 | NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are 140 | solely responsible for determining the appropriateness of using or 141 | redistributing the Work and assume any risks associated with Your exercise of 142 | permissions under this License. 143 | 144 | 8. Limitation of Liability. 145 | 146 | In no event and under no legal theory, whether in tort (including negligence), 147 | contract, or otherwise, unless required by applicable law (such as deliberate 148 | and grossly negligent acts) or agreed to in writing, shall any Contributor be 149 | liable to You for damages, including any direct, indirect, special, incidental, 150 | or consequential damages of any character arising as a result of this License or 151 | out of the use or inability to use the Work (including but not limited to 152 | damages for loss of goodwill, work stoppage, computer failure or malfunction, or 153 | any and all other commercial damages or losses), even if such Contributor has 154 | been advised of the possibility of such damages. 155 | 156 | 9. Accepting Warranty or Additional Liability. 157 | 158 | While redistributing the Work or Derivative Works thereof, You may choose to 159 | offer, and charge a fee for, acceptance of support, warranty, indemnity, or 160 | other liability obligations and/or rights consistent with this License. However, 161 | in accepting such obligations, You may act only on Your own behalf and on Your 162 | sole responsibility, not on behalf of any other Contributor, and only if You 163 | agree to indemnify, defend, and hold each Contributor harmless for any liability 164 | incurred by, or claims asserted against, such Contributor by reason of your 165 | accepting any such warranty or additional liability. 166 | 167 | END OF TERMS AND CONDITIONS 168 | 169 | APPENDIX: How to apply the Apache License to your work 170 | 171 | To apply the Apache License to your work, attach the following boilerplate 172 | notice, with the fields enclosed by brackets "[]" replaced with your own 173 | identifying information. (Don't include the brackets!) The text should be 174 | enclosed in the appropriate comment syntax for the file format. We also 175 | recommend that a file or class name and description of purpose be included on 176 | the same "printed page" as the copyright notice for easier identification within 177 | third-party archives. 178 | 179 | Copyright [yyyy] [name of copyright owner] 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /xarray_extras/tests/test_interpolate.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import numpy as np 3 | import pytest 4 | from xarray import DataArray, Dataset, apply_ufunc 5 | from xarray.testing import assert_allclose, assert_equal 6 | 7 | import xarray_extras.kernels.interpolate as kernels 8 | from xarray_extras.interpolate import splev, splrep 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "k,expect", 13 | [ 14 | (0, [40, 55, 4]), 15 | (1, [45, 35, 4.5]), 16 | (2, [45, 37.90073529, np.nan]), 17 | (3, [45, 39.69583333, np.nan]), 18 | ], 19 | ) 20 | def test_0d(k, expect): 21 | """ 22 | - Test different orders 23 | - Test unsorted x 24 | - Test what happens when a series contains NaN 25 | """ 26 | y = DataArray( 27 | [[10, 20, 30, 40, 50, 60], [11, 28, 39, 55, 15, -2], [np.nan, 2, 3, 4, 5, 6]], 28 | dims=["y", "x"], 29 | coords={"x": [1, 2, 3, 4, 5, 6], "y": ["y1", "y2", "y3"]}, 30 | ) 31 | tck = splrep(y, "x", k) 32 | expect = DataArray( 33 | expect, dims=["y"], coords={"x": 4.5, "y": ["y1", "y2", "y3"]} 34 | ).astype(float) 35 | assert_allclose(splev(4.5, tck), expect, rtol=0, atol=1e-6) 36 | 37 | 38 | @pytest.mark.parametrize( 39 | "x_new,expect", 40 | [ 41 | # list 42 | ([3.5, 4.5], DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]})), 43 | # tuple 44 | ((3.5, 4.5), DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]})), 45 | # np.array 46 | ( 47 | np.array([3.5, 4.5]), 48 | DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]}), 49 | ), 50 | # da.Array 51 | ( 52 | da.from_array(np.array([3.5, 4.5]), chunks=1), 53 | DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]}).chunk(1), 54 | ), 55 | # DataArray, same dim as y, no coord 56 | (DataArray([3.5, 4.5], dims=["x"]), DataArray([35.0, 45.0], dims=["x"])), 57 | # DataArray, same dim as y, with coord 58 | ( 59 | DataArray([3.5, 4.5], dims=["x"], coords={"x": [100, 200]}), 60 | DataArray([35.0, 45.0], dims=["x"], coords={"x": [100, 200]}), 61 | ), 62 | # DataArray, different dim as y 63 | (DataArray([3.5, 4.5], dims=["t"]), DataArray([35.0, 45.0], dims=["t"])), 64 | ], 65 | ) 66 | def test_1d(x_new, expect): 67 | """ 68 | - Test 1d case 69 | - Test auto-casting of various types of x_new 70 | """ 71 | y = DataArray( 72 | [10, 20, 30, 40, 50, 60], dims=["x"], coords={"x": [1, 2, 3, 4, 5, 6]} 73 | ) 74 | tck = splrep(y, "x", k=1) 75 | y_new = splev(x_new, tck) 76 | assert_equal(y_new, expect) 77 | assert y_new.chunks == expect.chunks 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "chunk_y,chunk_x_new,expect_chunks_tck,expect_chunks_y_new", 82 | [ 83 | (False, False, {}, None), 84 | (False, True, {}, ((1, 1), (1, 1), (2,))), 85 | (True, False, {"y": (1, 1), "x": (6,)}, ((2,), (2,), (1, 1))), 86 | (True, True, {"y": (1, 1), "x": (6,)}, ((1, 1), (1, 1), (1, 1))), 87 | ], 88 | ) 89 | def test_nd(chunk_y, chunk_x_new, expect_chunks_tck, expect_chunks_y_new): 90 | """ 91 | - Test ND y vs. ND x_new 92 | - Test dask 93 | """ 94 | y = DataArray( 95 | [[10, 20, 30, 40, 50, 60], [11, 28, 39, 55, 15, -2]], 96 | dims=["y", "x"], 97 | coords={"x": [1, 2, 3, 4, 5, 6], "y": ["y1", "y2"]}, 98 | ) 99 | x_new = DataArray( 100 | [[3.5, 4.5], [1.5, 5.5]], 101 | dims=["w", "z"], 102 | coords={"w": [100, 200], "z": ["foo", "bar"]}, 103 | ) 104 | expect_tck = Dataset( 105 | data_vars={ 106 | "t": ("__t__", [1.0, 1.0, 1.0, 1.0, 3.0, 4.0, 6.0, 6.0, 6.0, 6.0]), 107 | "c": ( 108 | ("x", "y"), 109 | [ 110 | [10.0, 11.0], 111 | [16.666667, 33.118519], 112 | [26.666667, 20.762963], 113 | [43.333333, 84.003704], 114 | [53.333333, -25.251852], 115 | [60.0, -2], 116 | ], 117 | ), 118 | }, 119 | coords={"x": [1, 2, 3, 4, 5, 6], "y": ["y1", "y2"]}, 120 | attrs={"spline_dim": "x", "k": 3}, 121 | ) 122 | 123 | expect_y_new = DataArray( 124 | [ 125 | [[35.0, 51.0375], [45.0, 39.69583333]], 126 | [[15.0, 22.72083333], [55.0, -3.945833]], 127 | ], 128 | dims=["w", "z", "y"], 129 | coords={ 130 | "w": [100, 200], 131 | "y": ["y1", "y2"], 132 | "z": ["foo", "bar"], 133 | }, 134 | ) 135 | 136 | if chunk_y: 137 | y = y.chunk({"y": 1}) 138 | if chunk_x_new: 139 | x_new = x_new.chunk(1) 140 | 141 | tck = splrep(y, "x", k=3) 142 | assert_allclose(tck.compute(), expect_tck, atol=1e-6, rtol=0) 143 | y_new = splev(x_new, tck) 144 | assert_allclose(y_new.compute(), expect_y_new, atol=1e-6, rtol=0) 145 | 146 | assert tck.chunks == expect_chunks_tck 147 | assert y_new.chunks == expect_chunks_y_new 148 | 149 | 150 | @pytest.mark.parametrize("contiguous", [False, True]) 151 | @pytest.mark.parametrize("transpose", [False, True]) 152 | def test_transpose(transpose, contiguous): 153 | """Test that the interpolation dim does not need to be on axis 0""" 154 | y = DataArray( 155 | [[10, 20], [30, 40], [50, 60]], 156 | dims=["y", "x"], 157 | coords={"x": [1, 2], "y": ["y1", "y2", "y3"]}, 158 | ) 159 | expect = DataArray( 160 | [[15.0, 35.0, 55.0], [10.0, 30.0, 50.0]], 161 | dims=["x", "y"], 162 | coords={"x": [1.5, 1.0], "y": ["y1", "y2", "y3"]}, 163 | ) 164 | 165 | if transpose: 166 | y = y.T 167 | if contiguous: 168 | y = apply_ufunc(np.ascontiguousarray, y) 169 | 170 | tck = splrep(y, "x", 1) 171 | y_new = splev([1.5, 1.0], tck) 172 | assert_equal(expect, y_new) 173 | 174 | 175 | @pytest.mark.parametrize( 176 | "x_new_dtype", [np.uint32, np.uint64, np.int32, np.int64, np.float32] 177 | ) 178 | @pytest.mark.parametrize( 179 | "x_dtype", [np.uint32, np.uint64, np.int32, np.int64, np.float32] 180 | ) 181 | def test_nonfloat(x_dtype, x_new_dtype): 182 | """Test numeric x that isn't float64""" 183 | x = np.array([0, 100]) 184 | 185 | y = DataArray((x * 3).astype(x_dtype), dims=["x"], coords={"x": x.astype(x_dtype)}) 186 | x_new = np.array([50]).astype(x_new_dtype) 187 | expect = DataArray([150.0], dims=["x"], coords={"x": x_new}) 188 | 189 | tck = splrep(y, "x", k=1) 190 | y_new = splev(x_new, tck) 191 | assert_equal(expect, y_new) 192 | 193 | 194 | @pytest.mark.filterwarnings("ignore:Converting non-nanosecond precision datetime ") 195 | @pytest.mark.parametrize("x_new_dtype", ["M8[D]", "M8[s]", "M8[ns]"]) 196 | @pytest.mark.parametrize("x_dtype", ["M8[D]", "M8[s]", "M8[ns]"]) 197 | def test_dates(x_dtype, x_new_dtype): 198 | """ 199 | - Test mismatched date formats on x and x_new 200 | - Test clip extrapolation on test_dates 201 | """ 202 | y = DataArray( 203 | [10, 20], 204 | dims=["x"], 205 | coords={"x": np.array(["2000-01-01", "2001-01-01"], dtype=x_dtype)}, 206 | ) 207 | x_new = np.array(["2000-04-20", "2002-07-28"], dtype=x_new_dtype) 208 | expect = DataArray([13.00546448, 20.0], dims=["x"], coords={"x": x_new}) 209 | 210 | tck = splrep(y, "x", k=1) 211 | y_new = splev(x_new, tck, extrapolate="clip") 212 | assert y_new.x.dtype == expect.x.dtype 213 | assert_allclose(expect, y_new, atol=1e-6, rtol=0) 214 | 215 | 216 | @pytest.mark.filterwarnings("ignore:Converting non-nanosecond precision datetime ") 217 | @pytest.mark.parametrize("x_new_dtype", ["m8[D]", "m8[s]", "m8[ns]"]) 218 | @pytest.mark.parametrize("x_dtype", ["m8[D]", "m8[s]", "m8[ns]"]) 219 | def test_timedeltas(x_dtype, x_new_dtype): 220 | """ 221 | - Test mismatched date formats on x and x_new 222 | - Test clip extrapolation on test_dates 223 | """ 224 | y = DataArray( 225 | [10, 20], 226 | dims=["x"], 227 | coords={"x": np.array([30, 50], dtype="m8[D]").astype(x_dtype)}, 228 | ) 229 | x_new = np.array([35, 45], dtype="m8[D]").astype(x_new_dtype) 230 | expect = DataArray([12.5, 17.5], dims=["x"], coords={"x": x_new}) 231 | 232 | tck = splrep(y, "x", k=1) 233 | y_new = splev(x_new, tck, extrapolate="clip") 234 | assert y_new.x.dtype == expect.x.dtype 235 | assert_allclose(expect, y_new, atol=1e-6, rtol=0) 236 | 237 | 238 | @pytest.mark.parametrize( 239 | "extrapolate,expect", 240 | [ 241 | (True, [-1.36401507, -1.22694955]), 242 | (False, [np.nan, np.nan]), 243 | ("clip", [0, np.sin(9)]), 244 | ("periodic", [0.98935825, 0.84147098]), 245 | ], 246 | ) 247 | def test_extrapolate(extrapolate, expect): 248 | """Test all possible extrapolate parameters""" 249 | x = np.arange(10) 250 | y = DataArray(np.sin(x), dims=["x"], coords={"x": x}) 251 | x_new = [-1, 10] 252 | expect = DataArray(expect, dims=["x"], coords={"x": x_new}) 253 | 254 | tck = splrep(y, "x", k=3) 255 | y_new = splev(x_new, tck, extrapolate=extrapolate) 256 | assert_allclose(expect, y_new, atol=1e-6, rtol=0) 257 | 258 | 259 | def test_dim_collision(): 260 | """y and x_new have overlapping dims besides the interpolation dim""" 261 | y = DataArray( 262 | [[10, 20], [11, 28]], dims=["y", "x"], coords={"x": [1, 2], "y": ["y1", "y2"]} 263 | ) 264 | x_new = DataArray([1, 1], dims=["y"]) 265 | tck = splrep(y, "x", 1) 266 | with pytest.raises( 267 | ValueError, 268 | match="Overlapping dims between interpolated array and x_new: y", 269 | ): 270 | splev(x_new, tck) 271 | 272 | 273 | def test_duplicates(): 274 | """x contains non-unique points""" 275 | y = DataArray([10, 20, 30], dims=["x"], coords={"x": [1, 1, 2]}) 276 | with pytest.raises(ValueError, match="Expect x to be a 1-D sorted array-like"): 277 | splrep(y, "x", 1) 278 | 279 | 280 | def test_chunked_x(): 281 | """x is chunked""" 282 | y = DataArray([10, 20], dims=["x"], coords={"x": [1, 2]}).chunk(1) 283 | 284 | with pytest.raises( 285 | NotImplementedError, 286 | match="Unsupported: multiple chunks on interpolation dim", 287 | ): 288 | splrep(y, "x", 1) 289 | 290 | 291 | def test_distributed(): 292 | def ro_array(a): 293 | a = np.array(a) 294 | a.setflags(write=False) 295 | # Return a view of a, so that setting the write flag on the view is not 296 | # enough 297 | return a[:] 298 | 299 | x = ro_array([1.0, 2.0]) 300 | y = ro_array([10.0, 20.0]) 301 | t = kernels.make_interp_knots(x, k=1) 302 | t = ro_array(t) 303 | c = kernels.make_interp_coeffs(x, y, k=1, t=t) 304 | c = ro_array(c) 305 | x_new = ro_array([1.5, 1.8]) 306 | kernels.splev(x_new, t, c, k=1) 307 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # documentation build configuration file, created by 4 | # sphinx-quickstart on Thu Feb 6 18:57:54 2014. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | import datetime 15 | import os 16 | import sys 17 | 18 | import xarray 19 | 20 | import xarray_extras 21 | 22 | print("python exec:", sys.executable) 23 | print("sys.path:", sys.path) 24 | print("xarray_extras version: ", xarray_extras.__version__) 25 | 26 | 27 | # -- General configuration ------------------------------------------------ 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # needs_sphinx = '1.0' 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | "sphinx.ext.autodoc", 37 | "sphinx.ext.autosummary", 38 | "sphinx.ext.intersphinx", 39 | "sphinx.ext.extlinks", 40 | "sphinx.ext.mathjax", 41 | ] 42 | 43 | extlinks = { 44 | "issue": ("https://github.com/crusaderky/xarray_extras/issues/%s", "GH"), 45 | "pull": ("https://github.com/crusaderky/xarray_extras/pull/%s", "PR"), 46 | } 47 | 48 | autosummary_generate = True 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ["_templates"] 52 | 53 | # The suffix of source filenames. 54 | source_suffix = ".rst" 55 | 56 | # The encoding of source files. 57 | # source_encoding = 'utf-8-sig' 58 | 59 | # The master toctree document. 60 | master_doc = "index" 61 | 62 | # General information about the project. 63 | project = "xarray_extras" 64 | copyright = f"2018-{datetime.datetime.now().year}, xarray_extras Developers" 65 | 66 | # The version info for the project you're documenting, acts as replacement for 67 | # |version| and |release|, also used in various other places throughout the 68 | # built documents. 69 | # 70 | # The short X.Y version. 71 | version = xarray_extras.__version__.split("+")[0] 72 | # The full version, including alpha/beta/rc tags. 73 | release = xarray_extras.__version__ 74 | 75 | # The language for content autogenerated by Sphinx. Refer to documentation 76 | # for a list of supported languages. 77 | # language = None 78 | 79 | # There are two options for replacing |today|: either, you set today to some 80 | # non-false value, then it is used: 81 | # today = '' 82 | # Else, today_fmt is used as the format for a strftime call. 83 | today_fmt = "%Y-%m-%d" 84 | 85 | # List of patterns, relative to source directory, that match files and 86 | # directories to ignore when looking for source files. 87 | exclude_patterns = ["_build"] 88 | 89 | # The reST default role (used for this markup: `text`) to use for all 90 | # documents. 91 | # default_role = None 92 | 93 | # If true, '()' will be appended to :func: etc. cross-reference text. 94 | # add_function_parentheses = True 95 | 96 | # If true, the current module name will be prepended to all description 97 | # unit titles (such as .. function::). 98 | # add_module_names = True 99 | 100 | # If true, sectionauthor and moduleauthor directives will be shown in the 101 | # output. They are ignored by default. 102 | # show_authors = False 103 | 104 | # The name of the Pygments (syntax highlighting) style to use. 105 | pygments_style = "sphinx" 106 | 107 | # A list of ignored prefixes for module index sorting. 108 | # modindex_common_prefix = [] 109 | 110 | # If true, keep warnings as "system message" paragraphs in the built documents. 111 | # keep_warnings = False 112 | 113 | 114 | # -- Options for HTML output ---------------------------------------------- 115 | 116 | # The theme to use for HTML and HTML Help pages. See the documentation for 117 | # a list of builtin themes. 118 | html_theme = "sphinx_rtd_theme" 119 | 120 | # Theme options are theme-specific and customize the look and feel of a theme 121 | # further. For a list of options available for each theme, see the 122 | # documentation. 123 | html_theme_options = {"logo_only": True} 124 | 125 | # Add any paths that contain custom themes here, relative to this directory. 126 | # html_theme_path = [] 127 | 128 | # The name for this set of Sphinx documents. If None, it defaults to 129 | # " v documentation". 130 | # html_title = None 131 | 132 | # A shorter title for the navigation bar. Default is the same as html_title. 133 | # html_short_title = None 134 | 135 | # The name of an image file (relative to this directory) to place at the top 136 | # of the sidebar. 137 | # html_logo = "_static/logo.png" 138 | 139 | # The name of an image file (within the static path) to use as favicon of the 140 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 141 | # pixels large. 142 | # html_favicon = None 143 | 144 | # Add any paths that contain custom static files (such as style sheets) here, 145 | # relative to this directory. They are copied after the builtin static files, 146 | # so a file named "default.css" will overwrite the builtin "default.css". 147 | html_static_path = ["_static"] 148 | 149 | # Sometimes the savefig directory doesn't exist and needs to be created 150 | # https://github.com/ipython/ipython/issues/8733 151 | # becomes obsolete when we can pin ipython>=5.2; see doc/environment.yml 152 | ipython_savefig_dir = os.path.join( 153 | os.path.dirname(os.path.abspath(__file__)), "_build", "html", "_static" 154 | ) 155 | if not os.path.exists(ipython_savefig_dir): 156 | os.makedirs(ipython_savefig_dir) 157 | 158 | # Add any extra paths that contain custom files (such as robots.txt or 159 | # .htaccess) here, relative to this directory. These files are copied 160 | # directly to the root of the documentation. 161 | # html_extra_path = [] 162 | 163 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 164 | # using the given strftime format. 165 | html_last_updated_fmt = today_fmt 166 | 167 | # If true, SmartyPants will be used to convert quotes and dashes to 168 | # typographically correct entities. 169 | # html_use_smartypants = True 170 | 171 | # Custom sidebar templates, maps document names to template names. 172 | # html_sidebars = {} 173 | 174 | # Additional templates that should be rendered to pages, maps page names to 175 | # template names. 176 | # html_additional_pages = {} 177 | 178 | # If false, no module index is generated. 179 | # html_domain_indices = True 180 | 181 | # If false, no index is generated. 182 | # html_use_index = True 183 | 184 | # If true, the index is split into individual pages for each letter. 185 | # html_split_index = False 186 | 187 | # If true, links to the reST sources are added to the pages. 188 | # html_show_sourcelink = True 189 | 190 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 191 | # html_show_sphinx = True 192 | 193 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 194 | # html_show_copyright = True 195 | 196 | # If true, an OpenSearch description file will be output, and all pages will 197 | # contain a tag referring to it. The value of this option must be the 198 | # base URL from which the finished HTML is served. 199 | # html_use_opensearch = '' 200 | 201 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 202 | # html_file_suffix = None 203 | 204 | # Output file base name for HTML help builder. 205 | htmlhelp_basename = "xarray_extrasdoc" 206 | 207 | 208 | # -- Options for LaTeX output --------------------------------------------- 209 | 210 | latex_elements: dict[str, str] = { 211 | # The paper size ('letterpaper' or 'a4paper'). 212 | # 'papersize': 'letterpaper', 213 | # The font size ('10pt', '11pt' or '12pt'). 214 | # 'pointsize': '10pt', 215 | # Additional stuff for the LaTeX preamble. 216 | # 'preamble': '', 217 | } 218 | 219 | # Grouping the document tree into LaTeX files. List of tuples 220 | # (source start file, target name, title, 221 | # author, documentclass [howto, manual, or own class]). 222 | latex_documents = [ 223 | ( 224 | "index", 225 | "xarray_extras.tex", 226 | "xarray_extras Documentation", 227 | "xarray_extras Developers", 228 | "manual", 229 | ), 230 | ] 231 | 232 | # The name of an image file (relative to this directory) to place at the top of 233 | # the title page. 234 | # latex_logo = None 235 | 236 | # For "manual" documents, if this is true, then toplevel headings are parts, 237 | # not chapters. 238 | # latex_use_parts = False 239 | 240 | # If true, show page references after internal links. 241 | # latex_show_pagerefs = False 242 | 243 | # If true, show URL addresses after external links. 244 | # latex_show_urls = False 245 | 246 | # Documents to append as an appendix to all manuals. 247 | # latex_appendices = [] 248 | 249 | # If false, no module index is generated. 250 | # latex_domain_indices = True 251 | 252 | 253 | # -- Options for manual page output --------------------------------------- 254 | 255 | # One entry per manual page. List of tuples 256 | # (source start file, name, description, authors, manual section). 257 | man_pages = [ 258 | ( 259 | "index", 260 | "xarray_extras", 261 | "xarray_extras Documentation", 262 | ["xarray_extras Developers"], 263 | 1, 264 | ) 265 | ] 266 | 267 | # If true, show URL addresses after external links. 268 | # man_show_urls = False 269 | 270 | 271 | # -- Options for Texinfo output ------------------------------------------- 272 | 273 | # Grouping the document tree into Texinfo files. List of tuples 274 | # (source start file, target name, title, author, 275 | # dir menu entry, description, category) 276 | texinfo_documents = [ 277 | ( 278 | "index", 279 | "xarray_extras", 280 | "xarray_extras Documentation", 281 | "xarray_extras Developers", 282 | "xarray_extras", 283 | "Advanced / experimental algorithms for xarray", 284 | "Miscellaneous", 285 | ), 286 | ] 287 | 288 | # Documents to append as an appendix to all manuals. 289 | # texinfo_appendices = [] 290 | 291 | # If false, no module index is generated. 292 | # texinfo_domain_indices = True 293 | 294 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 295 | # texinfo_show_urls = 'footnote' 296 | 297 | # If true, do not generate a @detailmenu in the "Top" node's menu. 298 | # texinfo_no_detailmenu = False 299 | 300 | 301 | # Example configuration for intersphinx: refer to the Python standard library. 302 | intersphinx_mapping = { 303 | "python": ("https://docs.python.org/3/", None), 304 | "dask": ("https://docs.dask.org/en/latest/", None), 305 | "distributed": ("https://distributed.dask.org/en/latest/", None), 306 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 307 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 308 | "numba": ("https://numba.pydata.org/numba-doc/latest/", None), 309 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), 310 | "xarray": ("https://xarray.pydata.org/en/stable/", None), 311 | } 312 | 313 | # Work around intersphinx issue 314 | xarray.DataArray.__module__ = "xarray" 315 | xarray.Dataset.__module__ = "xarray" 316 | xarray.Variable.__module__ = "xarray" 317 | -------------------------------------------------------------------------------- /xarray_extras/tests/test_csv.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import gzip 3 | import lzma 4 | import pickle 5 | import tempfile 6 | from pathlib import Path 7 | 8 | import dask 9 | import numpy as np 10 | import pytest 11 | import xarray 12 | 13 | from xarray_extras.csv import to_csv 14 | 15 | 16 | def assert_to_csv( 17 | x, chunks, nogil, dtype, open_func=open, float_format="%f", ext="csv", **kwargs 18 | ): 19 | x = x.astype(dtype) 20 | if chunks: 21 | x = x.chunk(chunks) 22 | with tempfile.TemporaryDirectory() as tmp: 23 | x.to_pandas().to_csv(tmp + "/1." + ext, float_format=float_format, **kwargs) 24 | f = to_csv( 25 | x, tmp + "/2." + ext, nogil=nogil, float_format=float_format, **kwargs 26 | ) 27 | dask.compute(f) 28 | 29 | with open_func(tmp + "/1." + ext, "rb") as fh: 30 | d1 = fh.read() 31 | with open_func(tmp + "/2." + ext, "rb") as fh: 32 | d2 = fh.read() 33 | assert d2 == d1 34 | 35 | 36 | def assert_to_csv_with_path_type( 37 | x, 38 | path_type, 39 | chunks, 40 | nogil, 41 | dtype, 42 | open_func=open, 43 | float_format="%f", 44 | ext="csv", 45 | **kwargs, 46 | ): 47 | x = x.astype(dtype) 48 | if chunks: 49 | x = x.chunk(chunks) 50 | with tempfile.TemporaryDirectory() as tmp: 51 | path_1 = tmp + "/1." + ext if path_type == "str" else Path(tmp) / ("1." + ext) 52 | path_2 = tmp + "/2." + ext if path_type == "str" else Path(tmp) / ("2." + ext) 53 | x.to_pandas().to_csv(path_1, float_format=float_format, **kwargs) 54 | f = to_csv(x, path_2, nogil=nogil, float_format=float_format, **kwargs) 55 | dask.compute(f) 56 | 57 | with open_func(tmp + "/1." + ext, "rb") as fh: 58 | d1 = fh.read() 59 | with open_func(tmp + "/2." + ext, "rb") as fh: 60 | d2 = fh.read() 61 | assert d2 == d1 62 | 63 | 64 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 65 | @pytest.mark.parametrize("nogil", [False, True]) 66 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 67 | @pytest.mark.parametrize("header", [False, True]) 68 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"]) 69 | def test_series(chunks, nogil, dtype, header, lineterminator): 70 | x = xarray.DataArray([1, 2, 3, 4], dims=["x"], coords={"x": [10, 20, 30, 40]}) 71 | assert_to_csv(x, chunks, nogil, dtype, header=header, lineterminator=lineterminator) 72 | 73 | 74 | @pytest.mark.parametrize("path_type", ["path", "str"]) 75 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 76 | @pytest.mark.parametrize("nogil", [False, True]) 77 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 78 | @pytest.mark.parametrize("header", [False, True]) 79 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"]) 80 | def test_series_with_path(path_type, chunks, nogil, dtype, header, lineterminator): 81 | x = xarray.DataArray([1, 2, 3, 4], dims=["x"], coords={"x": [10, 20, 30, 40]}) 82 | assert_to_csv_with_path_type( 83 | x, path_type, chunks, nogil, dtype, header=header, lineterminator=lineterminator 84 | ) 85 | 86 | 87 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 88 | @pytest.mark.parametrize("nogil", [False, True]) 89 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}]) 90 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"]) 91 | def test_dataframe(chunks, nogil, dtype, lineterminator): 92 | x = xarray.DataArray( 93 | [[1, 2, 3, 4], [5, 6, 7, 8]], 94 | dims=["r", "c"], 95 | coords={"r": ["a", "b"], "c": [10, 20, 30, 40]}, 96 | ) 97 | assert_to_csv(x, chunks, nogil, dtype, lineterminator=lineterminator) 98 | 99 | 100 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 101 | @pytest.mark.parametrize("nogil", [False, True]) 102 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}]) 103 | def test_multiindex(chunks, nogil, dtype): 104 | x = xarray.DataArray( 105 | [[1, 2], [3, 4]], 106 | dims=["r", "c"], 107 | coords={ 108 | "r1": ("r", ["r11", "r12"]), 109 | "r2": ("r", ["r21", "r22"]), 110 | "c1": ("c", ["c11", "c12"]), 111 | "c2": ("c", ["c21", "c22"]), 112 | }, 113 | ) 114 | x = x.set_index(r=["r1", "r2"], c=["c1", "c2"]) 115 | assert_to_csv(x, chunks, nogil, dtype) 116 | 117 | 118 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 119 | @pytest.mark.parametrize("nogil", [False, True]) 120 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}]) 121 | def test_no_header(chunks, nogil, dtype): 122 | x = xarray.DataArray([[1, 2], [3, 4]], dims=["r", "c"]) 123 | assert_to_csv(x, chunks, nogil, dtype, index=False, header=False) 124 | 125 | 126 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 127 | @pytest.mark.parametrize("nogil", [False, True]) 128 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}]) 129 | def test_custom_header(chunks, nogil, dtype): 130 | x = xarray.DataArray([[1, 2], [3, 4]], dims=["r", "c"]) 131 | assert_to_csv(x, chunks, nogil, dtype, header=["foo", "bar"]) 132 | 133 | 134 | @pytest.mark.parametrize("encoding", ["utf-8", "utf-16"]) 135 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 136 | @pytest.mark.parametrize("nogil", [False, True]) 137 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}]) 138 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"]) 139 | def test_encoding(chunks, nogil, dtype, encoding, lineterminator): 140 | # Note: in Python 2.7, default encoding is ascii in pandas and utf-8 in 141 | # xarray_extras. Therefore we will not test the default. 142 | x = xarray.DataArray( 143 | [[1], [2]], dims=["r", "c"], coords={"r": ["crème", "foo"], "c": ["brûlée"]} 144 | ) 145 | assert_to_csv( 146 | x, chunks, nogil, dtype, encoding=encoding, lineterminator=lineterminator 147 | ) 148 | 149 | 150 | @pytest.mark.parametrize("sep", [",", "|"]) 151 | @pytest.mark.parametrize("float_format", ["%f", "%.2f", "%.15f", "%.5e"]) 152 | @pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64]) 153 | @pytest.mark.parametrize("nogil", [False, True]) 154 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 155 | def test_kwargs(chunks, nogil, dtype, float_format, sep): 156 | x = xarray.DataArray([1.0, 1.1, 1.000000000000001, 123.456789], dims=["x"]) 157 | assert_to_csv(x, chunks, nogil, dtype, float_format=float_format, sep=sep) 158 | 159 | 160 | @pytest.mark.parametrize("na_rep", ["", "nan"]) 161 | @pytest.mark.parametrize("nogil", [False, True]) 162 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 163 | def test_na_rep(chunks, nogil, na_rep): 164 | x = xarray.DataArray([np.nan, 1], dims=["x"]) 165 | assert_to_csv(x, chunks, nogil, np.float64, na_rep=na_rep) 166 | 167 | 168 | @pytest.mark.parametrize( 169 | "compression,open_func", 170 | [ 171 | (None, open), 172 | ("gzip", gzip.open), 173 | ("bz2", bz2.open), 174 | ("xz", lzma.open), 175 | ], 176 | ) 177 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 178 | @pytest.mark.parametrize("nogil", [False, True]) 179 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 180 | def test_compression(chunks, nogil, dtype, compression, open_func): 181 | x = xarray.DataArray([1, 2], dims=["x"]) 182 | assert_to_csv(x, chunks, nogil, dtype, compression=compression, open_func=open_func) 183 | 184 | 185 | @pytest.mark.parametrize( 186 | "ext,open_func", 187 | [ 188 | ("csv", open), 189 | ("csv.gz", gzip.open), 190 | ("csv.bz2", bz2.open), 191 | ("csv.xz", lzma.open), 192 | ], 193 | ) 194 | @pytest.mark.parametrize("nogil", [False, True]) 195 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 196 | def test_compression_infer(ext, open_func, nogil, chunks): 197 | x = xarray.DataArray([1, 2], dims=["x"]) 198 | assert_to_csv( 199 | x, chunks=chunks, nogil=nogil, dtype=np.float64, ext=ext, open_func=open_func 200 | ) 201 | 202 | 203 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 204 | @pytest.mark.parametrize("nogil", [False, True]) 205 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}]) 206 | def test_empty(chunks, nogil, dtype): 207 | x = xarray.DataArray( 208 | [[1, 2, 3, 4]], dims=["r", "c"], coords={"c": [10, 20, 30, 40]} 209 | ) 210 | x = x.isel(r=slice(0)) 211 | assert_to_csv(x, chunks, nogil, dtype) 212 | 213 | 214 | @pytest.mark.parametrize("x", [0, -(2**63)]) 215 | @pytest.mark.parametrize("index", ["a", "a" * 1000]) 216 | @pytest.mark.parametrize("nogil", [False, True]) 217 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 218 | def test_buffer_overflow_int(chunks, nogil, index, x): 219 | a = xarray.DataArray([x], dims=["x"], coords={"x": [index]}) 220 | assert_to_csv(a, chunks, nogil, np.int64) 221 | 222 | 223 | @pytest.mark.parametrize("x", [0, np.nan, 1.000000000000001, 1.7901234406790122e308]) 224 | @pytest.mark.parametrize("index,coord", [(False, ""), (True, "a"), (True, "a" * 1000)]) 225 | @pytest.mark.parametrize("na_rep", ["", "na" * 500]) 226 | @pytest.mark.parametrize("float_format", ["%.16f", "%.1000f", "a" * 1000 + "%.0f"]) 227 | @pytest.mark.parametrize("nogil", [False, True]) 228 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 229 | def test_buffer_overflow_float(chunks, nogil, float_format, na_rep, index, coord, x): 230 | if nogil and not index and np.isnan(x) and na_rep == "": 231 | # Expected: b'""\n' 232 | # Actual: b'\n' 233 | pytest.xfail("pandas prints useless for empty lines") 234 | 235 | a = xarray.DataArray([x], dims=["x"], coords={"x": [coord]}) 236 | assert_to_csv( 237 | a, 238 | chunks, 239 | nogil, 240 | np.float64, 241 | float_format=float_format, 242 | na_rep=na_rep, 243 | index=index, 244 | ) 245 | 246 | 247 | @pytest.mark.parametrize("encoding", ["utf-8", "utf-16"]) 248 | @pytest.mark.parametrize("dtype", [str, object]) 249 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 250 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"]) 251 | def test_pandas_only(chunks, dtype, encoding, lineterminator): 252 | x = xarray.DataArray(["foo", "Crème brûlée"], dims=["x"]) 253 | assert_to_csv( 254 | x, 255 | chunks=chunks, 256 | nogil=False, 257 | dtype=dtype, 258 | encoding=encoding, 259 | lineterminator=lineterminator, 260 | ) 261 | 262 | 263 | @pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) 264 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 265 | def test_pandas_only_complex(chunks, dtype): 266 | x = xarray.DataArray([1 + 2j], dims=["x"]) 267 | assert_to_csv(x, chunks=chunks, nogil=False, dtype=dtype) 268 | 269 | 270 | @pytest.mark.parametrize("nogil", [False, True]) 271 | @pytest.mark.parametrize("chunks", [None, {"x": 1}]) 272 | def test_mode(chunks, nogil): 273 | x = xarray.DataArray([1, 2], dims=["x"]) 274 | y = xarray.DataArray([3, 4], dims=["x"]) 275 | if chunks: 276 | x = x.chunk(chunks) 277 | y = y.chunk(chunks) 278 | 279 | with tempfile.TemporaryDirectory() as tmp: 280 | f = to_csv(x, tmp + "/1.csv", mode="a", nogil=nogil, header=False, index=False) 281 | dask.compute(f) 282 | f = to_csv(y, tmp + "/1.csv", mode="a", nogil=nogil, header=False, index=False) 283 | dask.compute(f) 284 | with open(tmp + "/1.csv") as fh: 285 | assert fh.read() == "1\n2\n3\n4\n" 286 | 287 | f = to_csv(y, tmp + "/1.csv", mode="w", nogil=nogil, header=False, index=False) 288 | dask.compute(f) 289 | with open(tmp + "/1.csv") as fh: 290 | assert fh.read() == "3\n4\n" 291 | 292 | 293 | def test_none_fmt(): 294 | """float_format=None differs between C and pandas; can't use assert_to_csv""" 295 | x = xarray.DataArray([1.0, 1.1, 1.000000000000001, 123.456789]) 296 | y = x.astype(np.float32) 297 | 298 | with tempfile.TemporaryDirectory() as tmp: 299 | to_csv(x, tmp + "/1.csv", header=False) 300 | to_csv(y, tmp + "/2.csv", header=False) 301 | 302 | with open(tmp + "/1.csv") as fh: 303 | assert fh.read() == "0,1.0\n1,1.1\n2,1.0\n3,123.456789\n" 304 | with open(tmp + "/2.csv") as fh: 305 | assert fh.read() == "0,1.0\n1,1.1\n2,1.0\n3,123.456787\n" 306 | 307 | 308 | def test_pickle(): 309 | x = xarray.DataArray([1, 2]) 310 | with tempfile.TemporaryDirectory() as tmp: 311 | x.to_pandas().to_csv(tmp + "/1.csv") 312 | d = to_csv(x.chunk(1), tmp + "/2.csv") 313 | d = pickle.loads(pickle.dumps(d)) 314 | d.compute() 315 | 316 | with open(tmp + "/1.csv", "rb") as fh: 317 | d1 = fh.read() 318 | with open(tmp + "/2.csv", "rb") as fh: 319 | d2 = fh.read() 320 | assert d1 == d2 321 | --------------------------------------------------------------------------------