├── xarray_extras
├── py.typed
├── tests
│ ├── __init__.py
│ ├── test_numba_extras.py
│ ├── test_sort.py
│ ├── test_stack.py
│ ├── test_cumulatives.py
│ ├── test_interpolate.py
│ └── test_csv.py
├── kernels
│ ├── __init__.py
│ ├── cumulatives.py
│ ├── csv.py
│ ├── np_to_csv_py.py
│ ├── np_to_csv.c
│ └── interpolate.py
├── duck
│ ├── __init__.py
│ └── sort.py
├── compat.py
├── __init__.py
├── numba_extras.py
├── stack.py
├── sort.py
├── cumulatives.py
├── csv.py
└── interpolate.py
├── doc
├── _static
│ ├── .gitignore
│ └── style.css
├── _templates
│ └── layout.html
├── api
│ ├── csv.rst
│ ├── stack.rst
│ ├── cumulatives.rst
│ ├── interpolate.rst
│ ├── numba_extras.rst
│ └── sort.rst
├── installing.rst
├── develop.rst
├── index.rst
├── whats-new.rst
└── conf.py
├── .gitattributes
├── .github
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ ├── pre-commit.yml
│ ├── docs.yml
│ └── pytest.yml
├── ci
├── requirements-latest.yml
├── requirements-docs.yml
├── requirements-minimal.yml
└── requirements-upstream.yml
├── setup.py
├── .readthedocs.yaml
├── MANIFEST.in
├── README.md
├── .gitignore
├── .pre-commit-config.yaml
├── HOW_TO_RELEASE
├── pyproject.toml
└── LICENSE
/xarray_extras/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/xarray_extras/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/doc/_static/.gitignore:
--------------------------------------------------------------------------------
1 | examples*.png
2 | *.log
3 | *.pdf
4 | *.fbd_latexmk
5 | *.aux
6 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # reduce the number of merge conflicts
2 | doc/whats-new.rst merge=union
3 |
--------------------------------------------------------------------------------
/doc/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %} {% set css_files = css_files +
2 | ["_static/style.css"] %}
3 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | - [ ] Closes #xxxx
2 | - [ ] Tests added / passed
3 | - [ ] Passes `pre-commit run --all-files`
4 |
--------------------------------------------------------------------------------
/doc/api/csv.rst:
--------------------------------------------------------------------------------
1 | csv
2 | ===
3 |
4 | .. automodule:: xarray_extras.csv
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/doc/api/stack.rst:
--------------------------------------------------------------------------------
1 | stack
2 | =====
3 |
4 | .. automodule:: xarray_extras.stack
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/doc/api/cumulatives.rst:
--------------------------------------------------------------------------------
1 | cumulatives
2 | ===========
3 |
4 | .. automodule:: xarray_extras.cumulatives
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/doc/api/interpolate.rst:
--------------------------------------------------------------------------------
1 | interpolate
2 | ===========
3 |
4 | .. automodule:: xarray_extras.interpolate
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/doc/api/numba_extras.rst:
--------------------------------------------------------------------------------
1 | numba\_extras
2 | =============
3 |
4 | .. automodule:: xarray_extras.numba_extras
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/xarray_extras/kernels/__init__.py:
--------------------------------------------------------------------------------
1 | """Pure numpy functions that either run in dask or are used directly.
2 | Typyically invoked through ``xarray.apply_ufunc(dask='parallelized').
3 | """
4 |
--------------------------------------------------------------------------------
/xarray_extras/duck/__init__.py:
--------------------------------------------------------------------------------
1 | """Functions that accept either a numpy array or a dask array, and return
2 | another of the matching type.
3 | Typyically invoked through ``xarray.apply_ufunc(dask='allowed').
4 | """
5 |
--------------------------------------------------------------------------------
/ci/requirements-latest.yml:
--------------------------------------------------------------------------------
1 | name: xarray-extras
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - dask
6 | - numba
7 | - numpy
8 | - pandas
9 | - pytest
10 | - pytest-cov
11 | - scipy
12 | - xarray
13 |
--------------------------------------------------------------------------------
/ci/requirements-docs.yml:
--------------------------------------------------------------------------------
1 | name: xarray-extras-docs
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python=3.12
6 | - sphinx
7 | - sphinx_rtd_theme
8 | - gcc_linux-64
9 | - dask
10 | - numba
11 | - scipy
12 | - xarray
13 |
--------------------------------------------------------------------------------
/ci/requirements-minimal.yml:
--------------------------------------------------------------------------------
1 | name: xarray-extras-minimal
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - dask=2022.6.0
6 | - numba=0.56
7 | - numpy=1.23
8 | - pandas=1.5
9 | - pytest
10 | - pytest-cov
11 | - scipy=1.9
12 | - xarray=2022.11.0
13 |
--------------------------------------------------------------------------------
/xarray_extras/compat.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | try:
4 | from xarray.namedarray.pycompat import array_type
5 | except ImportError: # <2024.2.0
6 | from xarray.core.pycompat import array_type # type: ignore[no-redef]
7 |
8 | dask_array_type = array_type("dask")
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import Extension, setup
2 |
3 | setup(
4 | use_scm_version=True,
5 | # Compile CPython extensions
6 | ext_modules=[
7 | Extension(
8 | "xarray_extras.kernels.np_to_csv", ["xarray_extras/kernels/np_to_csv.c"]
9 | )
10 | ],
11 | )
12 |
--------------------------------------------------------------------------------
/xarray_extras/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | try:
4 | __version__ = importlib.metadata.version("xarray_extras")
5 | except importlib.metadata.PackageNotFoundError: # pragma: nocover
6 | # Local copy, not installed with pip
7 | __version__ = "999"
8 |
9 | __all__ = ("__version__",)
10 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: mambaforge-22.9
7 |
8 | conda:
9 | environment: ci/requirements-docs.yml
10 |
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 |
16 | sphinx:
17 | builder: html
18 | configuration: doc/conf.py
19 | fail_on_warning: false
20 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include HOW_TO_RELEASE
2 | include LICENSE
3 | include *.md
4 | include *.py
5 | recursive-include ci *
6 | recursive-include doc *
7 | include xarray_extras *
8 | prune doc/_build
9 | global-exclude __pycache__
10 | global-exclude *.pyc
11 | global-exclude .DS_Store
12 | global-exclude .ipynb_checkpoints
13 | global-exclude dask-worker-space
14 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: Linting
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: ["*"]
8 |
9 | jobs:
10 | checks:
11 | name: pre-commit hooks
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: actions/setup-python@v5
16 | - uses: pre-commit/action@v3.0.0
17 |
--------------------------------------------------------------------------------
/doc/_static/style.css:
--------------------------------------------------------------------------------
1 | @import url("theme.css");
2 |
3 | .wy-side-nav-search > a img.logo,
4 | .wy-side-nav-search .wy-dropdown > a img.logo {
5 | width: 12rem;
6 | }
7 |
8 | .wy-side-nav-search {
9 | background-color: #eee;
10 | }
11 |
12 | .wy-side-nav-search > div.version {
13 | display: none;
14 | }
15 |
16 | .wy-nav-top {
17 | background-color: #555;
18 | }
19 |
--------------------------------------------------------------------------------
/ci/requirements-upstream.yml:
--------------------------------------------------------------------------------
1 | name: xarray-extras
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - dask
6 | # - numba # Not compatible with numpy 2
7 | - numpy
8 | - pandas
9 | - pyarrow
10 | - pytest
11 | - pytest-cov
12 | - scipy
13 | - xarray
14 | - pip
15 | - pip:
16 | - git+https://github.com/dask/dask
17 | - git+https://github.com/dask/distributed
18 | - git+https://github.com/pydata/xarray
19 | # numpy, pandas, pyarrow, and scipy are upgraded to nightly builds by pytest.yml
20 |
--------------------------------------------------------------------------------
/doc/installing.rst:
--------------------------------------------------------------------------------
1 | .. _installing:
2 |
3 | Installation
4 | ============
5 |
6 | Required dependencies
7 | ---------------------
8 |
9 | - Python 3.8 or later
10 | - `scipy `__
11 | - `xarray `__
12 | - `dask `__
13 | - `numba `__
14 | - C compiler (only if building from sources)
15 |
16 | Deployment
17 | ----------
18 |
19 | - With pip: :command:`pip install xarray-extras`
20 | - With `anaconda `_:
21 | :command:`conda install -c conda-forge xarray-extras`
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # xarray_extras
2 |
3 | [](https://github.com/crusaderky/xarray_extras/actions)
4 | [](https://github.com/crusaderky/xarray_extras/actions)
5 | [](https://github.com/crusaderky/xarray_extras/actions)
6 | [](https://codecov.io/gh/crusaderky/xarray_extras/branch/main)
7 |
8 | Advanced / experimental algorithms for xarray
9 |
10 | Full documentation at http://xarray-extras.readthedocs.io/
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__
3 |
4 | # C extensions
5 | *.so
6 |
7 | # Packages
8 | *.egg
9 | *.egg-info
10 | .eggs
11 | dist
12 | build
13 | eggs
14 | parts
15 | var
16 | sdist
17 | develop-eggs
18 | .installed.cfg
19 | lib
20 | lib64
21 |
22 | # Installer logs
23 | pip-log.txt
24 |
25 | # Unit test / coverage reports
26 | .coverage*
27 | coverage.xml
28 | .tox
29 | nosetests.xml
30 | .cache
31 | .mypy_cache/
32 | .ropeproject/
33 | .tags*
34 | .testmon*
35 | .pytest_cache
36 | dask-worker-space/
37 |
38 | # asv environments
39 | .asv
40 |
41 | # Translations
42 | *.mo
43 |
44 | # Mr Developer
45 | .mr.developer.cfg
46 | .project
47 | .pydevproject
48 |
49 | # IDEs
50 | .idea
51 | *.swp
52 | .DS_Store
53 | .vscode/
54 |
55 | # Sync tools
56 | Icon*
57 |
58 | .ipynb_checkpoints
59 |
60 | doc/_build
61 | Untitled.ipynb
62 | htmlcov/
63 |
--------------------------------------------------------------------------------
/xarray_extras/tests/test_numba_extras.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | pytest.importorskip("numba") # Not available in upstream CI
5 |
6 | from xarray_extras.numba_extras import guvectorize
7 |
8 | DTYPES = [
9 | # uint needs to appear before signed int:
10 | # https://github.com/numba/numba/issues/2934
11 | "uint8",
12 | "uint16",
13 | "uint32",
14 | "uint64",
15 | "int8",
16 | "int16",
17 | "int32",
18 | "int64",
19 | "float32",
20 | "float64",
21 | "complex64",
22 | "complex128",
23 | ]
24 |
25 |
26 | @guvectorize("{T}[:], {T}[:]", "()->()")
27 | def dumb_copy(x, y):
28 | for i in range(x.size):
29 | y.flat[i] = x.flat[i]
30 |
31 |
32 | @pytest.mark.parametrize("dtype", DTYPES)
33 | def test_guvectorize(dtype):
34 | x = np.arange(3, dtype=dtype)
35 | y = dumb_copy(x)
36 | np.testing.assert_equal(x, y)
37 | assert x.dtype == y.dtype
38 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: ["*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash -l {0}
12 |
13 | jobs:
14 | build:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Checkout
18 | uses: actions/checkout@v4
19 |
20 | - name: Setup Conda Environment
21 | uses: conda-incubator/setup-miniconda@v3
22 | with:
23 | miniforge-version: latest
24 | use-mamba: true
25 | environment-file: ci/requirements-docs.yml
26 | activate-environment: xarray-extras-docs
27 |
28 | - name: Show conda options
29 | run: conda config --show
30 |
31 | - name: conda info
32 | run: conda info
33 |
34 | - name: conda list
35 | run: conda list
36 |
37 | - name: Install
38 | run: python -m pip install --no-deps -e .
39 |
40 | - name: Build docs
41 | run: sphinx-build -n -j auto -b html -d build/doctrees doc build/html
42 |
43 | - uses: actions/upload-artifact@v4
44 | with:
45 | name: xarray_extras-docs
46 | path: build/html
47 |
--------------------------------------------------------------------------------
/doc/api/sort.rst:
--------------------------------------------------------------------------------
1 | sort
2 | ====
3 |
4 | .. automodule:: xarray_extras.sort
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 |
10 | An example that uses all of the above functions is *source attribution*.
11 | Given a generic function :math:`y = f(x_{0}, x_{1}, ..., x_{i})`, which is
12 | embarrassingly parallel along a given dimension, one wants to find:
13 |
14 | - the top k elements of y along the dimension
15 | - the elements of all x's that generated the top k elements of y
16 |
17 | .. code::
18 |
19 | >>> from xarray import DataArray
20 | >>> from xarray_extras.sort import *
21 | >>> x = DataArray([[5, 3, 2, 8, 1],
22 | >>> [0, 7, 1, 3, 2]], dims=['x', 's'])
23 | >>> y = x.sum('x') # y = f(x), embarrassingly parallel among dimension 's'
24 | >>> y
25 |
26 | array([ 5, 10, 3, 11, 3])
27 | Dimensions without coordinates: s
28 | >>> top_y = topk(y, 3, 's')
29 | >>> top_y
30 |
31 | array([11, 10, 5])
32 | Dimensions without coordinates: s
33 | >>> top_x = take_along_dim(x, argtopk(y, 3, 's'), 's')
34 | >>> top_x
35 |
36 | array([[8, 3, 5],
37 | [3, 7, 0]])
38 | Dimensions without coordinates: x, s
39 |
--------------------------------------------------------------------------------
/xarray_extras/kernels/cumulatives.py:
--------------------------------------------------------------------------------
1 | """Numba kernels for :mod:`cumulatives`"""
2 |
3 | import numpy as np
4 |
5 | from xarray_extras.numba_extras import guvectorize
6 |
7 |
8 | @guvectorize("{T}[:], intp[:], {T}[:]", "(i),(j)->()")
9 | def compound_sum(x: np.ndarray, c: np.ndarray, y: np.ndarray) -> None:
10 | """y = x[c[0]] + x[c[1]] + ... x[c[n]]
11 | until c[i] != -1
12 | """
13 | acc = 0
14 | for i in c:
15 | if i == -1:
16 | break
17 | acc += x[i]
18 | y[0] = acc
19 |
20 |
21 | @guvectorize("{T}[:], intp[:], {T}[:]", "(i),(j)->()")
22 | def compound_prod(x: np.ndarray, c: np.ndarray, y: np.ndarray) -> None:
23 | """y = x[c[0]] * x[c[1]] * ... x[c[n]]
24 | until c[i] != -1
25 | """
26 | acc = 1
27 | for i in c:
28 | if i == -1:
29 | break
30 | acc *= x[i]
31 | y[0] = acc
32 |
33 |
34 | @guvectorize("{T}[:], intp[:], {T}[:]", "(i),(j)->()")
35 | def compound_mean(x: np.ndarray, c: np.ndarray, y: np.ndarray) -> None:
36 | """y = mean(x[c[0]], x[c[1]], ... x[c[n]])
37 | until c[i] != -1
38 | """
39 | acc = 0
40 | j = 0 # Initialise j explicitly for when x.shape == (0, )
41 | for j, i in enumerate(c): # noqa: B007
42 | if i == -1:
43 | break
44 | acc += x[i]
45 | else:
46 | # Reached the end of the row
47 | j += 1
48 | y[0] = acc / j
49 |
--------------------------------------------------------------------------------
/xarray_extras/numba_extras.py:
--------------------------------------------------------------------------------
1 | """Extensions to numba"""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Callable
6 | from typing import Any
7 |
8 | import numba
9 |
10 | _DTYPES = (
11 | # uint needs to appear before signed int:
12 | # https://github.com/numba/numba/issues/2934
13 | "uint8",
14 | "uint16",
15 | "uint32",
16 | "uint64",
17 | "int8",
18 | "int16",
19 | "int32",
20 | "int64",
21 | "float32",
22 | "float64",
23 | "complex64",
24 | "complex128",
25 | )
26 |
27 |
28 | def guvectorize(
29 | signature: str, layout: str, **kwargs: Any
30 | ) -> Callable[[Callable], Any]:
31 | """Convenience wrapper around :func:`numba.guvectorize`.
32 | Generate signature for all possible data types and set a few healthy
33 | defaults.
34 |
35 | :param str signature:
36 | numba signature, containing {T}
37 | :param str layout:
38 | as in :func:`numba.guvectorize`
39 | :param kwargs:
40 | passed verbatim to :func:`numba.guvectorize`.
41 | This function changes the default for cache from False to True.
42 |
43 | example::
44 |
45 | guvectorize("{T}[:], {T}[:]", "(i)->(i)")
46 |
47 | Is the same as::
48 |
49 | numba.guvectorize([
50 | "float32[:], float32[:]",
51 | "float64[:], float64[:]",
52 | ...
53 | ], "(i)->(i)", cache=True)
54 |
55 | .. note::
56 | Discussing upstream fix; see
57 | ``_.
58 | """
59 | if "{T}" in signature:
60 | signatures = [signature.format(T=dtype) for dtype in _DTYPES]
61 | else:
62 | signatures = [signature]
63 | kwargs.setdefault("cache", True)
64 | return numba.guvectorize(signatures, layout, **kwargs)
65 |
--------------------------------------------------------------------------------
/xarray_extras/stack.py:
--------------------------------------------------------------------------------
1 | """Utilities for stacking/unstacking dimensions"""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Hashable
6 | from typing import TypeVar
7 |
8 | import pandas as pd
9 | import xarray
10 |
11 | T = TypeVar("T", xarray.DataArray, xarray.Dataset)
12 |
13 |
14 | def proper_unstack(array: T, dim: Hashable) -> T:
15 | """Work around an issue in xarray that causes the data to be sorted
16 | alphabetically by label on unstack():
17 |
18 | ``_
19 |
20 | Also work around issue that causes string labels to be converted to
21 | objects:
22 |
23 | ``_
24 |
25 | :param array:
26 | xarray.DataArray or xarray.Dataset to unstack
27 | :param str dim:
28 | Name of existing dimension to unstack
29 | :returns:
30 | xarray.DataArray or xarray.Dataset with unstacked dimension
31 | """
32 | # Regenerate Pandas multi-index to be ordered by first appearance
33 | mindex = array.coords[dim].to_pandas().index
34 |
35 | levels = []
36 | codes = []
37 |
38 | for levels_i, codes_i in zip(mindex.levels, mindex.codes):
39 | level_map: dict[Hashable, int] = {}
40 |
41 | for code in codes_i:
42 | if code not in level_map:
43 | level_map[code] = len(level_map)
44 |
45 | levels.append([levels_i[k] for k in level_map])
46 | codes.append([level_map[k] for k in codes_i])
47 |
48 | mindex = pd.MultiIndex(levels, codes, names=mindex.names)
49 | array = array.copy()
50 | array.coords[dim] = mindex
51 |
52 | # Invoke builtin unstack
53 | array = array.unstack((dim,))
54 |
55 | # Convert numpy arrays of Python objects to numpy arrays of C floats, ints,
56 | # strings, etc.
57 | for name in mindex.names:
58 | if array.coords[name].dtype == object:
59 | array.coords[name] = array.coords[name].values.tolist()
60 |
61 | return array
62 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-added-large-files
6 | - id: check-case-conflict
7 | - id: check-merge-conflict
8 | - id: check-symlinks
9 | - id: check-yaml
10 | - id: debug-statements
11 | - id: end-of-file-fixer
12 | - id: mixed-line-ending
13 | - id: name-tests-test
14 | args: ["--pytest-test-first"]
15 | - id: requirements-txt-fixer
16 | - id: trailing-whitespace
17 |
18 | - repo: https://github.com/rbubley/mirrors-prettier
19 | rev: v3.4.2
20 | hooks:
21 | - id: prettier
22 | types_or: [yaml, markdown, html, css, scss, javascript, json]
23 | args: [--prose-wrap=always]
24 |
25 | - repo: https://github.com/astral-sh/ruff-pre-commit
26 | rev: v0.9.4
27 | hooks:
28 | - id: ruff-format
29 | - id: ruff
30 | args: ["--fix", "--show-fixes"]
31 |
32 | - repo: https://github.com/codespell-project/codespell
33 | rev: v2.4.1
34 | hooks:
35 | - id: codespell
36 | additional_dependencies:
37 | - tomli
38 |
39 | - repo: https://github.com/shellcheck-py/shellcheck-py
40 | rev: "v0.10.0.1"
41 | hooks:
42 | - id: shellcheck
43 |
44 | - repo: https://github.com/abravalheri/validate-pyproject
45 | rev: v0.23
46 | hooks:
47 | - id: validate-pyproject
48 | additional_dependencies: ["validate-pyproject-schema-store[all]"]
49 |
50 | - repo: https://github.com/python-jsonschema/check-jsonschema
51 | rev: "0.31.1"
52 | hooks:
53 | - id: check-dependabot
54 | - id: check-github-workflows
55 |
56 | - repo: https://github.com/pre-commit/mirrors-mypy
57 | rev: v1.14.1
58 | hooks:
59 | - id: mypy
60 | additional_dependencies:
61 | # Libraries exclusively imported under `if TYPE_CHECKING:`
62 | - typing_extensions
63 | # Typed libraries
64 | - numpy
65 | - dask
66 | - xarray
67 | - scipy
68 |
--------------------------------------------------------------------------------
/doc/develop.rst:
--------------------------------------------------------------------------------
1 | Development Guidelines
2 | ======================
3 |
4 | Install
5 | -------
6 |
7 | 1. Clone this repository with git:
8 |
9 | .. code-block:: bash
10 |
11 | git clone git@github.com:crusaderky/xarray_extras.git
12 | cd xarray_extras
13 |
14 | 2. Install anaconda or miniconda (OS-dependent)
15 | 3. .. code-block:: bash
16 |
17 | conda env create -n xarray_extras --file ci/requirements.yml
18 | conda activate xarray_extras
19 |
20 | 4. Install C compilation stack:
21 |
22 | Linux
23 | .. code-block:: bash
24 |
25 | conda install gcc_linux-64
26 |
27 | MacOSX
28 | .. code-block:: bash
29 |
30 | conda install clang_osx-64
31 |
32 | Windows
33 | You need to manually install the Microsoft C compiler tools. Refer to CPython
34 | documentation.
35 |
36 |
37 | To keep a fork in sync with the upstream source:
38 |
39 | .. code-block:: bash
40 |
41 | cd xarray_extras
42 | git remote add upstream git@github.com:crusaderky/xarray_extras.git
43 | git remote -v
44 | git fetch -a upstream
45 | git checkout main
46 | git pull upstream main
47 | git push origin main
48 |
49 | Test
50 | ----
51 |
52 | Test using ``py.test``:
53 |
54 | .. code-block:: bash
55 |
56 | python setup.py build_ext --inplace
57 | py.test xarray_extras
58 |
59 | Code Formatting
60 | ---------------
61 |
62 | xarray_extras uses several code linters (black, ruff, mypy), which are enforced by CI.
63 | Developers should run them locally before they submit a PR, through the single command
64 |
65 | .. code-block:: bash
66 |
67 | pre-commit run --all-files
68 |
69 | This makes sure that linter versions and options are aligned for all developers.
70 |
71 | Optionally, you may wish to setup the `pre-commit hooks `_ to
72 | run automatically when you make a git commit. This can be done by running:
73 |
74 | .. code-block:: bash
75 |
76 | pre-commit install
77 |
78 | from the root of the xarray_extras repository. Now the code linters will be run each time
79 | you commit changes. You can skip these checks with ``git commit --no-verify`` or with
80 | the short version ``git commit -n``.
81 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | xarray_extras: Advanced algorithms for xarray
2 | =============================================
3 | This module offers several extensions to `xarray `_,
4 | which could not be included into the main module because they fall into one or
5 | more of the following categories:
6 |
7 | - They're too experimental
8 | - They're too niche
9 | - They introduce major new dependencies (e.g.
10 | `numba `_ or a C compiler)
11 | - They would be better done by doing major rework on multiple packages, and
12 | then one would need to wait for said changes to reach a stable release of
13 | each package - *in the right order*.
14 |
15 | The API of xarray-extras is unstable by definition, as features will be
16 | progressively migrated upwards towards xarray, dask, numpy, pandas, etc.
17 |
18 | Features
19 | --------
20 | :doc:`api/csv`
21 | Multi-threaded CSV writer, much faster than
22 | :meth:`pandas.DataFrame.to_csv`, with full support for
23 | `dask `_ and
24 | `dask distributed `_.
25 | :doc:`api/cumulatives`
26 | Advanced cumulative sum/productory/mean functions
27 | :doc:`api/interpolate`
28 | dask-optimized n-dimensional spline interpolation
29 | :doc:`api/numba_extras`
30 | Additions to `numba `_
31 | :doc:`api/sort`
32 | Advanced sort/take functions
33 | :doc:`api/stack`
34 | Tools for stacking/unstacking dimensions
35 |
36 |
37 | Index
38 | -----
39 |
40 | .. toctree::
41 | :maxdepth: 1
42 |
43 | installing
44 | develop
45 | whats-new
46 | api/csv
47 | api/cumulatives
48 | api/interpolate
49 | api/numba_extras
50 | api/sort
51 | api/stack
52 |
53 |
54 | Credits
55 | -------
56 | - :func:`~xarray_extras.stack.proper_unstack` was originally developed by
57 | Legal & General and released to the open source community in 2018.
58 | - All boilerplate is from
59 | `python_project_template `_,
60 | which in turn is from `xarray `_.
61 |
62 | License
63 | -------
64 |
65 | xarray-extras is available under the open source `Apache License`__.
66 |
67 | __ http://www.apache.org/licenses/LICENSE-2.0.html
68 |
--------------------------------------------------------------------------------
/xarray_extras/sort.py:
--------------------------------------------------------------------------------
1 | """Sorting functions"""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Hashable
6 | from typing import TypeVar
7 |
8 | import xarray
9 |
10 | from xarray_extras.duck import sort as duck
11 |
12 | __all__ = ("argtopk", "take_along_dim", "topk")
13 |
14 |
15 | T = TypeVar("T", xarray.DataArray, xarray.Dataset)
16 | TV = TypeVar("TV", xarray.DataArray, xarray.Dataset, xarray.Variable)
17 |
18 |
19 | def topk(a: TV, k: int, dim: Hashable, split_every: int | None = None) -> TV:
20 | """Extract the k largest elements from a on the given dimension, and return
21 | them sorted from largest to smallest. If k is negative, extract the -k
22 | smallest elements instead, and return them sorted from smallest to largest.
23 |
24 | This assumes that ``k`` is small. All results will be returned in a single
25 | chunk along the given axis.
26 | """
27 | return xarray.apply_ufunc(
28 | duck.topk,
29 | a,
30 | kwargs={"k": k, "split_every": split_every},
31 | input_core_dims=[[dim]],
32 | output_core_dims=[["__temp_topk__"]],
33 | dask="allowed",
34 | ).rename({"__temp_topk__": dim})
35 |
36 |
37 | def argtopk(a: TV, k: int, dim: Hashable, split_every: int | None = None) -> TV:
38 | """Extract the indexes of the k largest elements from a on the given
39 | dimension, and return them sorted from largest to smallest. If k is
40 | negative, extract the -k smallest elements instead, and return them
41 | sorted from smallest to largest.
42 |
43 | This assumes that ``k`` is small. All results will be returned in a single
44 | chunk along the given axis.
45 | """
46 | return xarray.apply_ufunc(
47 | duck.argtopk,
48 | a,
49 | kwargs={"k": k, "split_every": split_every},
50 | input_core_dims=[[dim]],
51 | output_core_dims=[["__temp_topk__"]],
52 | dask="allowed",
53 | ).rename({"__temp_topk__": dim})
54 |
55 |
56 | def take_along_dim(a: T, ind: T, dim: Hashable) -> T:
57 | """Use the output of :func:`argtopk` to pick points from a.
58 |
59 | :param a:
60 | xarray.DataArray or xarray.Dataset
61 | :param ind:
62 | array of ints, as returned by :func:`argtopk`
63 | :param dim:
64 | dimension along which argtopk was executed
65 | """
66 | a = a.rename({dim: "__temp_take_along_dim__"})
67 |
68 | return xarray.apply_ufunc(
69 | duck.take_along_axis,
70 | a,
71 | ind,
72 | input_core_dims=[["__temp_take_along_dim__"], [dim]],
73 | output_core_dims=[[dim]],
74 | dask="allowed",
75 | )
76 |
--------------------------------------------------------------------------------
/xarray_extras/tests/test_sort.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from xarray import DataArray
3 | from xarray.testing import assert_equal
4 |
5 | from xarray_extras.sort import argtopk, take_along_dim, topk
6 |
7 |
8 | @pytest.mark.parametrize("use_dask", [False, True])
9 | @pytest.mark.parametrize("split_every", [None, 2])
10 | @pytest.mark.parametrize("transpose", [False, True])
11 | @pytest.mark.parametrize(
12 | "func,k,expect",
13 | [
14 | (topk, 3, [[5, 4, 3], [8, 7, 2]]),
15 | (topk, -3, [[1, 2, 3], [0, 1, 2]]),
16 | (argtopk, 3, [[3, 1, 4], [2, 0, 4]]),
17 | (argtopk, -3, [[0, 2, 4], [3, 1, 4]]),
18 | ],
19 | )
20 | def test_topk_argtopk(use_dask, split_every, transpose, func, k, expect):
21 | a = DataArray(
22 | [[1, 4, 2, 5, 3], [7, 1, 8, 0, 2]],
23 | dims=["y", "x"],
24 | coords={"y": ["y1", "y2"], "x": ["x1", "x2", "x3", "x4", "x5"]},
25 | )
26 |
27 | if transpose:
28 | a = a.T
29 | if use_dask:
30 | a = a.chunk(1)
31 |
32 | expect = DataArray(expect, dims=["y", "x"], coords={"y": ["y1", "y2"]})
33 | actual = func(a, k, "x", split_every=split_every)
34 | assert_equal(expect, actual)
35 | if use_dask:
36 | assert actual.chunks
37 |
38 |
39 | @pytest.mark.parametrize("ind_use_dask", [False, True])
40 | @pytest.mark.parametrize("a_use_dask", [False, True])
41 | def test_take_along_dim(a_use_dask, ind_use_dask):
42 | """ind.ndim < a.ndim after broadcast"""
43 | a = DataArray(
44 | [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
45 | dims=["z", "y", "x"],
46 | coords={"z": ["z1", "z2"], "y": ["y1", "y2"], "x": ["x1", "x2", "x3"]},
47 | )
48 | ind = DataArray(
49 | [[[1, 0], [2, 1]], [[2, 0], [2, 2]]],
50 | dims=["w", "y", "x"],
51 | coords={"y": ["y1", "y2"]},
52 | )
53 |
54 | expect = DataArray(
55 | [
56 | [[[2, 1], [3, 1]], [[6, 5], [6, 6]]],
57 | [[[8, 7], [9, 7]], [[12, 11], [12, 12]]],
58 | ],
59 | dims=["z", "y", "w", "x"],
60 | coords={"z": ["z1", "z2"], "y": ["y1", "y2"]},
61 | )
62 |
63 | if a_use_dask:
64 | a = a.chunk(1)
65 | if ind_use_dask:
66 | ind = ind.chunk(1)
67 |
68 | actual = take_along_dim(a, ind, "x")
69 | assert_equal(expect, actual)
70 | if a_use_dask or ind_use_dask:
71 | assert actual.chunks
72 |
73 |
74 | @pytest.mark.parametrize("ind_use_dask", [False, True])
75 | @pytest.mark.parametrize("a_use_dask", [False, True])
76 | def test_take_along_dim2(a_use_dask, ind_use_dask):
77 | """ind.ndim > a.ndim after broadcast"""
78 | a = DataArray([1, 2, 3], dims=["x"])
79 | ind = DataArray([[1, 0], [2, 1]], dims=["x", "y"])
80 |
81 | expect = DataArray([[2, 3], [1, 2]], dims=["y", "x"])
82 |
83 | if a_use_dask:
84 | a = a.chunk(1)
85 | if ind_use_dask:
86 | ind = ind.chunk(1)
87 |
88 | actual = take_along_dim(a, ind, "x")
89 | assert_equal(expect, actual)
90 | if a_use_dask or ind_use_dask:
91 | assert actual.chunks
92 |
--------------------------------------------------------------------------------
/HOW_TO_RELEASE:
--------------------------------------------------------------------------------
1 | How to issue a release in 15 easy steps
2 |
3 | Time required: about an hour.
4 |
5 | 1. Ensure your main branch is synced to origin:
6 | git pull origin main
7 | 2. Look over whats-new.rst and the docs. Make sure "What's New" is complete
8 | (check the date!) and add a brief summary note describing the release at the
9 | top.
10 | 3. If you have any doubts, run the full test suite one final time!
11 | py.test
12 | 4. On the main branch, commit the release in git:
13 | git commit -a -m 'Release vX.Y.Z'
14 | 5. Tag the release:
15 | git tag -a vX.Y.Z -m 'vX.Y.Z'
16 | 6. Push your changes to main:
17 | git push origin main
18 | git push origin --tags
19 | 7. Update the stable branch (used by ReadTheDocs) and switch back to main:
20 | git checkout stable
21 | git rebase main
22 | git push origin stable
23 | git checkout main
24 | It's OK to force push to 'stable' if necessary.
25 | We also update the stable branch with `git cherrypick` for documentation
26 | only fixes that apply the current released version.
27 | 8. Build and test the release package
28 | python -m pip install --upgrade build twine
29 | rm -rf dist
30 | python -m build
31 | python -m twine check dist/*
32 | 9. Add a section for the next release to doc/whats-new.rst.
33 | 10. Commit your changes and push to main again:
34 | git commit -a -m 'Revert to dev version'
35 | git push origin main
36 | You're done pushing to main!
37 | 11. Issue the release on GitHub. Open https://github.com/crusaderky/xarray_extras/releases;
38 | the new release should have automatically appeared. Otherwise, click on
39 | "Draft a new release" and paste in the latest from whats-new.rst.
40 | 12. Use twine to register and upload the release on pypi. Be careful, you can't
41 | take this back!
42 | twine upload dist/*
43 | You will need to be listed as a package owner at
44 | https://pypi.python.org/pypi/xarray_extras for this to work.
45 | 13. Update the docs. Login to https://readthedocs.org/projects/xarray_extras/versions/
46 | and switch your new release tag (at the bottom) from "Inactive" to "Active".
47 | It should now build automatically.
48 | Make sure that both the new tagged version and 'stable' build successfully.
49 | 14. Update conda-forge.
50 | 14a. Clone https://github.com/conda-forge/xarray_extras-feedstock
51 | 14b. Update the version number and sha256 in meta.yaml.
52 | You can calculate sha256 with
53 | sha256sum dist/*
54 | 14c. Double check dependencies in meta.yaml and update them to match pyproject.toml.
55 | 14d. Submit a pull request.
56 | 14e. Write a comment in the PR:
57 | @conda-forge-admin, please rerender
58 | Wait for the rerender commit (it may take a few minutes).
59 | 14f. Wait for CI to pass and merge.
60 | 14g. The next day, test the conda-forge release
61 | conda search xarray_extras
62 | conda create -n xarray_extras-test xarray_extras
63 | conda activate xarray_extras-test
64 | conda list
65 | python -c 'import xarray_extras'
66 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: ["*"]
8 | workflow_dispatch: # allows you to trigger manually
9 |
10 | # When this workflow is queued, automatically cancel any previous running
11 | # or pending jobs from the same branch
12 | concurrency:
13 | group: tests-${{ github.ref }}
14 | cancel-in-progress: true
15 |
16 | defaults:
17 | run:
18 | shell: bash -l {0}
19 |
20 | jobs:
21 | build:
22 | name:
23 | ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.requirements }}
24 | runs-on: ${{ matrix.os }}-latest
25 | strategy:
26 | fail-fast: false
27 | matrix:
28 | os: [ubuntu]
29 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
30 | requirements: [latest]
31 | include:
32 | # Test on macos and windows (first and last version of python only)
33 | - os: macos
34 | python-version: "3.8"
35 | requirements: latest
36 | - os: macos
37 | python-version: "3.13"
38 | requirements: latest
39 | - os: windows
40 | python-version: "3.8"
41 | requirements: latest
42 | - os: windows
43 | python-version: "3.13"
44 | requirements: latest
45 | # Test on minimal requirements
46 | - os: ubuntu
47 | python-version: "3.8"
48 | requirements: minimal
49 | - os: macos
50 | python-version: "3.8"
51 | requirements: minimal
52 | - os: windows
53 | python-version: "3.8"
54 | requirements: minimal
55 | # Test on nightly builds of requirements
56 | - os: ubuntu
57 | python-version: "3.13"
58 | requirements: upstream
59 |
60 | steps:
61 | - name: Checkout
62 | uses: actions/checkout@v4
63 | with:
64 | fetch-depth: 0
65 |
66 | - name: Setup Conda Environment
67 | uses: conda-incubator/setup-miniconda@v3
68 | with:
69 | miniforge-version: latest
70 | use-mamba: true
71 | python-version: ${{ matrix.python-version }}
72 | environment-file: ci/requirements-${{ matrix.requirements }}.yml
73 | activate-environment: xarray-extras
74 |
75 | - name: Install Linux compile env
76 | if: ${{ matrix.os == 'linux' }}
77 | run: mamba install gcc_linux-64
78 |
79 | - name: Install MacOS compile env
80 | if: ${{ matrix.os == 'macosx' }}
81 | run: mamba install clang_osx-64
82 |
83 | - name: Install nightly builds
84 | if: ${{ matrix.requirements == 'upstream' }}
85 | run: |
86 | mamba uninstall --force numpy pandas scipy pyarrow
87 | python -m pip install --no-deps --pre --prefer-binary \
88 | --extra-index-url https://pypi.fury.io/arrow-nightlies/ \
89 | pyarrow
90 | python -m pip install --no-deps --pre \
91 | -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
92 | numpy pandas scipy
93 |
94 | - name: Show conda options
95 | run: conda config --show
96 |
97 | - name: conda info
98 | run: conda info
99 |
100 | - name: conda list
101 | run: conda list
102 |
103 | - name: Install
104 | run: python -m pip install --no-deps -e .
105 |
106 | - name: pytest
107 | run: py.test --verbose --cov=xarray_extras --cov-report=xml
108 |
109 | - name: codecov.io
110 | uses: codecov/codecov-action@v3
111 |
--------------------------------------------------------------------------------
/xarray_extras/duck/sort.py:
--------------------------------------------------------------------------------
1 | """Helper functions for :mod:`xarray_extras.sort`, which accept either
2 | numpy arrays or dask arrays.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from typing import TypeVar
8 |
9 | import dask.array as da
10 | import numpy as np
11 | from dask.array.slicing import slice_with_int_dask_array_on_axis
12 | from xarray.core.duck_array_ops import broadcast_to
13 |
14 | T = TypeVar("T", np.ndarray, da.Array)
15 |
16 |
17 | def topk(a: T, k: int, split_every: int | None = None) -> T:
18 | """If a is a :class:`dask.array.Array`, invoke a.topk; else reimplement
19 | the functionality in plain numpy.
20 | """
21 | if isinstance(a, da.Array):
22 | return a.topk(k, split_every=split_every)
23 |
24 | if abs(k) < a.shape[-1]:
25 | a = np.partition(a, -k)
26 | if k > 0:
27 | a = a[..., -k:]
28 | else:
29 | a = a[..., :-k]
30 |
31 | # Sort the partitioned output
32 | a = np.sort(a)
33 | if k > 0:
34 | # Sort from greatest to smallest
35 | return a[..., ::-1]
36 | return a
37 |
38 |
39 | def argtopk(a: T, k: int, split_every: int | None = None) -> T:
40 | """If a is a :class:`dask.array.Array`, invoke a.argtopk; else reimplement
41 | the functionality in plain numpy.
42 | """
43 | if isinstance(a, da.Array):
44 | return a.argtopk(k, split_every=split_every)
45 |
46 | idx = np.argpartition(a, -k)
47 | if k > 0:
48 | idx = idx[..., -k:]
49 | else:
50 | idx = idx[..., :-k]
51 |
52 | a = np.take_along_axis(a, idx, axis=-1)
53 | idx = np.take_along_axis(idx, a.argsort(), axis=-1)
54 | if k > 0:
55 | # Sort from greatest to smallest
56 | return idx[..., ::-1]
57 | return idx
58 |
59 |
60 | def take_along_axis(
61 | a: np.ndarray | da.Array, ind: np.ndarray | da.Array
62 | ) -> np.ndarray | da.Array:
63 | """Easily use the outputs of argsort on ND arrays to pick the results."""
64 | if isinstance(a, np.ndarray) and isinstance(ind, np.ndarray):
65 | a = a.reshape((1,) * (ind.ndim - a.ndim) + a.shape)
66 | ind = ind.reshape((1,) * (a.ndim - ind.ndim) + ind.shape)
67 | return np.take_along_axis(a, ind, axis=-1)
68 |
69 | # a and/or ind are dask arrays. This is not yet implemented upstream.
70 | # Upstream tracker: https://github.com/dask/dask/issues/3663
71 |
72 | # This is going to be an ugly and slow mess, as dask does not support
73 | # fancy indexing.
74 |
75 | # Normalize a and ind. The end result is that a can have more axes than
76 | # ind on the left, but not vice versa, and that all axes except the
77 | # extra ones on the left and the rightmost one (the axis to take
78 | # along) are the same shape.
79 | if ind.ndim > a.ndim:
80 | a = a.reshape((1,) * (ind.ndim - a.ndim) + a.shape)
81 | common_shape = tuple(np.maximum(a.shape[-ind.ndim : -1], ind.shape[:-1]))
82 | a_extra_shape = a.shape[: -ind.ndim]
83 | a = broadcast_to(a, a_extra_shape + common_shape + a.shape[-1:])
84 | ind = broadcast_to(ind, common_shape + ind.shape[-1:])
85 |
86 | # Flatten all common axes onto axis -2
87 | final_shape = a.shape[: -ind.ndim] + ind.shape
88 | ind = ind.reshape(ind.size // ind.shape[-1], ind.shape[-1])
89 | a = a.reshape(*a_extra_shape, ind.shape[0], a.shape[-1])
90 |
91 | # Now we have a[..., i, j] and ind[i, j], where i are the flattened
92 | # common axes and j is the axis to take along.
93 | res = []
94 |
95 | # Cycle a and ind along i, perform 1D slices, and then stack them back
96 | # together
97 | for i in range(ind.shape[0]):
98 | a_i = a[..., i, :]
99 | ind_i = ind[i, :]
100 |
101 | if not isinstance(a_i, da.Array):
102 | a_i = da.from_array(a_i, chunks=a_i.shape)
103 |
104 | if isinstance(ind_i, da.Array):
105 | res_i = slice_with_int_dask_array_on_axis(a_i, ind_i, axis=a_i.ndim - 1)
106 | else:
107 | res_i = a_i[..., ind_i]
108 | res.append(res_i)
109 |
110 | res_arr = da.stack(res, axis=-2)
111 | # Un-flatten axis i
112 | return res_arr.reshape(*final_shape)
113 |
--------------------------------------------------------------------------------
/xarray_extras/tests/test_stack.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 | import xarray
5 |
6 | from xarray_extras.stack import proper_unstack
7 |
8 | # FIXME https://github.com/crusaderky/xarray_extras/issues/33
9 | pytestmark = pytest.mark.filterwarnings(
10 | "ignore:Updating MultiIndexed coordinate .* would corrupt indices",
11 | "ignore:invalid value encountered in cast:RuntimeWarning",
12 | )
13 |
14 |
15 | def test_proper_unstack_order():
16 | # Note: using MultiIndex.from_tuples is NOT the same thing as
17 | # round-tripping DataArray.stack().unstack(), as the latter is not
18 | # affected by the re-ordering issue
19 | index = pd.MultiIndex.from_tuples(
20 | [
21 | ["x1", "first"],
22 | ["x1", "second"],
23 | ["x1", "third"],
24 | ["x1", "fourth"],
25 | ["x0", "first"],
26 | ["x0", "second"],
27 | ["x0", "third"],
28 | ["x0", "fourth"],
29 | ],
30 | names=["x", "count"],
31 | )
32 | xa = xarray.DataArray(np.arange(8), dims=["dim_0"], coords={"dim_0": index})
33 |
34 | a = proper_unstack(xa, "dim_0")
35 | b = xarray.DataArray(
36 | [[0, 1, 2, 3], [4, 5, 6, 7]],
37 | dims=["x", "count"],
38 | coords={"x": ["x1", "x0"], "count": ["first", "second", "third", "fourth"]},
39 | )
40 | xarray.testing.assert_equal(a, b)
41 | with pytest.raises(AssertionError):
42 | # Order is different
43 | xarray.testing.assert_equal(a, xa.unstack("dim_0"))
44 |
45 |
46 | def test_proper_unstack_dtype():
47 | """Test that we don't accidentally end up with dtype=O for the coords"""
48 | a = xarray.DataArray(
49 | [[0, 1, 2, 3], [4, 5, 6, 7]],
50 | dims=["r", "c"],
51 | coords={
52 | "r": pd.to_datetime(["2000/01/01", "2000/01/02"]),
53 | "c": [1, 2, 3, 4],
54 | },
55 | )
56 | b = a.stack(s=["r", "c"])
57 | c = proper_unstack(b, "s")
58 | xarray.testing.assert_equal(a, c)
59 |
60 |
61 | def test_proper_unstack_mixed_coords():
62 | a = xarray.DataArray(
63 | [[0, 1, 2, 3], [4, 5, 6, 7]],
64 | dims=["r", "c"],
65 | coords={"r": [1, "x0"], "c": [1, 2.2, "3", "fourth"]},
66 | )
67 | b = a.stack(s=["r", "c"])
68 | c = proper_unstack(b, "s")
69 | xarray.testing.assert_equal(a, c)
70 |
71 |
72 | def test_proper_unstack_dataset():
73 | a = xarray.DataArray(
74 | [[1, 2, 3, 4], [5, 6, 7, 8]],
75 | dims=["x", "col"],
76 | coords={
77 | "x": ["x0", "x1"],
78 | "col": pd.MultiIndex.from_tuples(
79 | [("u0", "v0"), ("u0", "v1"), ("u1", "v0"), ("u1", "v1")],
80 | names=["u", "v"],
81 | ),
82 | },
83 | )
84 | xa = xarray.Dataset({"foo": a, "bar": ("w", [1, 2]), "baz": np.pi})
85 | b = proper_unstack(xa, "col")
86 | c = xarray.DataArray(
87 | [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
88 | dims=["x", "u", "v"],
89 | coords={"x": ["x0", "x1"], "u": ["u0", "u1"], "v": ["v0", "v1"]},
90 | )
91 | d = xarray.Dataset({"foo": c, "bar": ("w", [1, 2]), "baz": np.pi})
92 | xarray.testing.assert_equal(b, d)
93 | for c in b.coords:
94 | assert b.coords[c].dtype.kind == "U"
95 |
96 |
97 | def test_proper_unstack_other_mi():
98 | a = xarray.DataArray(
99 | [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]],
100 | dims=["row", "col"],
101 | coords={
102 | "row": pd.MultiIndex.from_tuples(
103 | [("x0", "w0"), ("x0", "w1"), ("x1", "w0"), ("x1", "w1")],
104 | names=["x", "w"],
105 | ),
106 | "col": pd.MultiIndex.from_tuples(
107 | [("y0", "z0"), ("y0", "z1"), ("y1", "z0"), ("y1", "z1")],
108 | names=["y", "z"],
109 | ),
110 | },
111 | )
112 | b = proper_unstack(a, "row")
113 | c = xarray.DataArray(
114 | [[[1, 5], [1, 5]], [[2, 6], [2, 6]], [[3, 7], [3, 7]], [[4, 8], [4, 8]]],
115 | dims=["col", "x", "w"],
116 | coords={
117 | "col": pd.MultiIndex.from_tuples(
118 | [("y0", "z0"), ("y0", "z1"), ("y1", "z0"), ("y1", "z1")],
119 | names=["y", "z"],
120 | ),
121 | "x": ["x0", "x1"],
122 | "w": ["w0", "w1"],
123 | },
124 | )
125 | xarray.testing.assert_equal(b, c)
126 |
--------------------------------------------------------------------------------
/xarray_extras/tests/test_cumulatives.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import xarray
4 | from xarray.testing import assert_equal
5 |
6 | pytest.importorskip("numba") # Not available in upstream CI
7 |
8 | import xarray_extras.cumulatives as cum
9 |
10 | # Skip 0 and 1 as they're neutral in addition and multiplication
11 | INPUT = xarray.DataArray(
12 | [[2, 20, 25], [3, 30, 35], [4, 40, 45], [5, 50, 55]],
13 | dims=["t", "s"],
14 | coords={
15 | "t": np.array(
16 | ["1990-12-30", "2000-12-30", "2005-12-30", "2010-12-30"], dtype="M8[ns]"
17 | ),
18 | "s": ["s1", "s2", "s3"],
19 | },
20 | )
21 |
22 |
23 | T_COMPOUND_MATRIX = xarray.DataArray(
24 | np.array(
25 | [
26 | ["1990-12-30", "NaT", "NaT"],
27 | ["1990-12-30", "2005-12-30", "NaT"],
28 | ["2000-12-30", "1990-12-30", "NaT"],
29 | ["2010-12-30", "1990-12-30", "2005-12-30"],
30 | ],
31 | dtype="M8[ns]",
32 | ),
33 | dims=["t2", "c"],
34 | coords={"t2": [10, 20, 30, 40]},
35 | )
36 |
37 |
38 | S_COMPOUND_MATRIX = xarray.DataArray(
39 | [["s3", "s2"], ["s1", ""]], dims=["s2", "c"], coords={"s2": ["foo", "bar"]}
40 | )
41 |
42 | DTYPES = (
43 | # There's a bug in numba.guvectorize for u8, u16, u32
44 | # i8 and i16 are too short to store the output
45 | "int32",
46 | "int64",
47 | "uint64",
48 | "float32",
49 | "float64",
50 | "complex64",
51 | "complex128",
52 | )
53 |
54 |
55 | @pytest.mark.parametrize(
56 | "func, meth",
57 | [
58 | (cum.compound_sum, "sum"),
59 | (cum.compound_prod, "prod"),
60 | (cum.compound_mean, "mean"),
61 | ],
62 | )
63 | @pytest.mark.parametrize("dtype", DTYPES)
64 | @pytest.mark.parametrize("use_dask", [False, True])
65 | def test_compound_t(func, meth, dtype, use_dask):
66 | x = INPUT.astype(dtype)
67 | c = T_COMPOUND_MATRIX
68 | expect = xarray.concat(
69 | [
70 | getattr(x.isel(t=[0]), meth)("t"),
71 | getattr(x.isel(t=[0, 2]), meth)("t"),
72 | getattr(x.isel(t=[1, 0]), meth)("t"),
73 | getattr(x.isel(t=[3, 0, 2]), meth)("t"),
74 | ],
75 | dim="t2",
76 | ).T.astype(dtype)
77 | expect.coords["t2"] = c.coords["t2"]
78 |
79 | if use_dask:
80 | x = x.chunk({"s": 2})
81 | expect = expect.chunk({"s": 2})
82 | c = c.chunk()
83 |
84 | actual = func(x, c, "t", "c")
85 |
86 | if use_dask:
87 | assert_equal(expect.compute(), actual.compute())
88 | else:
89 | assert_equal(expect, actual)
90 |
91 | assert expect.dtype == actual.dtype
92 | assert actual.chunks == expect.chunks
93 |
94 |
95 | @pytest.mark.parametrize(
96 | "func, meth",
97 | [
98 | (cum.compound_sum, "sum"),
99 | (cum.compound_prod, "prod"),
100 | (cum.compound_mean, "mean"),
101 | ],
102 | )
103 | @pytest.mark.parametrize("dtype", DTYPES)
104 | @pytest.mark.parametrize("use_dask", [False, True])
105 | def test_compound_s(func, meth, dtype, use_dask):
106 | x = INPUT.astype(dtype)
107 | c = S_COMPOUND_MATRIX
108 | expect = xarray.concat(
109 | [
110 | getattr(x.sel(s=["s3", "s2"]), meth)("s"),
111 | getattr(x.sel(s=["s1"]), meth)("s"),
112 | ],
113 | dim="s2",
114 | ).T.astype(dtype)
115 | expect.coords["s2"] = c.coords["s2"]
116 |
117 | if use_dask:
118 | x = x.chunk({"t": 2})
119 | expect = expect.chunk({"t": 2})
120 | c = c.chunk()
121 |
122 | actual = func(x, c, "s", "c")
123 |
124 | if use_dask:
125 | assert_equal(expect.compute(), actual.compute())
126 | else:
127 | assert_equal(expect, actual)
128 |
129 | assert expect.dtype == actual.dtype
130 | assert actual.chunks == expect.chunks
131 |
132 |
133 | @pytest.mark.parametrize("dtype", [float, int, "complex128"])
134 | @pytest.mark.parametrize("skipna", [False, True, None])
135 | @pytest.mark.parametrize("use_dask", [False, True])
136 | def test_cummean(use_dask, skipna, dtype):
137 | x = INPUT.copy(deep=True).astype(dtype)
138 | if dtype in (float, "complex128"):
139 | x[2, 1] = np.nan
140 |
141 | expect = xarray.concat(
142 | [
143 | x[:1].mean("t", skipna=skipna),
144 | x[:2].mean("t", skipna=skipna),
145 | x[:3].mean("t", skipna=skipna),
146 | x[:4].mean("t", skipna=skipna),
147 | ],
148 | dim="t",
149 | )
150 | expect.coords["t"] = x.coords["t"]
151 | if use_dask:
152 | x = x.chunk({"s": 2, "t": 3})
153 | expect = expect.chunk({"s": 2, "t": 3})
154 |
155 | actual = cum.cummean(x, "t", skipna=skipna)
156 | assert_equal(expect, actual)
157 | assert expect.dtype == actual.dtype
158 | assert actual.chunks == expect.chunks
159 |
--------------------------------------------------------------------------------
/xarray_extras/kernels/csv.py:
--------------------------------------------------------------------------------
1 | """dask kernels for :mod:`xarray_extras.csv`"""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from xarray_extras.kernels.np_to_csv_py import snprintcsvd, snprintcsvi
11 |
12 |
13 | def to_csv(
14 | x: np.ndarray,
15 | index: pd.Index,
16 | columns: pd.Index,
17 | first_chunk: bool,
18 | nogil: bool,
19 | kwargs: dict,
20 | ) -> bytes:
21 | """Format x into CSV and encode it to binary
22 |
23 | :param x:
24 | numpy.ndarray with 1 or 2 dimensions
25 | :param pandas.Index index:
26 | row index
27 | :param pandas.Index columns:
28 | column index. None for Series or for DataFrame chunks beyond the first.
29 | :param bool first_chunk:
30 | True if this is the first chunk; False otherwise
31 | :param bool nogil:
32 | If True, use accelerated C implementation. Several kwargs won't be
33 | processed correctly. If False, use pandas to_csv method (slow, and does
34 | not release the GIL).
35 | :param kwargs:
36 | arguments passed to pandas to_csv methods
37 | :returns:
38 | CSV file contents, encoded in UTF-8
39 | """
40 | if x.ndim == 1:
41 | assert columns is None
42 | x_pd = pd.Series(x, index)
43 | elif x.ndim == 2:
44 | x_pd = pd.DataFrame(x, index, columns)
45 | else:
46 | # proper ValueError already raised in wrapper
47 | raise AssertionError("unreachable") # pragma: nocover
48 |
49 | encoding = kwargs.pop("encoding", "utf-8")
50 | header = kwargs.pop("header", True)
51 | lineterminator = kwargs.pop("lineterminator", os.linesep)
52 |
53 | if not nogil or not x.size:
54 | out = x_pd.to_csv(header=header, lineterminator=lineterminator, **kwargs)
55 | bout = out.encode(encoding)
56 | if encoding == "utf-16" and not first_chunk:
57 | # utf-16 contains a bang at the beginning of the text. However,
58 | # when concatenating multiple chunks we don't want to replicate it.
59 | assert bout[:2] == b"\xff\xfe"
60 | bout = bout[2:]
61 | return bout
62 |
63 | sep = kwargs.get("sep", ",")
64 | fmt = kwargs.get("float_format")
65 | na_rep = kwargs.get("na_rep", "")
66 |
67 | # Use pandas to format index
68 | if x.ndim == 1:
69 | x_df = x_pd.to_frame()
70 | else:
71 | x_df = x_pd
72 |
73 | index_csv = x_df.iloc[:, :0].to_csv(header=False, lineterminator="\n", **kwargs)
74 | index_csv = index_csv.strip().split("\n")
75 | if len(index_csv) != x.shape[0]:
76 | index_csv = "\n" * x.shape[0]
77 | else:
78 | index_csv = "\n".join(r + sep if r else "" for r in index_csv) + "\n"
79 |
80 | # Invoke C code to format the values. This releases the GIL.
81 | if x.dtype.kind == "i":
82 | body_bytes = snprintcsvi(x, index_csv, sep)
83 | elif x.dtype.kind == "f":
84 | body_bytes = snprintcsvd(x, index_csv, sep, fmt, na_rep)
85 | else:
86 | raise NotImplementedError("only int and float are supported when nogil=True")
87 |
88 | if header is not False:
89 | header_bytes = (
90 | x_df.iloc[:0, :]
91 | .to_csv(header=header, lineterminator="\n", **kwargs)
92 | .encode("utf-8")
93 | )
94 | body_bytes = header_bytes + body_bytes
95 |
96 | if encoding not in {"ascii", "utf-8"}:
97 | # Everything is encoded in UTF-8 until this moment. Recode if needed.
98 | body_str = body_bytes.decode("utf-8")
99 | if lineterminator != "\n":
100 | body_str = body_str.replace("\n", lineterminator)
101 | body_bytes = body_str.encode(encoding)
102 | if encoding == "utf-16" and not first_chunk:
103 | # utf-16 contains a bang at the beginning of the text. However,
104 | # when concatenating multiple chunks we don't want to replicate it.
105 | assert body_bytes[:2] == b"\xff\xfe"
106 | body_bytes = body_bytes[2:]
107 | elif lineterminator != "\n":
108 | body_bytes = body_bytes.replace(b"\n", lineterminator.encode("utf-8"))
109 |
110 | return body_bytes
111 |
112 |
113 | def to_file(fname: str, mode: str, data: str | bytes, rr_token: object) -> None: # noqa: ARG001
114 | """Write data to file
115 |
116 | :param fname:
117 | File path on disk
118 | :param mode:
119 | As in 'open'
120 | :param data:
121 | Binary or text data to write
122 | :param rr_token:
123 | Round-robin token passed by to_file from the previous chunk. It
124 | guarantees write order across multiple, otherwise parallel, tasks. This
125 | parameter is only used by the dask scheduler.
126 | """
127 | with open(fname, mode) as fh:
128 | fh.write(data)
129 |
--------------------------------------------------------------------------------
/xarray_extras/kernels/np_to_csv_py.py:
--------------------------------------------------------------------------------
1 | """Thin ctypes wrapper around :file:`np_to_csv.c`.
2 | This is a helper module of :mod:`xarray_extras.kernels.csv`.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import ctypes
8 | import warnings
9 |
10 | import numpy as np
11 |
12 | from xarray_extras.kernels import np_to_csv # type:ignore[attr-defined]
13 |
14 | with warnings.catch_warnings():
15 | warnings.simplefilter("ignore", category=DeprecationWarning)
16 | np_to_csv = np.ctypeslib.load_library("np_to_csv", np_to_csv.__file__)
17 |
18 | np_to_csv.snprintcsvd.argtypes = [
19 | ctypes.c_char_p, # char * buf
20 | ctypes.c_int32, # int bufsize
21 | # const double * array
22 | np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags="C_CONTIGUOUS"),
23 | ctypes.c_int32, # int h
24 | ctypes.c_int32, # int w
25 | ctypes.c_char_p, # const char * index
26 | ctypes.c_char_p, # const char * fmt
27 | ctypes.c_bool, # bool trim_zeros
28 | ctypes.c_char_p, # const char * na_rep
29 | ]
30 | np_to_csv.snprintcsvd.restype = ctypes.c_int32
31 |
32 |
33 | np_to_csv.snprintcsvi.argtypes = [
34 | ctypes.c_char_p, # char * buf
35 | ctypes.c_int32, # int bufsize
36 | # const int64_t * array
37 | np.ctypeslib.ndpointer(dtype=np.int64, ndim=2, flags="C_CONTIGUOUS"),
38 | ctypes.c_int32, # int h
39 | ctypes.c_int32, # int w
40 | ctypes.c_char_p, # const char * index
41 | ctypes.c_char, # char sep
42 | ]
43 | np_to_csv.snprintcsvi.restype = ctypes.c_int32
44 |
45 |
46 | def snprintcsvd(
47 | a: np.ndarray,
48 | index: str,
49 | sep: str = ",",
50 | fmt: str | None = None,
51 | na_rep: str = "",
52 | ) -> bytes:
53 | """Convert array to CSV.
54 |
55 | :param a:
56 | 1D or 2D numpy array of floats
57 | :param str index:
58 | newline-separated list of prefixes for every row of a
59 | :param str sep:
60 | cell separator
61 | :param str fmt:
62 | printf formatting string for a single float number
63 | Set to None to replicate pandas to_csv default behaviour
64 | :param str na_rep:
65 | string representation of NaN
66 | :return:
67 | CSV file contents, binary-encoded in ascii format.
68 | The line terminator is always \n on all OSs.
69 | """
70 | if a.ndim == 1:
71 | a = a.reshape((-1, 1))
72 | if a.ndim != 2 or a.dtype.kind != "f":
73 | raise ValueError("Expected 2d numpy array of floats")
74 | a = a.astype(np.float64)
75 | a = np.ascontiguousarray(a)
76 | if len(sep) != 1:
77 | raise ValueError("sep must be exactly 1 character")
78 | bsep = sep.encode("ascii")
79 |
80 | # Test fmt while in Python - much better to get
81 | # an Exception here than a segfault in C!
82 | if fmt is not None:
83 | fmt % 1.23
84 | bfmt = fmt.encode("ascii") + bsep
85 | trim_zeros = False
86 | else:
87 | bfmt = b"%f" + bsep
88 | trim_zeros = True
89 | bna_rep = na_rep.encode("ascii") + bsep
90 | # We're relying on the fact that ascii is a strict subset of UTF-8
91 | bindex = index.encode("utf-8")
92 |
93 | # Blindly try ever-larger bufsizes until it fits
94 | # The first iteration should be sufficient in all but the most
95 | # degenerate cases.
96 | # FIXME: is there a better way?
97 | cellsize = 40
98 | while True:
99 | bufsize = cellsize * a.size + len(bindex)
100 | buf = ctypes.create_string_buffer(bufsize)
101 | nchar = np_to_csv.snprintcsvd(
102 | buf, bufsize, a, a.shape[0], a.shape[1], bindex, bfmt, trim_zeros, bna_rep
103 | )
104 | if nchar < bufsize:
105 | return bytes(buf[:nchar]) # type: ignore[arg-type]
106 | cellsize *= 2
107 |
108 |
109 | def snprintcsvi(a: np.ndarray, index: str, sep: str = ",") -> bytes:
110 | """Convert array to CSV.
111 |
112 | :param a:
113 | 1D or 2D numpy array of integers
114 | :param str index:
115 | newline-separated list of prefixes for every row of a
116 | :param str sep:
117 | cell separator
118 | :return:
119 | CSV file contents, binary-encoded in ascii format.
120 | The line terminator is always \n on all OSs.
121 | """
122 | if a.ndim == 1:
123 | a = a.reshape((-1, 1))
124 | if a.ndim != 2 or a.dtype.kind != "i":
125 | raise ValueError("Expected 2d numpy array of ints")
126 | a = a.astype(np.int64)
127 | a = np.ascontiguousarray(a)
128 | if len(sep) != 1:
129 | raise ValueError("sep must be exactly 1 character")
130 | bsep = sep.encode("ascii")
131 | # We're relying on the fact that ascii is a strict subset of UTF-8
132 | bindex = index.encode("utf-8")
133 |
134 | cellsize = 22 # len('%d' % -2**64) + 1
135 | bufsize = cellsize * a.size + len(bindex)
136 | buf = ctypes.create_string_buffer(bufsize)
137 | nchar = np_to_csv.snprintcsvi(buf, bufsize, a, a.shape[0], a.shape[1], bindex, bsep)
138 | assert nchar < bufsize
139 | return bytes(buf[:nchar]) # type: ignore[arg-type]
140 |
--------------------------------------------------------------------------------
/doc/whats-new.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: xarray_extras
2 |
3 | What's New
4 | ==========
5 |
6 | .. _whats-new.0.7.0:
7 |
8 | v0.7.0 (unreleased)
9 | -------------------
10 | - ``interpolate``: Added support for timedelta coordinates;
11 | fixed support for datetime coordinates other than ``M8[ns]`` in xarray 2025.1.2.
12 | - Added formal support for Python 3.13 (but the previous version works fine too)
13 |
14 |
15 | .. _whats-new.0.6.0:
16 |
17 | v0.6.0 (2024-03-16)
18 | -------------------
19 | - Bumped minimum version of all dependencies:
20 |
21 | ========== ======== =========
22 | Dependency v0.5.0 v0.6.0
23 | ========== ======== =========
24 | python 3.7 3.8
25 | dask 2021.4.0 2022.6.0
26 | numba 0.52 0.56
27 | numpy 1.18 1.23
28 | pandas 1.1 1.5
29 | scipy 1.5 1.9
30 | xarray 0.16 2022.11.0
31 | ========== ======== =========
32 |
33 | - Added support for Python 3.10, 3.11, and 3.12
34 | - Added support for recent versions of Pandas (tested up to 2.2) and xarray
35 | - Added support for :cls:`pathlib.Path` in function arguments
36 | - Migrated from setup.cfg to pyproject.toml
37 | - Migrated from flake8+isort+pyupgrade to ruff
38 |
39 |
40 | .. _whats-new.0.5.0:
41 |
42 | v0.5.0 (2021-12-01)
43 | -------------------
44 | - Bumped minimum version of all dependencies:
45 |
46 | ========== ====== ========
47 | Dependency v0.4.2 v0.5.0
48 | ========== ====== ========
49 | python 3.5 3.7
50 | dask 0.19 2021.4.0
51 | numba 0.34 0.52
52 | numpy 1.13 1.18
53 | pandas 0.21 1.1
54 | scipy 1.0 1.5
55 | xarray 0.10.1 0.16
56 | ========== ====== ========
57 |
58 | - Added support for Python 3.8 and 3.9
59 | - Removed ``xarray_extras.backports`` module
60 | - Migrated CI to github workflows
61 | - Added code linters: pyupgrade, isort, black
62 | - Run all code linters through pre-commit
63 | - Use setuptools-scm for versioning
64 | - Moved the whole contents of setup.py to setup.cfg
65 |
66 |
67 | .. _whats-new.0.4.2:
68 |
69 | v0.4.2 (2019-06-03)
70 | -------------------
71 |
72 | - Type annotations
73 | - Mandatory mypy validation in CI
74 | - CI unit tests for Windows now run on Python 3.7
75 | - Compatibility with dask >= 1.1
76 | - Suppress deprecation warnings with pandas >= 0.24
77 | - :func:`~xarray_extras.csv.to_csv` changes:
78 |
79 | - When invoked on a 1-dimensional DataArray,
80 | the default value for the ``index`` parameter has been changed from False to
81 | True, coherently to the default for pandas.Series.to_csv from pandas 0.24.
82 | This applies also to users who have pandas < 0.24 installed.
83 | - support for ``line_terminator`` parameter (all pandas versions);
84 | - fix incorrect line terminator in Windows with pandas >= 0.24
85 | - support for ``compression='infer'`` (all pandas versions)
86 | - support for ``compression`` parameter with pandas < 0.23
87 |
88 |
89 | .. _whats-new.0.4.1:
90 |
91 | v0.4.1 (2019-02-02)
92 | -------------------
93 |
94 | - Fixed build regression in `readthedocs `_
95 |
96 |
97 | .. _whats-new.0.4.0:
98 |
99 | v0.4.0 (2019-02-02)
100 | -------------------
101 |
102 | - Moved ``recursive_diff``, ``recursive_eq`` and ``ncdiff``
103 | to their own package `recursive_diff `_
104 | - Fixed bug in :func:`~xarray_extras.stack.proper_unstack` where unstacking
105 | coords with dtype=datetime64 would convert them to integer
106 | - Mandatory flake8 in CI
107 |
108 |
109 | .. _whats-new.0.3.0:
110 |
111 | v0.3.0 (2018-12-13)
112 | -------------------
113 |
114 | - Changed license to Apache 2.0
115 | - Increased minimum versions: dask >= 0.19, pandas >= 0.21,
116 | xarray >= 0.10.1, pytest >= 3.6
117 | - New function :func:`~xarray_extras.stack.proper_unstack`
118 | - New functions ``recursive_diff`` and ``ecursive_eq``
119 | - New command-line tool ``ncdiff``
120 | - Blacklisted Python 3.7 conda-forge builds in CI tests
121 |
122 |
123 | .. _whats-new.0.2.2:
124 |
125 | v0.2.2 (2018-07-24)
126 | -------------------
127 |
128 | - Fixed segmentation faults in :func:`~xarray_extras.csv.to_csv`
129 | - Added conda-forge travis build
130 | - Blacklisted dask-0.18.2 because of regression in argtopk(split_every=2)
131 |
132 |
133 | .. _whats-new.0.2.1:
134 |
135 | v0.2.1 (2018-07-22)
136 | -------------------
137 |
138 | - Added parameter nogil=True to :func:`~xarray_extras.csv.to_csv`, which will
139 | switch to a C-accelerated implementation instead of pandas to_csv (albeit
140 | with caveats). Fixed deadlock in to_csv as well as compatibility with dask
141 | distributed. Pandas code (when using nogil=False) is not wrapped by a
142 | subprocess anymore, which means it won't be able to use more than 1 CPU
143 | (but compression can run in pipeline).
144 | to_csv has lost the ability to write to a buffer - only file paths are
145 | supported now.
146 | - AppVeyor integration
147 |
148 |
149 | .. _whats-new.0.2.0:
150 |
151 | v0.2.0 (2018-07-15)
152 | -------------------
153 |
154 | - New function :func:`xarray_extras.csv.to_csv`
155 | - Speed up interpolation for k=2 and k=3
156 | - CI: Rigorous tracking of minimum dependency versions
157 | - CI: Explicit support for Python 3.7
158 |
159 |
160 | .. _whats-new.0.1.0:
161 |
162 | v0.1.0 (2018-05-19)
163 | -------------------
164 |
165 | Initial release.
166 |
--------------------------------------------------------------------------------
/xarray_extras/kernels/np_to_csv.c:
--------------------------------------------------------------------------------
1 | /*
2 | * High speed implementation for to_csv()
3 | */
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #ifdef _WIN32
13 | # define LIBRARY_API __declspec(dllexport)
14 | #else
15 | # define LIBRARY_API
16 | #endif
17 |
18 | // In case of buffer overflow, in Windows, snprintf returns -1
19 | // In Linux, it returns the number of characters that would have been written
20 | #define CHECK_OVERFLOW if (ret < 0 || char_count >= bufsize) return bufsize;
21 |
22 |
23 | static PyMethodDef module_methods[] = {
24 | {NULL, NULL, 0, NULL}
25 | };
26 |
27 |
28 | PyMODINIT_FUNC PyInit_np_to_csv(void)
29 | {
30 | PyObject *module;
31 | static struct PyModuleDef moduledef = {
32 | PyModuleDef_HEAD_INIT,
33 | "np_to_csv",
34 | "High speed implementation for to_csv()",
35 | -1,
36 | module_methods,
37 | NULL,
38 | NULL,
39 | NULL,
40 | NULL
41 | };
42 | module = PyModule_Create(&moduledef);
43 | if (!module) return NULL;
44 |
45 | return module;
46 | }
47 |
48 |
49 | /* Convert 2D array of doubles to CSV
50 | *
51 | * buf : output buffer
52 | * bufsize : maximum number of characters that can be written to buf
53 | * array : input numerical data
54 | * h : number of rows in array
55 | * w : number of columns in array
56 | * index : newline-separated list of prefix strings, one per row
57 | * fmt : printf formatting, including the cell separator at the end.
58 | * cell separator must be exactly 1 character.
59 | * trim_zeros : if true, trim trailing zeros after the . beyond the first
60 | * e.g. 1.000 -> 1.0
61 | * na_rep : string representation for NaN, including the cell separator at the end
62 | *
63 | * The line terminator is always \n, regardless of OS.
64 | */
65 | LIBRARY_API
66 | int snprintcsvd(char * buf, int bufsize, const double * array, int h, int w,
67 | const char * index, const char * fmt, bool trim_zeros, const char * na_rep)
68 | {
69 | int char_count = 0;
70 | int ret = 0;
71 | int i, j;
72 |
73 | // Move along a single column, printing the value of each row
74 | for (i = 0; i < h; i++) {
75 | // Print row header
76 | while (1) {
77 | CHECK_OVERFLOW;
78 | char c = *(index++);
79 | assert(c != 0);
80 | if (c == '\n') {
81 | break;
82 | }
83 | buf[char_count++] = c;
84 | }
85 | CHECK_OVERFLOW;
86 |
87 | // Print row values
88 | for (j = 0; j < w; j++) {
89 | double n = *(array++);
90 | if (isnan(n)) {
91 | ret = snprintf(buf + char_count, bufsize - char_count, "%s", na_rep);
92 | char_count += ret;
93 | CHECK_OVERFLOW;
94 | }
95 | else {
96 | ret = snprintf(buf + char_count, bufsize - char_count, fmt, n);
97 | char_count += ret;
98 | CHECK_OVERFLOW;
99 |
100 | if (trim_zeros) {
101 | while (char_count > 2 &&
102 | buf[char_count - 2] == '0' &&
103 | buf[char_count - 3] != '.') {
104 | buf[char_count - 2] = buf[char_count - 1];
105 | char_count--;
106 | }
107 | }
108 | }
109 | }
110 | // Replace latest column separator with line terminator
111 | buf[char_count - 1] = '\n';
112 | }
113 |
114 | return char_count;
115 | }
116 |
117 |
118 | /* Convert 2D array of int64's to CSV
119 | *
120 | * buf : output buffer
121 | * bufsize : maximum number of characters that can be written to buf
122 | * array : input numerical data
123 | * h : number of rows in array
124 | * w : number of columns in array
125 | * index : newline-separated list of prefix strings, one per row
126 | * sep : cell separator
127 | *
128 | * The line terminator is always \n, regardless of OS.
129 | */
130 | LIBRARY_API
131 | int snprintcsvi(char * buf, int bufsize, const int64_t * array, int h, int w,
132 | const char * index, char sep)
133 | {
134 | int char_count = 0;
135 | int ret = 0;
136 | int i, j;
137 |
138 | // '%d' + sep, but for int64_t
139 | char fmt[sizeof(PRId64) + 2];
140 | fmt[0] = '%';
141 | strcpy(fmt + 1, PRId64);
142 | fmt[sizeof(PRId64)] = sep;
143 | fmt[sizeof(PRId64) + 1] = 0;
144 |
145 | // Move along a single column, printing the value of each row
146 | for (i = 0; i < h; i++) {
147 | // Print row header
148 | while (1) {
149 | CHECK_OVERFLOW;
150 | char c = *(index++);
151 | if (c == 0 || c == '\n') {
152 | break;
153 | }
154 | buf[char_count++] = c;
155 | }
156 | CHECK_OVERFLOW;
157 |
158 | // Print row values
159 | for (j = 0; j < w; j++) {
160 | ret = snprintf(buf + char_count, bufsize - char_count, fmt, *(array++));
161 | char_count += ret;
162 | CHECK_OVERFLOW;
163 | }
164 | // Replace latest column separator with line terminator
165 | buf[char_count - 1] = '\n';
166 | }
167 |
168 | return char_count;
169 | }
170 |
--------------------------------------------------------------------------------
/xarray_extras/kernels/interpolate.py:
--------------------------------------------------------------------------------
1 | """dask kernels for :mod:`xarray_extras.interpolate`
2 |
3 | .. codeauthor:: Guido Imperiale
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | from collections.abc import Iterable
9 | from typing import TYPE_CHECKING
10 |
11 | import numpy as np
12 | from scipy.interpolate import BSpline, make_interp_spline
13 | from scipy.interpolate._bsplines import _as_float_array, _augknt, _not_a_knot
14 |
15 | if TYPE_CHECKING:
16 | # TODO Python 3.9 notations
17 | from typing import Tuple, Union
18 |
19 | # TODO import from typing (requires Python 3.10)
20 | from typing_extensions import TypeAlias
21 |
22 | Boundary: TypeAlias = Iterable[Tuple[int, float]]
23 | BCType: TypeAlias = Union[Tuple[Boundary, Boundary], str, None]
24 |
25 |
26 | def _memoryview_safe(x: np.ndarray) -> np.ndarray:
27 | """Make array safe to run in a Cython memoryview-based kernel. These
28 | kernels typically break down with the error ``ValueError: buffer source
29 | array is read-only`` when running in dask distributed.
30 | """
31 | if not x.flags.writeable:
32 | if not x.flags.owndata:
33 | x = x.copy(order="C")
34 | x.setflags(write=True)
35 | return x
36 |
37 |
38 | def make_interp_knots(
39 | x: np.ndarray, k: int = 3, bc_type: BCType | None = None, check_finite: bool = True
40 | ) -> np.ndarray:
41 | """Compute the knots of the B-spline.
42 |
43 | .. note::
44 | This is a temporary implementation that should be moved to the main
45 | scipy library - see ``_.
46 |
47 | Parameters
48 | ----------
49 | x : array_like, shape (n,)
50 | Abscissas.
51 | k : int, optional
52 | B-spline degree. Default is cubic, k=3.
53 | bc_type : 2-tuple or None
54 | Boundary conditions.
55 | Default is None, which means choosing the boundary conditions
56 | automatically. Otherwise, it must be a length-two tuple where the first
57 | element sets the boundary conditions at ``x[0]`` and the second
58 | element sets the boundary conditions at ``x[-1]``. Each of these must
59 | be an iterable of pairs ``(order, value)`` which gives the values of
60 | derivatives of specified orders at the given edge of the interpolation
61 | interval.
62 | check_finite : bool, optional
63 | Whether to check that the input arrays contain only finite numbers.
64 | Disabling may give a performance gain, but may result in problems
65 | (crashes, non-termination) if the inputs do contain infinities or NaNs.
66 | Default is True.
67 |
68 | Returns
69 | -------
70 | numpy array with size = x.size + k + 1, representing the B-spline knots.
71 | """
72 | if k < 2 and bc_type is not None:
73 | raise ValueError("Too much info for k<2: bc_type can only be None.")
74 |
75 | x = np.array(x)
76 | if x.ndim != 1 or np.any(x[1:] <= x[:-1]):
77 | raise ValueError("Expect x to be a 1-D sorted array-like.")
78 |
79 | if k == 0:
80 | t = np.r_[x, x[-1]]
81 | elif k == 1:
82 | t = np.r_[x[0], x, x[-1]]
83 | elif bc_type is None:
84 | if k == 2:
85 | # OK, it's a bit ad hoc: Greville sites + omit
86 | # 2nd and 2nd-to-last points, a la not-a-knot
87 | t = (x[1:] + x[:-1]) / 2.0
88 | t = np.r_[(x[0],) * (k + 1), t[1:-1], (x[-1],) * (k + 1)]
89 | else:
90 | t = _not_a_knot(x, k)
91 | else:
92 | t = _augknt(x, k)
93 |
94 | return _as_float_array(t, check_finite)
95 |
96 |
97 | def make_interp_coeffs(
98 | x: np.ndarray,
99 | y: np.ndarray,
100 | k: int = 3,
101 | t: np.ndarray | None = None,
102 | bc_type: BCType = None,
103 | axis: int = 0,
104 | check_finite: bool = True,
105 | ) -> np.ndarray:
106 | """Compute the knots of the B-spline.
107 |
108 | .. note::
109 | This is a temporary implementation that should be moved to the main
110 | scipy library - see ``_.
111 |
112 | See :func:`scipy.interpolate.make_interp_spline` for parameters.
113 |
114 | :param t:
115 | Knots array, as calculated by :func:`make_interp_knots`.
116 |
117 | - For k=0, must always be None (the coefficients are not a function of
118 | the knots).
119 | - For k=1, set to None if t has been calculated by
120 | :func:`make_interp_knots`; pass a vector if it already existed
121 | before.
122 | - For k=2 and k=3, must always pass either the output of
123 | :func:`make_interp_knots` or a pre-generated vector.
124 | """
125 | x = _memoryview_safe(x)
126 | y = _memoryview_safe(y)
127 | if t is not None:
128 | t = _memoryview_safe(t)
129 |
130 | return make_interp_spline(
131 | x, y, k, t, bc_type=bc_type, axis=axis, check_finite=check_finite
132 | ).c
133 |
134 |
135 | def splev(
136 | x_new: np.ndarray,
137 | t: np.ndarray,
138 | c: np.ndarray,
139 | k: int = 3,
140 | extrapolate: bool | str = True,
141 | ) -> np.ndarray:
142 | """Generate a BSpline object on the fly from knots and coefficients and
143 | evaluate it on x_new.
144 |
145 | See :class:`scipy.interpolate.BSpline` for all parameters.
146 | """
147 | t = _memoryview_safe(t)
148 | c = _memoryview_safe(c)
149 | x_new = _memoryview_safe(x_new)
150 | spline = BSpline.construct_fast(t, c, k, axis=0, extrapolate=extrapolate)
151 | return spline(x_new)
152 |
--------------------------------------------------------------------------------
/xarray_extras/cumulatives.py:
--------------------------------------------------------------------------------
1 | """Advanced cumulative sum/productory/mean functions"""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Callable, Hashable
6 | from typing import TypeVar
7 |
8 | import dask.array as da
9 | import numpy as np
10 | import xarray
11 |
12 | from xarray_extras.kernels import cumulatives as kernels
13 |
14 | __all__ = ("compound_mean", "compound_prod", "compound_sum", "cummean")
15 |
16 |
17 | T = TypeVar("T", xarray.DataArray, xarray.Dataset)
18 | TV = TypeVar("TV", xarray.DataArray, xarray.Dataset, xarray.Variable)
19 |
20 |
21 | def cummean(x: T, dim: Hashable, skipna: bool | None = None) -> T:
22 | """
23 | .. math::
24 |
25 | y_{i} = mean(x_{0}, x_{1}, ... x_{i})
26 |
27 | :param x:
28 | :class:`~xarray.DataArray` or :class:`~xarray.Dataset`
29 | :param hashable dim:
30 | dimension along which to calculate the mean
31 | :param bool skipna:
32 | If True, skip missing values (as marked by NaN). By default, only skips
33 | missing values for float dtypes; other dtypes either do not have a
34 | sentinel missing value (int) or skipna=True has not been implemented
35 | (object, datetime64 or timedelta64).
36 | :returns:
37 | xarray object of the same type, dtype, and shape as x
38 | """
39 | if skipna is False or (skipna is None and x.dtype.kind not in "fc"):
40 | # n is a simple arange
41 | if x.chunks:
42 | if isinstance(x, xarray.DataArray):
43 | chunks = x.chunks[x.dims.index(dim)]
44 | else:
45 | chunks = x.chunks[dim]
46 | n = da.arange(1, x.sizes[dim] + 1, chunks=chunks)
47 | else:
48 | n = np.arange(1, x.sizes[dim] + 1)
49 | n = xarray.DataArray(n, dims=[dim], coords={dim: x.coords[dim]})
50 | else:
51 | # heavier computation
52 | n = (~x.isnull()).cumsum((dim,), skipna=False)
53 |
54 | return x.cumsum((dim,), skipna=skipna) / n
55 |
56 |
57 | def compound_sum(x: T, c: xarray.DataArray, xdim: Hashable, cdim: Hashable) -> T:
58 | """Compound sum on arbitrary points of x along dim.
59 |
60 | :param x:
61 | :class:`~xarray.DataArray` or :class:`~xarray.Dataset` containing the
62 | data to be compounded
63 | :param xarray.DataArray c:
64 | array where every row contains elements of x.coords[xdim] and
65 | is used to build a point of the output.
66 | The cells in the row are matched against x.coords[dim] and perform a
67 | sum. If different rows of c require different amounts of points from x,
68 | they must be padded on the right with NaN, NaT, or '' (respectively for
69 | numbers, datetimes, and strings).
70 | :param hashable xdim:
71 | dimension of x to acquire data from. The coord associated to it must be
72 | monotonic ascending.
73 | :param hashable cdim:
74 | dimension of c that represent the vector of points to be compounded for
75 | every point of dim
76 | :returns:
77 | xarray object of the same type and dtype as x, with all dims from x
78 | and c except xdim and cdim.
79 |
80 | example::
81 |
82 | >>> x = xarray.DataArray(
83 | >>> [10, 20, 30],
84 | >>> dims=['x'], coords={'x': ['foo', 'bar', 'baz']})
85 | >>> c = xarray.DataArray(
86 | >>> [['foo', 'baz', None],
87 | >>> ['bar', 'baz', 'baz']],
88 | >>> dims=['y', 'c'], coords={'y': ['new1', 'new2']})
89 | >>> compound_sum(x, c, 'x', 'c')
90 |
91 | array([40, 80])
92 | Coordinates:
93 | * y (y) T:
99 | """Compound product among arbitrary points of x along dim
100 | See :func:`compound_sum`.
101 | """
102 | return _compound(x, c, xdim, cdim, kernels.compound_prod)
103 |
104 |
105 | def compound_mean(x: T, c: xarray.DataArray, xdim: Hashable, cdim: Hashable) -> T:
106 | """Compound mean among arbitrary points of x along dim
107 | See :func:`compound_sum`.
108 | """
109 | return _compound(x, c, xdim, cdim, kernels.compound_mean)
110 |
111 |
112 | def _compound(
113 | x: T,
114 | c: xarray.DataArray,
115 | xdim: Hashable,
116 | cdim: Hashable,
117 | kernel: Callable[[T, xarray.DataArray], T],
118 | ) -> T:
119 | """Implementation of all compound functions
120 |
121 | :param kernel:
122 | numba kernel to apply to (x, idx), where
123 | idx is an array of indices with the same shape as c,
124 | containing the indices along x.coords[xdim] or -1 where c is null.
125 | """
126 | # Convert coord points to indexes of x.coords[dim]
127 | idx = xarray.DataArray(x.coords[xdim].searchsorted(c), dims=c.dims, coords=c.coords)
128 | # searchsorted(NaN) returns 0; replace it with -1.
129 | # isnull('') returns False. We could have asked for None, however
130 | # searchsorted will refuse to compare strings and None's
131 | if c.dtype.kind == "U":
132 | idx = idx.where(c != "", -1)
133 | else:
134 | idx = idx.where(~c.isnull(), -1)
135 |
136 | dtype = x.dtypes if isinstance(x, xarray.Dataset) else x.dtype
137 |
138 | return xarray.apply_ufunc(
139 | kernel,
140 | x,
141 | idx,
142 | input_core_dims=[[xdim], [cdim]],
143 | output_core_dims=[[]],
144 | dask="parallelized",
145 | output_dtypes=[dtype],
146 | )
147 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "xarray_extras"
3 | authors = [{name = "Guido Imperiale", email = "crusaderky@gmail.com"}]
4 | license = {text = "Apache"}
5 | description = "Advanced / experimental algorithms for xarray"
6 | keywords = ["xarray"]
7 | classifiers = [
8 | "Development Status :: 3 - Alpha",
9 | "License :: OSI Approved :: Apache Software License",
10 | "Operating System :: OS Independent",
11 | "Intended Audience :: Science/Research",
12 | "Topic :: Scientific/Engineering",
13 | "Programming Language :: Python",
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.8",
16 | "Programming Language :: Python :: 3.9",
17 | "Programming Language :: Python :: 3.10",
18 | "Programming Language :: Python :: 3.11",
19 | "Programming Language :: Python :: 3.12",
20 | "Programming Language :: Python :: 3.13",
21 | ]
22 | requires-python = ">=3.8"
23 | dependencies = [
24 | "dask >= 2022.6.0",
25 | "numba >= 0.56",
26 | "numpy >= 1.23",
27 | "pandas >= 1.5",
28 | "scipy >= 1.9",
29 | "xarray >= 2022.11.0",
30 | ]
31 | dynamic = ["version"]
32 |
33 | [project.urls]
34 | Homepage = "https://github.com/crusaderky/xarray_extras"
35 | "Bug Tracker" = "https://github.com/crusaderky/xarray_extras/issues"
36 | Changelog = "https://xarray-extras.readthedocs.io/en/latest/whats-new.html"
37 |
38 | [project.readme]
39 | text = "Advanced / experimental algorithms for xarray"
40 | content-type = "text/x-rst"
41 |
42 | [tool.setuptools]
43 | packages = ["xarray_extras"]
44 | zip-safe = false # https://mypy.readthedocs.io/en/latest/installed_packages.html
45 | include-package-data = true
46 |
47 | [tool.setuptools_scm]
48 | # Use hardcoded version when .git has been removed and this is not a package created
49 | # by sdist. This is the case e.g. of a remote deployment with PyCharm.
50 | fallback_version = "9999"
51 |
52 | [tool.setuptools.package-data]
53 | xarray_extras = [
54 | "py.typed",
55 | "tests/data/*",
56 | ]
57 |
58 | [build-system]
59 | requires = [
60 | "setuptools>=66",
61 | "setuptools_scm[toml]",
62 | ]
63 | build-backend = "setuptools.build_meta"
64 |
65 | [tool.pytest.ini_options]
66 | addopts = "--strict-markers --strict-config -v -r sxfE --color=yes"
67 | xfail_strict = true
68 | python_files = ["test_*.py"]
69 | testpaths = ["xarray_extras/tests"]
70 | filterwarnings = [
71 | "error",
72 | # FIXME these need to be fixed in xarray
73 | "ignore:__array_wrap__ must accept context and return:DeprecationWarning",
74 | # FIXME these need to be looked at
75 | 'ignore:.*will no longer be implicitly promoted:FutureWarning',
76 | 'ignore:.*updating coordinate .* with a PandasMultiIndex would leave the multi-index level coordinates .* in an inconsistent state:FutureWarning',
77 | # These have been fixed; still needed for Python 3.9 CI
78 | "ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed:DeprecationWarning",
79 | 'ignore:Converting non-nanosecond precision datetime:UserWarning',
80 | 'ignore:Converting non-nanosecond precision timedelta:UserWarning',
81 | ]
82 |
83 | [tool.coverage.report]
84 | show_missing = true
85 | exclude_lines = [
86 | "pragma: nocover",
87 | "pragma: no cover",
88 | "TYPE_CHECKING",
89 | "except ImportError",
90 | "@overload",
91 | '@(abc\.)?abstractmethod',
92 | '@(numba\.)?jit',
93 | '@(numba\.)?vectorize',
94 | '@(numba\.)?guvectorize',
95 | ]
96 |
97 | [tool.codespell]
98 | ignore-words-list = ["ND"]
99 |
100 | [tool.ruff]
101 | exclude = [".eggs"]
102 | target-version = "py38"
103 |
104 | [tool.ruff.lint]
105 | ignore = [
106 | "EM101", # Exception must not use a string literal, assign to variable first
107 | "EM102", # Exception must not use an f-string literal, assign to variable first
108 | "N802", # Function name should be lowercase
109 | "N803", # Argument name should be lowercase
110 | "N806", # Variable should be lowercase
111 | "N816", # Variable in global scope should not be mixedCase
112 | "PD901", # Avoid using the generic variable name `df` for DataFrames
113 | "PT006", # Wrong type passed to first argument of `pytest.mark.parametrize`; expected `tuple`
114 | "PLC0414", # Import alias does not rename original package
115 | "PLR0912", # Too many branches
116 | "PLR0913", # Too many arguments in function definition
117 | "PLR2004", # Magic value used in comparison, consider replacing `123` with a constant variable
118 | "PLW2901", # for loop variable overwritten by assignment target
119 | "SIM108", # Use ternary operator instead of if-else block
120 | ]
121 | select = [
122 | "YTT", # flake8-2020
123 | "B", # flake8-bugbear
124 | "C4", # flake8-comprehensions
125 | "EM", # flake8-errmsg
126 | "EXE", # flake8-executable
127 | "ICN", # flake8-import-conventions
128 | "G", # flake8-logging-format
129 | "PIE", # flake8-pie
130 | "PT", # flake8-pytest-style
131 | "RET", # flake8-return
132 | "SIM", # flake8-simplify
133 | "ARG", # flake8-unused-arguments
134 | "I", # isort
135 | "NPY", # NumPy specific rules
136 | "N", # pep8-naming
137 | "E", # Pycodestyle
138 | "W", # Pycodestyle
139 | "PGH", # pygrep-hooks
140 | "F", # Pyflakes
141 | "PL", # pylint
142 | "UP", # pyupgrade
143 | "RUF", # unused-noqa
144 | "TID", # tidy-ups
145 | "EXE001", # Shebang is present but file is not executable
146 | ]
147 |
148 | [tool.ruff.lint.isort]
149 | known-first-party = ["xarray_extras"]
150 |
151 | [tool.mypy]
152 | disallow_incomplete_defs = true
153 | disallow_untyped_decorators = true
154 | disallow_untyped_defs = true
155 | ignore_missing_imports = true
156 | no_implicit_optional = true
157 | show_error_codes = true
158 | warn_redundant_casts = true
159 | warn_unused_ignores = true
160 | warn_unreachable = true
161 |
162 | [[tool.mypy.overrides]]
163 | module = ["*.tests.*"]
164 | disallow_untyped_defs = false
165 |
--------------------------------------------------------------------------------
/xarray_extras/csv.py:
--------------------------------------------------------------------------------
1 | """Multi-threaded CSV writer, much faster than :meth:`pandas.DataFrame.to_csv`,
2 | with full support for `dask `_ and `dask distributed
3 | `_.
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | from collections.abc import Callable
9 | from pathlib import Path
10 | from typing import TYPE_CHECKING, Any, cast
11 |
12 | import xarray
13 | from dask.base import tokenize
14 | from dask.delayed import Delayed
15 | from dask.highlevelgraph import HighLevelGraph
16 |
17 | from xarray_extras.kernels import csv as kernels
18 |
19 | __all__ = ("to_csv",)
20 |
21 | if TYPE_CHECKING:
22 | # TODO: remove TYPE_CHECKING (requires dask >=2023.9.1)
23 | from dask.typing import DaskCollection, Key
24 |
25 |
26 | def to_csv( # noqa: PLR0915
27 | x: xarray.DataArray, path: str | Path, *, nogil: bool = True, **kwargs: Any
28 | ) -> Any:
29 | """Print DataArray to CSV.
30 |
31 | When x has numpy backend, this function is functionally equivalent to (but
32 | much) faster than)::
33 |
34 | x.to_pandas().to_csv(path_or_buf, **kwargs)
35 |
36 | When x has dask backend, this function returns a dask delayed object which
37 | will write to the disk only when its .compute() method is invoked.
38 |
39 | Formatting and optional compression are parallelised across all available
40 | CPUs, using one dask task per chunk on the first dimension. Chunks on other
41 | dimensions will be merged ahead of computation.
42 |
43 | :param x:
44 | :class:`~xarray.DataArray` with one or two dimensions
45 | :param path:
46 | Output file path
47 | :param bool nogil:
48 | If True, use accelerated C implementation. Several kwargs won't be
49 | processed correctly (see limitations below). If False, use pandas
50 | to_csv method (slow, and does not release the GIL).
51 | nogil=True exclusively supports float and integer values dtypes (but
52 | the coords can be anything). In case of incompatible dtype, nogil
53 | is automatically switched to False.
54 | :param kwargs:
55 | Passed verbatim to :meth:`pandas.DataFrame.to_csv` or
56 | :meth:`pandas.Series.to_csv`
57 |
58 | **Limitations**
59 |
60 | - Fancy URIs are not (yet) supported.
61 | - compression='zip' is not supported. All other compression methods (gzip,
62 | bz2, xz) are supported.
63 | - When running with nogil=True, the following parameters are ignored:
64 | columns, quoting, quotechar, doublequote, escapechar, chunksize, decimal
65 |
66 | **Distributed computing**
67 |
68 | This function supports `dask distributed`_, with the caveat that all workers
69 | must write to the same shared mountpoint and that the shared filesystem
70 | must strictly guarantee **close-open coherency**, meaning that one must be
71 | able to call write() and then close() on a file descriptor from one host
72 | and then immediately afterwards open() from another host and see the output
73 | from the first host. Note that, for performance reasons, most network
74 | filesystems do not enable this feature by default.
75 |
76 | Alternatively, one may write to local mountpoints and then manually collect
77 | and concatenate the partial outputs.
78 | """
79 | if not isinstance(x, xarray.DataArray):
80 | raise TypeError("first argument must be a DataArray")
81 |
82 | # Health checks
83 | if not isinstance(path, (str, Path)):
84 | raise TypeError("path_or_buf must be a string or a pathlib.Path object")
85 |
86 | path = Path(path)
87 |
88 | if x.ndim not in (1, 2):
89 | raise ValueError(
90 | f"cannot convert arrays with {x.ndim} dimensions into pandas objects"
91 | )
92 |
93 | if nogil and x.dtype.kind not in "if":
94 | nogil = False
95 |
96 | # Extract row and columns indices
97 | indices = [x.get_index(dim) for dim in x.dims]
98 | if x.ndim == 2:
99 | index, columns = indices
100 | else:
101 | index = indices[0]
102 | columns = None
103 |
104 | compression = kwargs.pop("compression", "infer")
105 | compress = _compress_func(path, compression)
106 | mode = kwargs.pop("mode", "w")
107 | if mode not in "wa":
108 | raise ValueError(f"mode: expected w or a; got {mode!r}")
109 |
110 | # Fast exit for numpy backend
111 | if not x.chunks:
112 | bdata = kernels.to_csv(x.values, index, columns, True, nogil, kwargs)
113 | if compress:
114 | bdata = compress(bdata)
115 | with open(path, mode + "b") as fh:
116 | fh.write(bdata)
117 | return None
118 |
119 | # Merge chunks on all dimensions beyond the first
120 | x = x.chunk({dim: -1 for dim in x.dims[1:]})
121 |
122 | # Manually define the dask graph
123 | tok = tokenize(x.data, index, columns, compression, path, kwargs)
124 | name1 = "to_csv_encode-" + tok
125 | name2 = "to_csv_compress-" + tok
126 | name3 = "to_csv_write-" + tok
127 | name4 = "to_csv-" + tok
128 |
129 | dsk: dict[Key, Any] = {}
130 |
131 | assert x.chunks
132 | assert x.chunks[0]
133 | offset = 0
134 | for i, size in enumerate(x.chunks[0]):
135 | # Slice index
136 | index_i = index[offset : offset + size]
137 | offset += size
138 |
139 | x_i = (x.data.name, i) + (0,) * (x.ndim - 1)
140 |
141 | # Step 1: convert to CSV and encode to binary blob
142 | if i == 0:
143 | # First chunk: print header
144 | dsk[name1, i] = (kernels.to_csv, x_i, index_i, columns, True, nogil, kwargs)
145 | else:
146 | kwargs_i = kwargs.copy()
147 | kwargs_i["header"] = False
148 | dsk[name1, i] = (kernels.to_csv, x_i, index_i, None, False, nogil, kwargs_i)
149 |
150 | # Step 2 (optional): compress
151 | if compress:
152 | prevname = name2
153 | dsk[name2, i] = compress, (name1, i)
154 | else:
155 | prevname = name1
156 |
157 | # Step 3: write to file
158 | if i == 0:
159 | # First chunk: overwrite file if it already exists
160 | # Convert PosixPath / WindowsPath to str to support dask Client
161 | # on Windows and Worker on Linux or vice versa
162 | dsk[name3, i] = (
163 | kernels.to_file,
164 | str(path),
165 | mode + "b",
166 | (prevname, i),
167 | None,
168 | )
169 | else:
170 | # Next chunks: wait for previous chunk to complete and append
171 | dsk[name3, i] = (
172 | kernels.to_file,
173 | str(path),
174 | "ab",
175 | (prevname, i),
176 | (name3, i - 1),
177 | )
178 |
179 | # Rename final key
180 | dsk[name4] = dsk.pop((name3, i))
181 |
182 | hlg = HighLevelGraph.from_collections(name4, dsk, [cast("DaskCollection", x)])
183 | return Delayed(name4, hlg)
184 |
185 |
186 | def _compress_func(
187 | path: Path, compression: str | None
188 | ) -> Callable[[bytes], bytes] | None:
189 | if compression == "infer":
190 | compression = path.suffix[1:].lower()
191 | if compression == "gz":
192 | compression = "gzip"
193 | elif compression == "csv":
194 | compression = None
195 |
196 | if compression is None:
197 | return None
198 | if compression == "gzip":
199 | import gzip
200 |
201 | return gzip.compress
202 | if compression == "bz2":
203 | import bz2
204 |
205 | return bz2.compress
206 | if compression == "xz":
207 | import lzma
208 |
209 | return lzma.compress
210 | if compression == "zip":
211 | raise NotImplementedError("zip compression is not supported")
212 | raise ValueError(f"Unrecognized compression: {compression}")
213 |
--------------------------------------------------------------------------------
/xarray_extras/interpolate.py:
--------------------------------------------------------------------------------
1 | """xarray spline interpolation functions"""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Hashable
6 |
7 | import dask.array as da
8 | import numpy as np
9 | import xarray
10 |
11 | from xarray_extras.compat import dask_array_type
12 | from xarray_extras.kernels import interpolate as kernels
13 |
14 | __all__ = ("splev", "splrep")
15 |
16 |
17 | def splrep(a: xarray.DataArray, dim: Hashable, k: int = 3) -> xarray.Dataset:
18 | """Calculate the univariate B-spline for an N-dimensional array
19 |
20 | :param xarray.DataArray a:
21 | any :class:`~xarray.DataArray`
22 | :param dim:
23 | dimension of a to be interpolated. ``a.coords[dim]`` must be strictly
24 | monotonic ascending. All int, float (not complex), or datetime dtypes
25 | are supported.
26 | :param int k:
27 | B-spline order:
28 |
29 | = ==================
30 | k interpolation kind
31 | = ==================
32 | 0 nearest neighbour
33 | 1 linear
34 | 2 quadratic
35 | 3 cubic
36 | = ==================
37 |
38 | :returns:
39 | :class:`~xarray.Dataset` with t, c, k (knots, coefficients, order)
40 | variables, the same shape and coords as the input, that can be passed
41 | to :func:`splev`.
42 |
43 | Example::
44 |
45 | >>> x = np.arange(0, 120, 20)
46 | >>> x = xarray.DataArray(x, dims=['x'], coords={'x': x})
47 | >>> s = xarray.DataArray(np.linspace(1, 20, 5), dims=['s'])
48 | >>> y = np.exp(-x / s)
49 | >>> x_new = np.arange(0, 120, 1)
50 | >>> tck = splrep(y, 'x')
51 | >>> y_new = splev(x_new, tck)
52 |
53 | **Features**
54 |
55 | - Interpolate a ND array on any arbitrary dimension
56 | - dask supported on both on the interpolated array and x_new
57 | - Supports ND x_new arrays
58 | - The CPU-heavy interpolator generation (:func:`splrep`) is executed only
59 | once and then can be applied to multiple x_new (:func:`splev`)
60 | - memory-efficient
61 | - Can be pickled and used on dask distributed
62 |
63 | **Limitations**
64 |
65 | - Chunks are not supported along dim on the interpolated dimension.
66 | """
67 | # Make sure that dim is on axis 0
68 | a = a.transpose(dim, *[d for d in a.dims if d != dim])
69 | x = a.coords[dim].values
70 |
71 | if x.dtype.kind == "M": # datetime
72 | # Same treatment will be applied to x_new.
73 | # Allow x_new.dtype==M8[D] and x.dtype==M8[ns], or vice versa
74 | x = x.astype("M8[ns]").astype(float)
75 | elif x.dtype.kind == "m": # timedelta
76 | x = x.astype("m8[ns]").astype(float)
77 |
78 | t = kernels.make_interp_knots(x, k, check_finite=False)
79 | if k < 2:
80 | t_c_param = None
81 | else:
82 | t_c_param = t
83 |
84 | if isinstance(a.data, dask_array_type):
85 | from dask.array import map_blocks
86 |
87 | if len(a.data.chunks[0]) > 1:
88 | raise NotImplementedError(
89 | "Unsupported: multiple chunks on interpolation dim"
90 | )
91 |
92 | c = map_blocks(
93 | kernels.make_interp_coeffs,
94 | x,
95 | a.data,
96 | k=k,
97 | t=t_c_param,
98 | check_finite=False,
99 | dtype=float,
100 | )
101 | else:
102 | c = kernels.make_interp_coeffs(x, a.data, k=k, t=t_c_param, check_finite=False)
103 |
104 | return xarray.Dataset(
105 | data_vars={
106 | "t": ("__t__", t),
107 | "c": (a.dims, c),
108 | },
109 | coords=a.coords,
110 | attrs={
111 | "spline_dim": dim,
112 | "k": k,
113 | },
114 | )
115 |
116 |
117 | def splev(
118 | x_new: object, tck: xarray.Dataset, extrapolate: bool | str = True
119 | ) -> xarray.DataArray:
120 | """Evaluate the B-spline generated with :func:`splrep`.
121 |
122 | :param x_new:
123 | Any :class:`~xarray.DataArray` with any number of dims, not necessarily
124 | the original interpolation dim.
125 | Alternatively, it can be any 1-dimensional array-like; it will be
126 | automatically converted to a :class:`~xarray.DataArray` on the
127 | interpolation dim.
128 |
129 | :param xarray.Dataset tck:
130 | As returned by :func:`splrep`.
131 | It can have been:
132 |
133 | - transposed (not recommended, as performance will
134 | drop if c is not C-contiguous)
135 | - sliced, reordered, or (re)chunked, on any
136 | dim except the interpolation dim
137 | - computed from dask to numpy backend
138 | - round-tripped to disk
139 |
140 | :param extrapolate:
141 | True
142 | Extrapolate the first and last polynomial pieces of b-spline
143 | functions active on the base interval
144 | False
145 | Return NaNs outside of the base interval
146 | 'periodic'
147 | Periodic extrapolation is used
148 | 'clip'
149 | Return y[0] and y[-1] outside of the base interval
150 |
151 | :returns:
152 | :class:`~xarray.DataArray` with all dims of the interpolated array,
153 | minus the interpolation dim, plus all dims of x_new
154 |
155 | See :func:`splrep` for usage example.
156 | """
157 | # Pre-process x_new into a DataArray
158 | if not isinstance(x_new, xarray.DataArray):
159 | if not isinstance(x_new, dask_array_type):
160 | x_new = np.array(x_new)
161 | if x_new.ndim == 0:
162 | dims = []
163 | elif x_new.ndim == 1:
164 | dims = [tck.spline_dim]
165 | else:
166 | raise ValueError(
167 | "N-dimensional x_new is only supported if x_new is a DataArray"
168 | )
169 | x_new = xarray.DataArray(x_new, dims=dims, coords={tck.spline_dim: x_new})
170 |
171 | dim = tck.spline_dim
172 | t = tck.t
173 | c = tck.c
174 | k = tck.k
175 |
176 | invalid_dims = {*x_new.dims} & {*c.dims} - {dim}
177 | if invalid_dims:
178 | raise ValueError(
179 | "Overlapping dims between interpolated "
180 | "array and x_new: " + ",".join(str(d) for d in invalid_dims)
181 | )
182 |
183 | if t.shape != (c.sizes[dim] + k + 1,):
184 | raise ValueError("Interpolated dimension has been sliced")
185 |
186 | if x_new.dtype.kind == "M": # datetime
187 | # Note that we're modifying the x_new values, not the x_new coords
188 | x_new = x_new.astype("M8[ns]").astype(float)
189 | elif x_new.dtype.kind == "m": # timedelta
190 | x_new = x_new.astype("m8[ns]").astype(float)
191 |
192 | if extrapolate == "clip":
193 | x = tck.coords[dim].values
194 |
195 | if x.dtype.kind == "M": # datetime
196 | x = x.astype("M8[ns]").astype(float)
197 | elif x.dtype.kind == "m": # timedelta
198 | x = x.astype("m8[ns]").astype(float)
199 |
200 | x_new = np.clip(x_new, x[0].tolist(), x[-1].tolist())
201 | extrapolate = False
202 |
203 | if c.dims[0] != dim:
204 | c = c.transpose(dim, *[d for d in c.dims if d != dim])
205 |
206 | if any(isinstance(v.data, dask_array_type) for v in (x_new, t, c)):
207 | if t.chunks and len(t.chunks[0]) > 1:
208 | raise NotImplementedError(
209 | "Unsupported: multiple chunks on interpolation dim"
210 | )
211 | if c.chunks and len(c.chunks[0]) > 1:
212 | raise NotImplementedError(
213 | "Unsupported: multiple chunks on interpolation dim"
214 | )
215 |
216 | # omitting t and c
217 | x_new_axes = "abdefghijklm"[: x_new.ndim]
218 | c_axes = "nopqrsuvwxyz"[: c.ndim - 1]
219 |
220 | y_new = da.blockwise(
221 | kernels.splev,
222 | x_new_axes + c_axes,
223 | x_new.data,
224 | x_new_axes,
225 | t.data,
226 | "t",
227 | c.data,
228 | "c" + c_axes,
229 | k=k,
230 | extrapolate=extrapolate,
231 | concatenate=True,
232 | dtype=float,
233 | )
234 | else:
235 | y_new = kernels.splev(
236 | x_new.values, t.values, c.values, k, extrapolate=extrapolate
237 | )
238 |
239 | y_new = xarray.DataArray(y_new, dims=x_new.dims + c.dims[1:], coords=x_new.coords)
240 | y_new.coords.update({k: c for k, c in c.coords.items() if dim not in c.dims})
241 | return y_new
242 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction, and
10 | distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by the copyright
13 | owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all other entities
16 | that control, are controlled by, or are under common control with that entity.
17 | For the purposes of this definition, "control" means (i) the power, direct or
18 | indirect, to cause the direction or management of such entity, whether by
19 | contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
20 | outstanding shares, or (iii) beneficial ownership of such entity.
21 |
22 | "You" (or "Your") shall mean an individual or Legal Entity exercising
23 | permissions granted by this License.
24 |
25 | "Source" form shall mean the preferred form for making modifications, including
26 | but not limited to software source code, documentation source, and configuration
27 | files.
28 |
29 | "Object" form shall mean any form resulting from mechanical transformation or
30 | translation of a Source form, including but not limited to compiled object code,
31 | generated documentation, and conversions to other media types.
32 |
33 | "Work" shall mean the work of authorship, whether in Source or Object form, made
34 | available under the License, as indicated by a copyright notice that is included
35 | in or attached to the work (an example is provided in the Appendix below).
36 |
37 | "Derivative Works" shall mean any work, whether in Source or Object form, that
38 | is based on (or derived from) the Work and for which the editorial revisions,
39 | annotations, elaborations, or other modifications represent, as a whole, an
40 | original work of authorship. For the purposes of this License, Derivative Works
41 | shall not include works that remain separable from, or merely link (or bind by
42 | name) to the interfaces of, the Work and Derivative Works thereof.
43 |
44 | "Contribution" shall mean any work of authorship, including the original version
45 | of the Work and any modifications or additions to that Work or Derivative Works
46 | thereof, that is intentionally submitted to Licensor for inclusion in the Work
47 | by the copyright owner or by an individual or Legal Entity authorized to submit
48 | on behalf of the copyright owner. For the purposes of this definition,
49 | "submitted" means any form of electronic, verbal, or written communication sent
50 | to the Licensor or its representatives, including but not limited to
51 | communication on electronic mailing lists, source code control systems, and
52 | issue tracking systems that are managed by, or on behalf of, the Licensor for
53 | the purpose of discussing and improving the Work, but excluding communication
54 | that is conspicuously marked or otherwise designated in writing by the copyright
55 | owner as "Not a Contribution."
56 |
57 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf
58 | of whom a Contribution has been received by Licensor and subsequently
59 | incorporated within the Work.
60 |
61 | 2. Grant of Copyright License.
62 |
63 | Subject to the terms and conditions of this License, each Contributor hereby
64 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
65 | irrevocable copyright license to reproduce, prepare Derivative Works of,
66 | publicly display, publicly perform, sublicense, and distribute the Work and such
67 | Derivative Works in Source or Object form.
68 |
69 | 3. Grant of Patent License.
70 |
71 | Subject to the terms and conditions of this License, each Contributor hereby
72 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
73 | irrevocable (except as stated in this section) patent license to make, have
74 | made, use, offer to sell, sell, import, and otherwise transfer the Work, where
75 | such license applies only to those patent claims licensable by such Contributor
76 | that are necessarily infringed by their Contribution(s) alone or by combination
77 | of their Contribution(s) with the Work to which such Contribution(s) was
78 | submitted. If You institute patent litigation against any entity (including a
79 | cross-claim or counterclaim in a lawsuit) alleging that the Work or a
80 | Contribution incorporated within the Work constitutes direct or contributory
81 | patent infringement, then any patent licenses granted to You under this License
82 | for that Work shall terminate as of the date such litigation is filed.
83 |
84 | 4. Redistribution.
85 |
86 | You may reproduce and distribute copies of the Work or Derivative Works thereof
87 | in any medium, with or without modifications, and in Source or Object form,
88 | provided that You meet the following conditions:
89 |
90 | You must give any other recipients of the Work or Derivative Works a copy of
91 | this License; and
92 | You must cause any modified files to carry prominent notices stating that You
93 | changed the files; and
94 | You must retain, in the Source form of any Derivative Works that You distribute,
95 | all copyright, patent, trademark, and attribution notices from the Source form
96 | of the Work, excluding those notices that do not pertain to any part of the
97 | Derivative Works; and
98 | If the Work includes a "NOTICE" text file as part of its distribution, then any
99 | Derivative Works that You distribute must include a readable copy of the
100 | attribution notices contained within such NOTICE file, excluding those notices
101 | that do not pertain to any part of the Derivative Works, in at least one of the
102 | following places: within a NOTICE text file distributed as part of the
103 | Derivative Works; within the Source form or documentation, if provided along
104 | with the Derivative Works; or, within a display generated by the Derivative
105 | Works, if and wherever such third-party notices normally appear. The contents of
106 | the NOTICE file are for informational purposes only and do not modify the
107 | License. You may add Your own attribution notices within Derivative Works that
108 | You distribute, alongside or as an addendum to the NOTICE text from the Work,
109 | provided that such additional attribution notices cannot be construed as
110 | modifying the License.
111 | You may add Your own copyright statement to Your modifications and may provide
112 | additional or different license terms and conditions for use, reproduction, or
113 | distribution of Your modifications, or for any such Derivative Works as a whole,
114 | provided Your use, reproduction, and distribution of the Work otherwise complies
115 | with the conditions stated in this License.
116 |
117 | 5. Submission of Contributions.
118 |
119 | Unless You explicitly state otherwise, any Contribution intentionally submitted
120 | for inclusion in the Work by You to the Licensor shall be under the terms and
121 | conditions of this License, without any additional terms or conditions.
122 | Notwithstanding the above, nothing herein shall supersede or modify the terms of
123 | any separate license agreement you may have executed with Licensor regarding
124 | such Contributions.
125 |
126 | 6. Trademarks.
127 |
128 | This License does not grant permission to use the trade names, trademarks,
129 | service marks, or product names of the Licensor, except as required for
130 | reasonable and customary use in describing the origin of the Work and
131 | reproducing the content of the NOTICE file.
132 |
133 | 7. Disclaimer of Warranty.
134 |
135 | Unless required by applicable law or agreed to in writing, Licensor provides the
136 | Work (and each Contributor provides its Contributions) on an "AS IS" BASIS,
137 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
138 | including, without limitation, any warranties or conditions of TITLE,
139 | NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
140 | solely responsible for determining the appropriateness of using or
141 | redistributing the Work and assume any risks associated with Your exercise of
142 | permissions under this License.
143 |
144 | 8. Limitation of Liability.
145 |
146 | In no event and under no legal theory, whether in tort (including negligence),
147 | contract, or otherwise, unless required by applicable law (such as deliberate
148 | and grossly negligent acts) or agreed to in writing, shall any Contributor be
149 | liable to You for damages, including any direct, indirect, special, incidental,
150 | or consequential damages of any character arising as a result of this License or
151 | out of the use or inability to use the Work (including but not limited to
152 | damages for loss of goodwill, work stoppage, computer failure or malfunction, or
153 | any and all other commercial damages or losses), even if such Contributor has
154 | been advised of the possibility of such damages.
155 |
156 | 9. Accepting Warranty or Additional Liability.
157 |
158 | While redistributing the Work or Derivative Works thereof, You may choose to
159 | offer, and charge a fee for, acceptance of support, warranty, indemnity, or
160 | other liability obligations and/or rights consistent with this License. However,
161 | in accepting such obligations, You may act only on Your own behalf and on Your
162 | sole responsibility, not on behalf of any other Contributor, and only if You
163 | agree to indemnify, defend, and hold each Contributor harmless for any liability
164 | incurred by, or claims asserted against, such Contributor by reason of your
165 | accepting any such warranty or additional liability.
166 |
167 | END OF TERMS AND CONDITIONS
168 |
169 | APPENDIX: How to apply the Apache License to your work
170 |
171 | To apply the Apache License to your work, attach the following boilerplate
172 | notice, with the fields enclosed by brackets "[]" replaced with your own
173 | identifying information. (Don't include the brackets!) The text should be
174 | enclosed in the appropriate comment syntax for the file format. We also
175 | recommend that a file or class name and description of purpose be included on
176 | the same "printed page" as the copyright notice for easier identification within
177 | third-party archives.
178 |
179 | Copyright [yyyy] [name of copyright owner]
180 |
181 | Licensed under the Apache License, Version 2.0 (the "License");
182 | you may not use this file except in compliance with the License.
183 | You may obtain a copy of the License at
184 |
185 | http://www.apache.org/licenses/LICENSE-2.0
186 |
187 | Unless required by applicable law or agreed to in writing, software
188 | distributed under the License is distributed on an "AS IS" BASIS,
189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190 | See the License for the specific language governing permissions and
191 | limitations under the License.
192 |
--------------------------------------------------------------------------------
/xarray_extras/tests/test_interpolate.py:
--------------------------------------------------------------------------------
1 | import dask.array as da
2 | import numpy as np
3 | import pytest
4 | from xarray import DataArray, Dataset, apply_ufunc
5 | from xarray.testing import assert_allclose, assert_equal
6 |
7 | import xarray_extras.kernels.interpolate as kernels
8 | from xarray_extras.interpolate import splev, splrep
9 |
10 |
11 | @pytest.mark.parametrize(
12 | "k,expect",
13 | [
14 | (0, [40, 55, 4]),
15 | (1, [45, 35, 4.5]),
16 | (2, [45, 37.90073529, np.nan]),
17 | (3, [45, 39.69583333, np.nan]),
18 | ],
19 | )
20 | def test_0d(k, expect):
21 | """
22 | - Test different orders
23 | - Test unsorted x
24 | - Test what happens when a series contains NaN
25 | """
26 | y = DataArray(
27 | [[10, 20, 30, 40, 50, 60], [11, 28, 39, 55, 15, -2], [np.nan, 2, 3, 4, 5, 6]],
28 | dims=["y", "x"],
29 | coords={"x": [1, 2, 3, 4, 5, 6], "y": ["y1", "y2", "y3"]},
30 | )
31 | tck = splrep(y, "x", k)
32 | expect = DataArray(
33 | expect, dims=["y"], coords={"x": 4.5, "y": ["y1", "y2", "y3"]}
34 | ).astype(float)
35 | assert_allclose(splev(4.5, tck), expect, rtol=0, atol=1e-6)
36 |
37 |
38 | @pytest.mark.parametrize(
39 | "x_new,expect",
40 | [
41 | # list
42 | ([3.5, 4.5], DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]})),
43 | # tuple
44 | ((3.5, 4.5), DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]})),
45 | # np.array
46 | (
47 | np.array([3.5, 4.5]),
48 | DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]}),
49 | ),
50 | # da.Array
51 | (
52 | da.from_array(np.array([3.5, 4.5]), chunks=1),
53 | DataArray([35.0, 45.0], dims=["x"], coords={"x": [3.5, 4.5]}).chunk(1),
54 | ),
55 | # DataArray, same dim as y, no coord
56 | (DataArray([3.5, 4.5], dims=["x"]), DataArray([35.0, 45.0], dims=["x"])),
57 | # DataArray, same dim as y, with coord
58 | (
59 | DataArray([3.5, 4.5], dims=["x"], coords={"x": [100, 200]}),
60 | DataArray([35.0, 45.0], dims=["x"], coords={"x": [100, 200]}),
61 | ),
62 | # DataArray, different dim as y
63 | (DataArray([3.5, 4.5], dims=["t"]), DataArray([35.0, 45.0], dims=["t"])),
64 | ],
65 | )
66 | def test_1d(x_new, expect):
67 | """
68 | - Test 1d case
69 | - Test auto-casting of various types of x_new
70 | """
71 | y = DataArray(
72 | [10, 20, 30, 40, 50, 60], dims=["x"], coords={"x": [1, 2, 3, 4, 5, 6]}
73 | )
74 | tck = splrep(y, "x", k=1)
75 | y_new = splev(x_new, tck)
76 | assert_equal(y_new, expect)
77 | assert y_new.chunks == expect.chunks
78 |
79 |
80 | @pytest.mark.parametrize(
81 | "chunk_y,chunk_x_new,expect_chunks_tck,expect_chunks_y_new",
82 | [
83 | (False, False, {}, None),
84 | (False, True, {}, ((1, 1), (1, 1), (2,))),
85 | (True, False, {"y": (1, 1), "x": (6,)}, ((2,), (2,), (1, 1))),
86 | (True, True, {"y": (1, 1), "x": (6,)}, ((1, 1), (1, 1), (1, 1))),
87 | ],
88 | )
89 | def test_nd(chunk_y, chunk_x_new, expect_chunks_tck, expect_chunks_y_new):
90 | """
91 | - Test ND y vs. ND x_new
92 | - Test dask
93 | """
94 | y = DataArray(
95 | [[10, 20, 30, 40, 50, 60], [11, 28, 39, 55, 15, -2]],
96 | dims=["y", "x"],
97 | coords={"x": [1, 2, 3, 4, 5, 6], "y": ["y1", "y2"]},
98 | )
99 | x_new = DataArray(
100 | [[3.5, 4.5], [1.5, 5.5]],
101 | dims=["w", "z"],
102 | coords={"w": [100, 200], "z": ["foo", "bar"]},
103 | )
104 | expect_tck = Dataset(
105 | data_vars={
106 | "t": ("__t__", [1.0, 1.0, 1.0, 1.0, 3.0, 4.0, 6.0, 6.0, 6.0, 6.0]),
107 | "c": (
108 | ("x", "y"),
109 | [
110 | [10.0, 11.0],
111 | [16.666667, 33.118519],
112 | [26.666667, 20.762963],
113 | [43.333333, 84.003704],
114 | [53.333333, -25.251852],
115 | [60.0, -2],
116 | ],
117 | ),
118 | },
119 | coords={"x": [1, 2, 3, 4, 5, 6], "y": ["y1", "y2"]},
120 | attrs={"spline_dim": "x", "k": 3},
121 | )
122 |
123 | expect_y_new = DataArray(
124 | [
125 | [[35.0, 51.0375], [45.0, 39.69583333]],
126 | [[15.0, 22.72083333], [55.0, -3.945833]],
127 | ],
128 | dims=["w", "z", "y"],
129 | coords={
130 | "w": [100, 200],
131 | "y": ["y1", "y2"],
132 | "z": ["foo", "bar"],
133 | },
134 | )
135 |
136 | if chunk_y:
137 | y = y.chunk({"y": 1})
138 | if chunk_x_new:
139 | x_new = x_new.chunk(1)
140 |
141 | tck = splrep(y, "x", k=3)
142 | assert_allclose(tck.compute(), expect_tck, atol=1e-6, rtol=0)
143 | y_new = splev(x_new, tck)
144 | assert_allclose(y_new.compute(), expect_y_new, atol=1e-6, rtol=0)
145 |
146 | assert tck.chunks == expect_chunks_tck
147 | assert y_new.chunks == expect_chunks_y_new
148 |
149 |
150 | @pytest.mark.parametrize("contiguous", [False, True])
151 | @pytest.mark.parametrize("transpose", [False, True])
152 | def test_transpose(transpose, contiguous):
153 | """Test that the interpolation dim does not need to be on axis 0"""
154 | y = DataArray(
155 | [[10, 20], [30, 40], [50, 60]],
156 | dims=["y", "x"],
157 | coords={"x": [1, 2], "y": ["y1", "y2", "y3"]},
158 | )
159 | expect = DataArray(
160 | [[15.0, 35.0, 55.0], [10.0, 30.0, 50.0]],
161 | dims=["x", "y"],
162 | coords={"x": [1.5, 1.0], "y": ["y1", "y2", "y3"]},
163 | )
164 |
165 | if transpose:
166 | y = y.T
167 | if contiguous:
168 | y = apply_ufunc(np.ascontiguousarray, y)
169 |
170 | tck = splrep(y, "x", 1)
171 | y_new = splev([1.5, 1.0], tck)
172 | assert_equal(expect, y_new)
173 |
174 |
175 | @pytest.mark.parametrize(
176 | "x_new_dtype", [np.uint32, np.uint64, np.int32, np.int64, np.float32]
177 | )
178 | @pytest.mark.parametrize(
179 | "x_dtype", [np.uint32, np.uint64, np.int32, np.int64, np.float32]
180 | )
181 | def test_nonfloat(x_dtype, x_new_dtype):
182 | """Test numeric x that isn't float64"""
183 | x = np.array([0, 100])
184 |
185 | y = DataArray((x * 3).astype(x_dtype), dims=["x"], coords={"x": x.astype(x_dtype)})
186 | x_new = np.array([50]).astype(x_new_dtype)
187 | expect = DataArray([150.0], dims=["x"], coords={"x": x_new})
188 |
189 | tck = splrep(y, "x", k=1)
190 | y_new = splev(x_new, tck)
191 | assert_equal(expect, y_new)
192 |
193 |
194 | @pytest.mark.filterwarnings("ignore:Converting non-nanosecond precision datetime ")
195 | @pytest.mark.parametrize("x_new_dtype", ["M8[D]", "M8[s]", "M8[ns]"])
196 | @pytest.mark.parametrize("x_dtype", ["M8[D]", "M8[s]", "M8[ns]"])
197 | def test_dates(x_dtype, x_new_dtype):
198 | """
199 | - Test mismatched date formats on x and x_new
200 | - Test clip extrapolation on test_dates
201 | """
202 | y = DataArray(
203 | [10, 20],
204 | dims=["x"],
205 | coords={"x": np.array(["2000-01-01", "2001-01-01"], dtype=x_dtype)},
206 | )
207 | x_new = np.array(["2000-04-20", "2002-07-28"], dtype=x_new_dtype)
208 | expect = DataArray([13.00546448, 20.0], dims=["x"], coords={"x": x_new})
209 |
210 | tck = splrep(y, "x", k=1)
211 | y_new = splev(x_new, tck, extrapolate="clip")
212 | assert y_new.x.dtype == expect.x.dtype
213 | assert_allclose(expect, y_new, atol=1e-6, rtol=0)
214 |
215 |
216 | @pytest.mark.filterwarnings("ignore:Converting non-nanosecond precision datetime ")
217 | @pytest.mark.parametrize("x_new_dtype", ["m8[D]", "m8[s]", "m8[ns]"])
218 | @pytest.mark.parametrize("x_dtype", ["m8[D]", "m8[s]", "m8[ns]"])
219 | def test_timedeltas(x_dtype, x_new_dtype):
220 | """
221 | - Test mismatched date formats on x and x_new
222 | - Test clip extrapolation on test_dates
223 | """
224 | y = DataArray(
225 | [10, 20],
226 | dims=["x"],
227 | coords={"x": np.array([30, 50], dtype="m8[D]").astype(x_dtype)},
228 | )
229 | x_new = np.array([35, 45], dtype="m8[D]").astype(x_new_dtype)
230 | expect = DataArray([12.5, 17.5], dims=["x"], coords={"x": x_new})
231 |
232 | tck = splrep(y, "x", k=1)
233 | y_new = splev(x_new, tck, extrapolate="clip")
234 | assert y_new.x.dtype == expect.x.dtype
235 | assert_allclose(expect, y_new, atol=1e-6, rtol=0)
236 |
237 |
238 | @pytest.mark.parametrize(
239 | "extrapolate,expect",
240 | [
241 | (True, [-1.36401507, -1.22694955]),
242 | (False, [np.nan, np.nan]),
243 | ("clip", [0, np.sin(9)]),
244 | ("periodic", [0.98935825, 0.84147098]),
245 | ],
246 | )
247 | def test_extrapolate(extrapolate, expect):
248 | """Test all possible extrapolate parameters"""
249 | x = np.arange(10)
250 | y = DataArray(np.sin(x), dims=["x"], coords={"x": x})
251 | x_new = [-1, 10]
252 | expect = DataArray(expect, dims=["x"], coords={"x": x_new})
253 |
254 | tck = splrep(y, "x", k=3)
255 | y_new = splev(x_new, tck, extrapolate=extrapolate)
256 | assert_allclose(expect, y_new, atol=1e-6, rtol=0)
257 |
258 |
259 | def test_dim_collision():
260 | """y and x_new have overlapping dims besides the interpolation dim"""
261 | y = DataArray(
262 | [[10, 20], [11, 28]], dims=["y", "x"], coords={"x": [1, 2], "y": ["y1", "y2"]}
263 | )
264 | x_new = DataArray([1, 1], dims=["y"])
265 | tck = splrep(y, "x", 1)
266 | with pytest.raises(
267 | ValueError,
268 | match="Overlapping dims between interpolated array and x_new: y",
269 | ):
270 | splev(x_new, tck)
271 |
272 |
273 | def test_duplicates():
274 | """x contains non-unique points"""
275 | y = DataArray([10, 20, 30], dims=["x"], coords={"x": [1, 1, 2]})
276 | with pytest.raises(ValueError, match="Expect x to be a 1-D sorted array-like"):
277 | splrep(y, "x", 1)
278 |
279 |
280 | def test_chunked_x():
281 | """x is chunked"""
282 | y = DataArray([10, 20], dims=["x"], coords={"x": [1, 2]}).chunk(1)
283 |
284 | with pytest.raises(
285 | NotImplementedError,
286 | match="Unsupported: multiple chunks on interpolation dim",
287 | ):
288 | splrep(y, "x", 1)
289 |
290 |
291 | def test_distributed():
292 | def ro_array(a):
293 | a = np.array(a)
294 | a.setflags(write=False)
295 | # Return a view of a, so that setting the write flag on the view is not
296 | # enough
297 | return a[:]
298 |
299 | x = ro_array([1.0, 2.0])
300 | y = ro_array([10.0, 20.0])
301 | t = kernels.make_interp_knots(x, k=1)
302 | t = ro_array(t)
303 | c = kernels.make_interp_coeffs(x, y, k=1, t=t)
304 | c = ro_array(c)
305 | x_new = ro_array([1.5, 1.8])
306 | kernels.splev(x_new, t, c, k=1)
307 |
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | # documentation build configuration file, created by
4 | # sphinx-quickstart on Thu Feb 6 18:57:54 2014.
5 | #
6 | # This file is execfile()d with the current directory set to its
7 | # containing dir.
8 | #
9 | # Note that not all possible configuration values are present in this
10 | # autogenerated file.
11 | #
12 | # All configuration values have a default; values that are commented out
13 | # serve to show the default.
14 | import datetime
15 | import os
16 | import sys
17 |
18 | import xarray
19 |
20 | import xarray_extras
21 |
22 | print("python exec:", sys.executable)
23 | print("sys.path:", sys.path)
24 | print("xarray_extras version: ", xarray_extras.__version__)
25 |
26 |
27 | # -- General configuration ------------------------------------------------
28 |
29 | # If your documentation needs a minimal Sphinx version, state it here.
30 | # needs_sphinx = '1.0'
31 |
32 | # Add any Sphinx extension module names here, as strings. They can be
33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
34 | # ones.
35 | extensions = [
36 | "sphinx.ext.autodoc",
37 | "sphinx.ext.autosummary",
38 | "sphinx.ext.intersphinx",
39 | "sphinx.ext.extlinks",
40 | "sphinx.ext.mathjax",
41 | ]
42 |
43 | extlinks = {
44 | "issue": ("https://github.com/crusaderky/xarray_extras/issues/%s", "GH"),
45 | "pull": ("https://github.com/crusaderky/xarray_extras/pull/%s", "PR"),
46 | }
47 |
48 | autosummary_generate = True
49 |
50 | # Add any paths that contain templates here, relative to this directory.
51 | templates_path = ["_templates"]
52 |
53 | # The suffix of source filenames.
54 | source_suffix = ".rst"
55 |
56 | # The encoding of source files.
57 | # source_encoding = 'utf-8-sig'
58 |
59 | # The master toctree document.
60 | master_doc = "index"
61 |
62 | # General information about the project.
63 | project = "xarray_extras"
64 | copyright = f"2018-{datetime.datetime.now().year}, xarray_extras Developers"
65 |
66 | # The version info for the project you're documenting, acts as replacement for
67 | # |version| and |release|, also used in various other places throughout the
68 | # built documents.
69 | #
70 | # The short X.Y version.
71 | version = xarray_extras.__version__.split("+")[0]
72 | # The full version, including alpha/beta/rc tags.
73 | release = xarray_extras.__version__
74 |
75 | # The language for content autogenerated by Sphinx. Refer to documentation
76 | # for a list of supported languages.
77 | # language = None
78 |
79 | # There are two options for replacing |today|: either, you set today to some
80 | # non-false value, then it is used:
81 | # today = ''
82 | # Else, today_fmt is used as the format for a strftime call.
83 | today_fmt = "%Y-%m-%d"
84 |
85 | # List of patterns, relative to source directory, that match files and
86 | # directories to ignore when looking for source files.
87 | exclude_patterns = ["_build"]
88 |
89 | # The reST default role (used for this markup: `text`) to use for all
90 | # documents.
91 | # default_role = None
92 |
93 | # If true, '()' will be appended to :func: etc. cross-reference text.
94 | # add_function_parentheses = True
95 |
96 | # If true, the current module name will be prepended to all description
97 | # unit titles (such as .. function::).
98 | # add_module_names = True
99 |
100 | # If true, sectionauthor and moduleauthor directives will be shown in the
101 | # output. They are ignored by default.
102 | # show_authors = False
103 |
104 | # The name of the Pygments (syntax highlighting) style to use.
105 | pygments_style = "sphinx"
106 |
107 | # A list of ignored prefixes for module index sorting.
108 | # modindex_common_prefix = []
109 |
110 | # If true, keep warnings as "system message" paragraphs in the built documents.
111 | # keep_warnings = False
112 |
113 |
114 | # -- Options for HTML output ----------------------------------------------
115 |
116 | # The theme to use for HTML and HTML Help pages. See the documentation for
117 | # a list of builtin themes.
118 | html_theme = "sphinx_rtd_theme"
119 |
120 | # Theme options are theme-specific and customize the look and feel of a theme
121 | # further. For a list of options available for each theme, see the
122 | # documentation.
123 | html_theme_options = {"logo_only": True}
124 |
125 | # Add any paths that contain custom themes here, relative to this directory.
126 | # html_theme_path = []
127 |
128 | # The name for this set of Sphinx documents. If None, it defaults to
129 | # " v documentation".
130 | # html_title = None
131 |
132 | # A shorter title for the navigation bar. Default is the same as html_title.
133 | # html_short_title = None
134 |
135 | # The name of an image file (relative to this directory) to place at the top
136 | # of the sidebar.
137 | # html_logo = "_static/logo.png"
138 |
139 | # The name of an image file (within the static path) to use as favicon of the
140 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
141 | # pixels large.
142 | # html_favicon = None
143 |
144 | # Add any paths that contain custom static files (such as style sheets) here,
145 | # relative to this directory. They are copied after the builtin static files,
146 | # so a file named "default.css" will overwrite the builtin "default.css".
147 | html_static_path = ["_static"]
148 |
149 | # Sometimes the savefig directory doesn't exist and needs to be created
150 | # https://github.com/ipython/ipython/issues/8733
151 | # becomes obsolete when we can pin ipython>=5.2; see doc/environment.yml
152 | ipython_savefig_dir = os.path.join(
153 | os.path.dirname(os.path.abspath(__file__)), "_build", "html", "_static"
154 | )
155 | if not os.path.exists(ipython_savefig_dir):
156 | os.makedirs(ipython_savefig_dir)
157 |
158 | # Add any extra paths that contain custom files (such as robots.txt or
159 | # .htaccess) here, relative to this directory. These files are copied
160 | # directly to the root of the documentation.
161 | # html_extra_path = []
162 |
163 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
164 | # using the given strftime format.
165 | html_last_updated_fmt = today_fmt
166 |
167 | # If true, SmartyPants will be used to convert quotes and dashes to
168 | # typographically correct entities.
169 | # html_use_smartypants = True
170 |
171 | # Custom sidebar templates, maps document names to template names.
172 | # html_sidebars = {}
173 |
174 | # Additional templates that should be rendered to pages, maps page names to
175 | # template names.
176 | # html_additional_pages = {}
177 |
178 | # If false, no module index is generated.
179 | # html_domain_indices = True
180 |
181 | # If false, no index is generated.
182 | # html_use_index = True
183 |
184 | # If true, the index is split into individual pages for each letter.
185 | # html_split_index = False
186 |
187 | # If true, links to the reST sources are added to the pages.
188 | # html_show_sourcelink = True
189 |
190 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
191 | # html_show_sphinx = True
192 |
193 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
194 | # html_show_copyright = True
195 |
196 | # If true, an OpenSearch description file will be output, and all pages will
197 | # contain a tag referring to it. The value of this option must be the
198 | # base URL from which the finished HTML is served.
199 | # html_use_opensearch = ''
200 |
201 | # This is the file name suffix for HTML files (e.g. ".xhtml").
202 | # html_file_suffix = None
203 |
204 | # Output file base name for HTML help builder.
205 | htmlhelp_basename = "xarray_extrasdoc"
206 |
207 |
208 | # -- Options for LaTeX output ---------------------------------------------
209 |
210 | latex_elements: dict[str, str] = {
211 | # The paper size ('letterpaper' or 'a4paper').
212 | # 'papersize': 'letterpaper',
213 | # The font size ('10pt', '11pt' or '12pt').
214 | # 'pointsize': '10pt',
215 | # Additional stuff for the LaTeX preamble.
216 | # 'preamble': '',
217 | }
218 |
219 | # Grouping the document tree into LaTeX files. List of tuples
220 | # (source start file, target name, title,
221 | # author, documentclass [howto, manual, or own class]).
222 | latex_documents = [
223 | (
224 | "index",
225 | "xarray_extras.tex",
226 | "xarray_extras Documentation",
227 | "xarray_extras Developers",
228 | "manual",
229 | ),
230 | ]
231 |
232 | # The name of an image file (relative to this directory) to place at the top of
233 | # the title page.
234 | # latex_logo = None
235 |
236 | # For "manual" documents, if this is true, then toplevel headings are parts,
237 | # not chapters.
238 | # latex_use_parts = False
239 |
240 | # If true, show page references after internal links.
241 | # latex_show_pagerefs = False
242 |
243 | # If true, show URL addresses after external links.
244 | # latex_show_urls = False
245 |
246 | # Documents to append as an appendix to all manuals.
247 | # latex_appendices = []
248 |
249 | # If false, no module index is generated.
250 | # latex_domain_indices = True
251 |
252 |
253 | # -- Options for manual page output ---------------------------------------
254 |
255 | # One entry per manual page. List of tuples
256 | # (source start file, name, description, authors, manual section).
257 | man_pages = [
258 | (
259 | "index",
260 | "xarray_extras",
261 | "xarray_extras Documentation",
262 | ["xarray_extras Developers"],
263 | 1,
264 | )
265 | ]
266 |
267 | # If true, show URL addresses after external links.
268 | # man_show_urls = False
269 |
270 |
271 | # -- Options for Texinfo output -------------------------------------------
272 |
273 | # Grouping the document tree into Texinfo files. List of tuples
274 | # (source start file, target name, title, author,
275 | # dir menu entry, description, category)
276 | texinfo_documents = [
277 | (
278 | "index",
279 | "xarray_extras",
280 | "xarray_extras Documentation",
281 | "xarray_extras Developers",
282 | "xarray_extras",
283 | "Advanced / experimental algorithms for xarray",
284 | "Miscellaneous",
285 | ),
286 | ]
287 |
288 | # Documents to append as an appendix to all manuals.
289 | # texinfo_appendices = []
290 |
291 | # If false, no module index is generated.
292 | # texinfo_domain_indices = True
293 |
294 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
295 | # texinfo_show_urls = 'footnote'
296 |
297 | # If true, do not generate a @detailmenu in the "Top" node's menu.
298 | # texinfo_no_detailmenu = False
299 |
300 |
301 | # Example configuration for intersphinx: refer to the Python standard library.
302 | intersphinx_mapping = {
303 | "python": ("https://docs.python.org/3/", None),
304 | "dask": ("https://docs.dask.org/en/latest/", None),
305 | "distributed": ("https://distributed.dask.org/en/latest/", None),
306 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
307 | "numpy": ("https://docs.scipy.org/doc/numpy/", None),
308 | "numba": ("https://numba.pydata.org/numba-doc/latest/", None),
309 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
310 | "xarray": ("https://xarray.pydata.org/en/stable/", None),
311 | }
312 |
313 | # Work around intersphinx issue
314 | xarray.DataArray.__module__ = "xarray"
315 | xarray.Dataset.__module__ = "xarray"
316 | xarray.Variable.__module__ = "xarray"
317 |
--------------------------------------------------------------------------------
/xarray_extras/tests/test_csv.py:
--------------------------------------------------------------------------------
1 | import bz2
2 | import gzip
3 | import lzma
4 | import pickle
5 | import tempfile
6 | from pathlib import Path
7 |
8 | import dask
9 | import numpy as np
10 | import pytest
11 | import xarray
12 |
13 | from xarray_extras.csv import to_csv
14 |
15 |
16 | def assert_to_csv(
17 | x, chunks, nogil, dtype, open_func=open, float_format="%f", ext="csv", **kwargs
18 | ):
19 | x = x.astype(dtype)
20 | if chunks:
21 | x = x.chunk(chunks)
22 | with tempfile.TemporaryDirectory() as tmp:
23 | x.to_pandas().to_csv(tmp + "/1." + ext, float_format=float_format, **kwargs)
24 | f = to_csv(
25 | x, tmp + "/2." + ext, nogil=nogil, float_format=float_format, **kwargs
26 | )
27 | dask.compute(f)
28 |
29 | with open_func(tmp + "/1." + ext, "rb") as fh:
30 | d1 = fh.read()
31 | with open_func(tmp + "/2." + ext, "rb") as fh:
32 | d2 = fh.read()
33 | assert d2 == d1
34 |
35 |
36 | def assert_to_csv_with_path_type(
37 | x,
38 | path_type,
39 | chunks,
40 | nogil,
41 | dtype,
42 | open_func=open,
43 | float_format="%f",
44 | ext="csv",
45 | **kwargs,
46 | ):
47 | x = x.astype(dtype)
48 | if chunks:
49 | x = x.chunk(chunks)
50 | with tempfile.TemporaryDirectory() as tmp:
51 | path_1 = tmp + "/1." + ext if path_type == "str" else Path(tmp) / ("1." + ext)
52 | path_2 = tmp + "/2." + ext if path_type == "str" else Path(tmp) / ("2." + ext)
53 | x.to_pandas().to_csv(path_1, float_format=float_format, **kwargs)
54 | f = to_csv(x, path_2, nogil=nogil, float_format=float_format, **kwargs)
55 | dask.compute(f)
56 |
57 | with open_func(tmp + "/1." + ext, "rb") as fh:
58 | d1 = fh.read()
59 | with open_func(tmp + "/2." + ext, "rb") as fh:
60 | d2 = fh.read()
61 | assert d2 == d1
62 |
63 |
64 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
65 | @pytest.mark.parametrize("nogil", [False, True])
66 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
67 | @pytest.mark.parametrize("header", [False, True])
68 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"])
69 | def test_series(chunks, nogil, dtype, header, lineterminator):
70 | x = xarray.DataArray([1, 2, 3, 4], dims=["x"], coords={"x": [10, 20, 30, 40]})
71 | assert_to_csv(x, chunks, nogil, dtype, header=header, lineterminator=lineterminator)
72 |
73 |
74 | @pytest.mark.parametrize("path_type", ["path", "str"])
75 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
76 | @pytest.mark.parametrize("nogil", [False, True])
77 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
78 | @pytest.mark.parametrize("header", [False, True])
79 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"])
80 | def test_series_with_path(path_type, chunks, nogil, dtype, header, lineterminator):
81 | x = xarray.DataArray([1, 2, 3, 4], dims=["x"], coords={"x": [10, 20, 30, 40]})
82 | assert_to_csv_with_path_type(
83 | x, path_type, chunks, nogil, dtype, header=header, lineterminator=lineterminator
84 | )
85 |
86 |
87 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
88 | @pytest.mark.parametrize("nogil", [False, True])
89 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}])
90 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"])
91 | def test_dataframe(chunks, nogil, dtype, lineterminator):
92 | x = xarray.DataArray(
93 | [[1, 2, 3, 4], [5, 6, 7, 8]],
94 | dims=["r", "c"],
95 | coords={"r": ["a", "b"], "c": [10, 20, 30, 40]},
96 | )
97 | assert_to_csv(x, chunks, nogil, dtype, lineterminator=lineterminator)
98 |
99 |
100 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
101 | @pytest.mark.parametrize("nogil", [False, True])
102 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}])
103 | def test_multiindex(chunks, nogil, dtype):
104 | x = xarray.DataArray(
105 | [[1, 2], [3, 4]],
106 | dims=["r", "c"],
107 | coords={
108 | "r1": ("r", ["r11", "r12"]),
109 | "r2": ("r", ["r21", "r22"]),
110 | "c1": ("c", ["c11", "c12"]),
111 | "c2": ("c", ["c21", "c22"]),
112 | },
113 | )
114 | x = x.set_index(r=["r1", "r2"], c=["c1", "c2"])
115 | assert_to_csv(x, chunks, nogil, dtype)
116 |
117 |
118 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
119 | @pytest.mark.parametrize("nogil", [False, True])
120 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}])
121 | def test_no_header(chunks, nogil, dtype):
122 | x = xarray.DataArray([[1, 2], [3, 4]], dims=["r", "c"])
123 | assert_to_csv(x, chunks, nogil, dtype, index=False, header=False)
124 |
125 |
126 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
127 | @pytest.mark.parametrize("nogil", [False, True])
128 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}])
129 | def test_custom_header(chunks, nogil, dtype):
130 | x = xarray.DataArray([[1, 2], [3, 4]], dims=["r", "c"])
131 | assert_to_csv(x, chunks, nogil, dtype, header=["foo", "bar"])
132 |
133 |
134 | @pytest.mark.parametrize("encoding", ["utf-8", "utf-16"])
135 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
136 | @pytest.mark.parametrize("nogil", [False, True])
137 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}])
138 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"])
139 | def test_encoding(chunks, nogil, dtype, encoding, lineterminator):
140 | # Note: in Python 2.7, default encoding is ascii in pandas and utf-8 in
141 | # xarray_extras. Therefore we will not test the default.
142 | x = xarray.DataArray(
143 | [[1], [2]], dims=["r", "c"], coords={"r": ["crème", "foo"], "c": ["brûlée"]}
144 | )
145 | assert_to_csv(
146 | x, chunks, nogil, dtype, encoding=encoding, lineterminator=lineterminator
147 | )
148 |
149 |
150 | @pytest.mark.parametrize("sep", [",", "|"])
151 | @pytest.mark.parametrize("float_format", ["%f", "%.2f", "%.15f", "%.5e"])
152 | @pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
153 | @pytest.mark.parametrize("nogil", [False, True])
154 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
155 | def test_kwargs(chunks, nogil, dtype, float_format, sep):
156 | x = xarray.DataArray([1.0, 1.1, 1.000000000000001, 123.456789], dims=["x"])
157 | assert_to_csv(x, chunks, nogil, dtype, float_format=float_format, sep=sep)
158 |
159 |
160 | @pytest.mark.parametrize("na_rep", ["", "nan"])
161 | @pytest.mark.parametrize("nogil", [False, True])
162 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
163 | def test_na_rep(chunks, nogil, na_rep):
164 | x = xarray.DataArray([np.nan, 1], dims=["x"])
165 | assert_to_csv(x, chunks, nogil, np.float64, na_rep=na_rep)
166 |
167 |
168 | @pytest.mark.parametrize(
169 | "compression,open_func",
170 | [
171 | (None, open),
172 | ("gzip", gzip.open),
173 | ("bz2", bz2.open),
174 | ("xz", lzma.open),
175 | ],
176 | )
177 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
178 | @pytest.mark.parametrize("nogil", [False, True])
179 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
180 | def test_compression(chunks, nogil, dtype, compression, open_func):
181 | x = xarray.DataArray([1, 2], dims=["x"])
182 | assert_to_csv(x, chunks, nogil, dtype, compression=compression, open_func=open_func)
183 |
184 |
185 | @pytest.mark.parametrize(
186 | "ext,open_func",
187 | [
188 | ("csv", open),
189 | ("csv.gz", gzip.open),
190 | ("csv.bz2", bz2.open),
191 | ("csv.xz", lzma.open),
192 | ],
193 | )
194 | @pytest.mark.parametrize("nogil", [False, True])
195 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
196 | def test_compression_infer(ext, open_func, nogil, chunks):
197 | x = xarray.DataArray([1, 2], dims=["x"])
198 | assert_to_csv(
199 | x, chunks=chunks, nogil=nogil, dtype=np.float64, ext=ext, open_func=open_func
200 | )
201 |
202 |
203 | @pytest.mark.parametrize("dtype", [np.int64, np.float64])
204 | @pytest.mark.parametrize("nogil", [False, True])
205 | @pytest.mark.parametrize("chunks", [None, {"r": 1, "c": 1}])
206 | def test_empty(chunks, nogil, dtype):
207 | x = xarray.DataArray(
208 | [[1, 2, 3, 4]], dims=["r", "c"], coords={"c": [10, 20, 30, 40]}
209 | )
210 | x = x.isel(r=slice(0))
211 | assert_to_csv(x, chunks, nogil, dtype)
212 |
213 |
214 | @pytest.mark.parametrize("x", [0, -(2**63)])
215 | @pytest.mark.parametrize("index", ["a", "a" * 1000])
216 | @pytest.mark.parametrize("nogil", [False, True])
217 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
218 | def test_buffer_overflow_int(chunks, nogil, index, x):
219 | a = xarray.DataArray([x], dims=["x"], coords={"x": [index]})
220 | assert_to_csv(a, chunks, nogil, np.int64)
221 |
222 |
223 | @pytest.mark.parametrize("x", [0, np.nan, 1.000000000000001, 1.7901234406790122e308])
224 | @pytest.mark.parametrize("index,coord", [(False, ""), (True, "a"), (True, "a" * 1000)])
225 | @pytest.mark.parametrize("na_rep", ["", "na" * 500])
226 | @pytest.mark.parametrize("float_format", ["%.16f", "%.1000f", "a" * 1000 + "%.0f"])
227 | @pytest.mark.parametrize("nogil", [False, True])
228 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
229 | def test_buffer_overflow_float(chunks, nogil, float_format, na_rep, index, coord, x):
230 | if nogil and not index and np.isnan(x) and na_rep == "":
231 | # Expected: b'""\n'
232 | # Actual: b'\n'
233 | pytest.xfail("pandas prints useless for empty lines")
234 |
235 | a = xarray.DataArray([x], dims=["x"], coords={"x": [coord]})
236 | assert_to_csv(
237 | a,
238 | chunks,
239 | nogil,
240 | np.float64,
241 | float_format=float_format,
242 | na_rep=na_rep,
243 | index=index,
244 | )
245 |
246 |
247 | @pytest.mark.parametrize("encoding", ["utf-8", "utf-16"])
248 | @pytest.mark.parametrize("dtype", [str, object])
249 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
250 | @pytest.mark.parametrize("lineterminator", ["\n", "\r\n"])
251 | def test_pandas_only(chunks, dtype, encoding, lineterminator):
252 | x = xarray.DataArray(["foo", "Crème brûlée"], dims=["x"])
253 | assert_to_csv(
254 | x,
255 | chunks=chunks,
256 | nogil=False,
257 | dtype=dtype,
258 | encoding=encoding,
259 | lineterminator=lineterminator,
260 | )
261 |
262 |
263 | @pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
264 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
265 | def test_pandas_only_complex(chunks, dtype):
266 | x = xarray.DataArray([1 + 2j], dims=["x"])
267 | assert_to_csv(x, chunks=chunks, nogil=False, dtype=dtype)
268 |
269 |
270 | @pytest.mark.parametrize("nogil", [False, True])
271 | @pytest.mark.parametrize("chunks", [None, {"x": 1}])
272 | def test_mode(chunks, nogil):
273 | x = xarray.DataArray([1, 2], dims=["x"])
274 | y = xarray.DataArray([3, 4], dims=["x"])
275 | if chunks:
276 | x = x.chunk(chunks)
277 | y = y.chunk(chunks)
278 |
279 | with tempfile.TemporaryDirectory() as tmp:
280 | f = to_csv(x, tmp + "/1.csv", mode="a", nogil=nogil, header=False, index=False)
281 | dask.compute(f)
282 | f = to_csv(y, tmp + "/1.csv", mode="a", nogil=nogil, header=False, index=False)
283 | dask.compute(f)
284 | with open(tmp + "/1.csv") as fh:
285 | assert fh.read() == "1\n2\n3\n4\n"
286 |
287 | f = to_csv(y, tmp + "/1.csv", mode="w", nogil=nogil, header=False, index=False)
288 | dask.compute(f)
289 | with open(tmp + "/1.csv") as fh:
290 | assert fh.read() == "3\n4\n"
291 |
292 |
293 | def test_none_fmt():
294 | """float_format=None differs between C and pandas; can't use assert_to_csv"""
295 | x = xarray.DataArray([1.0, 1.1, 1.000000000000001, 123.456789])
296 | y = x.astype(np.float32)
297 |
298 | with tempfile.TemporaryDirectory() as tmp:
299 | to_csv(x, tmp + "/1.csv", header=False)
300 | to_csv(y, tmp + "/2.csv", header=False)
301 |
302 | with open(tmp + "/1.csv") as fh:
303 | assert fh.read() == "0,1.0\n1,1.1\n2,1.0\n3,123.456789\n"
304 | with open(tmp + "/2.csv") as fh:
305 | assert fh.read() == "0,1.0\n1,1.1\n2,1.0\n3,123.456787\n"
306 |
307 |
308 | def test_pickle():
309 | x = xarray.DataArray([1, 2])
310 | with tempfile.TemporaryDirectory() as tmp:
311 | x.to_pandas().to_csv(tmp + "/1.csv")
312 | d = to_csv(x.chunk(1), tmp + "/2.csv")
313 | d = pickle.loads(pickle.dumps(d))
314 | d.compute()
315 |
316 | with open(tmp + "/1.csv", "rb") as fh:
317 | d1 = fh.read()
318 | with open(tmp + "/2.csv", "rb") as fh:
319 | d2 = fh.read()
320 | assert d1 == d2
321 |
--------------------------------------------------------------------------------