├── .codecov.yaml ├── .github ├── dependabot.yml └── workflows │ ├── pre-commit.yaml │ └── tests.yaml ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── ci └── requirements.txt ├── setup.cfg ├── setup.py ├── tests └── test_xpartition.py └── xpartition ├── __init__.py ├── xarray_utils.py └── xpartition.py /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Configuration taken from xarray 2 | codecov: 3 | require_ci_to_pass: yes 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: off 15 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Set update schedule for GitHub Actions 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | name: "pre-commit hooks" 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | - uses: pre-commit/action@v3.0.1 16 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | env: 9 | FORCE_COLOR: 1 10 | 11 | jobs: 12 | pytest: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | python-version: ["3.12"] 17 | os: [ubuntu-latest] 18 | platform: [x64] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: "pip" # caching pip dependencies 27 | - run: pip install -r ci/requirements.txt 28 | 29 | - name: Install xpartition 30 | run: pip install -v -e . --no-deps 31 | 32 | - name: Environment information 33 | run: python -m pip list 34 | 35 | - name: Run tests 36 | run: pytest -vv --cov=xpartition --cov-report=xml 37 | 38 | - name: Upload code coverage to Codecov 39 | uses: codecov/codecov-action@v5 40 | with: 41 | file: ./coverage.xml 42 | flags: unittests,${{ matrix.python-version }} 43 | name: codecov-umbrella 44 | fail_ci_if_error: false 45 | token: ${{ secrets.CODECOV_TOKEN }} 46 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.6.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/PyCQA/flake8.git 7 | rev: 6.1.0 8 | hooks: 9 | - id: flake8 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Spencer Clark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xpartition 2 | 3 | [![Build Status](https://github.com/spencerkclark/xpartition/actions/workflows/tests.yaml/badge.svg?branch=main)](https://github.com/spencerkclark/xpartition/actions) 4 | [![codecov](https://codecov.io/gh/spencerkclark/xpartition/branch/main/graph/badge.svg?token=H1DBBSTQ2V)](https://codecov.io/gh/spencerkclark/xpartition) 5 | [![PyPI](https://img.shields.io/pypi/v/xpartition.svg)](https://pypi.python.org/pypi/xpartition/) 6 | 7 | This is a tool that can make writing large xarray datasets to cohesive zarr 8 | stores from completely independent processes easier. 9 | 10 | ## Usage 11 | 12 | The primary use-case is something like this. Say you have a lot of netCDF files 13 | output from a simulation or observational dataset that you would like to stitch 14 | together into a zarr store. If you have a way of combining those files lazily — 15 | i.e. opening them into dask-backed arrays — into a single dataset with maybe 16 | some additional computations, then you can write contiguous "partitions" of that 17 | dataset out via independent processes. A "partition" corresponds to a 18 | contiguous group of dask "chunks." I.e. it can correspond to one or more chunks 19 | across any number of dimensions. A key detail is no partition straddles any 20 | dask chunks; this makes writing from independent processes completely safe. 21 | 22 | `xpartition` provides an accessor called `partition` that implements 23 | `initialize_store` and `write` methods. The pattern is to have some code that 24 | constructs the dataset lazily, then call `initialize_store`, and finally in a 25 | set of separate processes, call `write`. 26 | 27 | ### Simple serial example 28 | 29 | Before illustrating a use-case of `xpartition` on a cluster, we can start with a 30 | simple serial example. From this example it should be straightforward to 31 | imagine how to extend this to various distributed computing platforms, whether 32 | HPC or cloud-based, to do the same thing in parallel. 33 | 34 | Assume through some external package we have a function that can construct a 35 | dataset lazily. To incrementally write it to zarr using `xpartition` we would 36 | only need to do the following: 37 | 38 | ```python 39 | import xpartition 40 | 41 | from external import construct_lazy_dataset 42 | 43 | store = "store.zarr" 44 | partitions = 16 45 | partition_dims = ["tile", "time"] 46 | 47 | ds = construct_lazy_dataset() 48 | ds.partition.initialize_store(store) 49 | for partition in range(partitions): 50 | ds.partition.write(store, partitions, partition_dims, partition) 51 | ``` 52 | 53 | `partition_dims` describes the dimensions over which to partition the dataset; 54 | if chunks exist along dimensions that are not among the partition dimensions, 55 | then they will all be grouped together. If you are not particular about this, 56 | simply using `ds.dims` will also work out of the box. 57 | 58 | ### Parallelization using `multiprocessing` 59 | 60 | A parallel example can easily be illustrated using the built-in 61 | `multiprocessing` library; something similar could be done with `dask.bag`: 62 | 63 | ```python 64 | import xpartition 65 | 66 | from external import construct_lazy_dataset 67 | 68 | store = "store.zarr" 69 | partitions = 16 70 | partition_dims = ["tile", "time"] 71 | 72 | ds = construct_lazy_dataset() 73 | ds.partition.initialize_store(store) 74 | with multiprocessing.get_context("spawn").Pool(partitions) as pool: 75 | pool.map( 76 | ds.partition.mappable_write(store, partitions, partition_dims), 77 | range(partitions) 78 | ) 79 | ``` 80 | 81 | ### Parallelization using a SLURM array job 82 | 83 | Finally, the example below describes how one might use `xpartition` on an HPC 84 | cluster using a SLURM array job. We first start by writing a couple 85 | command-line interfaces that initialize the store and write a partition. We'll 86 | start with one called `initialize_store.py`: 87 | 88 | ```python 89 | import argparse 90 | import xpartition 91 | 92 | from external import construct_lazy_dataset 93 | 94 | parser = argparse.ArgumentParser( 95 | prog="initialize_store", 96 | description="initialize a zarr store for a dataset" 97 | ) 98 | parser.add_argument("store", help="absolute path to directory to store zarr result") 99 | 100 | args = parser.parse_args() 101 | ds = construct_lazy_dataset() 102 | ds.partition.initialize_store(args.store) 103 | ``` 104 | 105 | Next we'll write one called `write_partition.py`: 106 | 107 | ```python 108 | import argparse 109 | import xpartition 110 | 111 | from external import construct_lazy_dataset 112 | 113 | parser = argparse.ArgumentParser( 114 | prog="write_partition", 115 | description="write a partition of a dataset" 116 | ) 117 | parser.add_argument("store", help="absolute path to directory to store zarr result") 118 | parser.add_argument("ranks", type=int, help="total number of available ranks") 119 | parser.add_argument("rank", type=int, help="rank of job") 120 | 121 | args = parser.parse_args() 122 | 123 | # xpartition uses these as the dimensions to partition the jobs over. 124 | dims = ["tile", "time"] 125 | 126 | ds = construct_lazy_dataset() 127 | ds.partition.write(args.store, args.ranks, dims, args.rank) 128 | ``` 129 | 130 | Now we can write a couple bash scripts. The first will be a SLURM array job 131 | that writes all the partitions. The second will be a "main" script that 132 | controls the whole workflow. 133 | 134 | We call this one `write_partition.sh`: 135 | 136 | ``` 137 | #!/bin/bash 138 | #SBATCH --job-name=zarr-history-files-array-job 139 | #SBATCH --output=stdout/slurm-%A.%a.out # STDOUT file 140 | #SBATCH --error=stdout/slurm-%A.%a.err # STDERR file 141 | #SBATCH --time=16:00:00 # total run time limit (HH:MM:SS) 142 | #SBATCH --array=0-15 # job array with index values 0, 1, 2, ... 15 143 | 144 | echo "My SLURM_ARRAY_JOB_ID is $SLURM_ARRAY_JOB_ID." 145 | echo "My SLURM_ARRAY_TASK_ID is $SLURM_ARRAY_TASK_ID." 146 | 147 | STORE=$1 148 | RANKS=16 149 | RANK=$SLURM_ARRAY_TASK_ID 150 | 151 | python write_partition.py $STORE $RANKS $RANK 152 | ``` 153 | 154 | And we call this one `write_zarr.sh`: 155 | 156 | ``` 157 | #!/bin/bash 158 | set -e 159 | 160 | STORE=$1 161 | 162 | python initialize_store.py $STORE 163 | 164 | # Make a local directory for the stdout and stderr of the array jobs so that 165 | # they do not clutter up the local space. 166 | mkdir -p stdout 167 | 168 | # Submit the array job with the -W argument to sbatch; this tells SLURM to wait 169 | # until all array jobs have completed before returning from this script. 170 | sbatch -W write_partition.sh $STORE 171 | ``` 172 | 173 | Submitting the full task as is then as simple as: 174 | 175 | ``` 176 | bash write_zarr.sh /path/to/store.zarr 177 | ``` 178 | 179 | ## Motivation 180 | 181 | It is not always advantageous to let all computations be controlled by a single 182 | dask client. At the moment, the dask scheduler breaks down when having to 183 | manage a large number of memory-intensive tasks, often leading to slowdowns or 184 | out of memory errors ([this issue](https://github.com/dask/distributed/issues/6360) 185 | is perhaps a good summary of the state of things currently in dask). Breaking the 186 | problem down in the way that `xpartition` does, allows you to gain the benefits of 187 | dask's laziness on each independent process, while working in a distributed 188 | environment. *In an ideal world we wouldn't need a package like this — we would 189 | let dask and dask distributed handle everything — but in practice that does not 190 | work perfectly yet.* 191 | 192 | ## Installation 193 | 194 | `xpartition` can either be installed from PyPI: 195 | 196 | ``` 197 | $ pip install xpartition 198 | ``` 199 | 200 | or directly from source: 201 | 202 | ``` 203 | $ git clone https://github.com/spencerkclark/xpartition.git 204 | $ cd xpartition 205 | $ pip install -e . 206 | ``` 207 | 208 | ## See also 209 | 210 | There is some overlap between what this package does and what other libraries 211 | do, namely packages like: 212 | 213 | - [rechunker](https://github.com/pangeo-data/rechunker) 214 | - [pangeo-forge](https://github.com/pangeo-forge/pangeo-forge) 215 | - [xarray-beam](https://github.com/google/xarray-beam) 216 | -------------------------------------------------------------------------------- /ci/requirements.txt: -------------------------------------------------------------------------------- 1 | coveralls 2 | dask 3 | pytest 4 | pytest-cov 5 | xarray 6 | zarr 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = xpartition 3 | version = attr: xpartition.__version__ 4 | author = Spencer K. Clark 5 | author_email = spencerkclark@gmail.com 6 | license = MIT License 7 | description = Tool for writing large xarray datasets to zarr stores with independent processes 8 | long_description = 9 | xpartition provides a way to split N-dimensional dask-backed arrays into 10 | a user-specified number of blocks of dask chunks. This can be useful for 11 | assigning work to batch jobs on HPC systems or Dataflow workers in an 12 | Apache Beam pipeline in the cloud. 13 | long_description_content_type = text/plain 14 | url = https://github.com/spencerkclark/xpartition 15 | classifiers = 16 | Development Status :: 2 - Pre-Alpha 17 | License :: OSI Approved :: MIT License 18 | Operating System :: OS Independent 19 | Intended Audience :: Science/Research 20 | Programming Language :: Python 21 | Programming Language :: Python :: 3 22 | Topic :: Scientific/Engineering 23 | 24 | [options] 25 | packages = xpartition 26 | zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html 27 | include_package_data = True 28 | python_requires = >=3.6 29 | install_requires = 30 | xarray >= 2024.10.0 31 | dask[array] >= 2.9.0 32 | setuptools >= 38.4 # For pkg_resources 33 | dataclasses; python_version == "3.6" 34 | zarr 35 | setup_requires = 36 | setuptools >= 38.4 37 | setuptools_scm 38 | 39 | [flake8] 40 | ignore = 41 | E203 42 | E402 43 | E501 44 | E731 45 | W503 46 | exclude = 47 | .eggs 48 | doc 49 | __init__.py 50 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup 3 | 4 | setup(use_scm_version={"fallback_version": "999"}) 5 | -------------------------------------------------------------------------------- /tests/test_xpartition.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import string 4 | 5 | import dask 6 | import numpy as np 7 | import pytest 8 | import xarray as xr 9 | import zarr 10 | 11 | import xpartition 12 | 13 | from xpartition.xarray_utils import get_chunks_encoding, CountingScheduler 14 | from xpartition.xpartition import ( 15 | _zeros_like_dataarray, 16 | freeze_indexers, 17 | get_inner_chunk_size, 18 | get_inner_chunks_encoding, 19 | get_unchunked_data_var_names, 20 | get_unchunked_non_dimension_coord_names, 21 | get_unchunked_variable_names, 22 | unfreeze_indexers, 23 | validate_PartitionMapper_dataset, 24 | ) 25 | 26 | 27 | @pytest.mark.parametrize( 28 | ("block_indexers", "expected", "exception"), 29 | [ 30 | ({"x": slice(0, 3)}, {"x": slice(0, 6)}, None), 31 | ({"x": slice(1, 2)}, {"x": slice(2, 5)}, None), 32 | ({"x": slice(-3, -2)}, {"x": slice(0, 2)}, None), 33 | ({"x": slice(-3, -1)}, {"x": slice(0, 5)}, None), 34 | ({"x": slice(-3, None)}, {"x": slice(0, 6)}, None), 35 | ({"x": slice(None, 1)}, {"x": slice(0, 2)}, None), 36 | ({"x": slice(0, 10)}, {"x": slice(0, 6)}, None), 37 | ({"x": slice(-10, None)}, {"x": slice(0, 6)}, None), 38 | ({"x": slice(None, None)}, {"x": slice(0, 6)}, None), 39 | ({"x": slice(10, 12)}, {"x": slice(6, 6)}, None), 40 | ({"x": slice(2, 1)}, {"x": slice(5, 2)}, None), 41 | ({"x": 1}, {"x": slice(2, 5)}, None), 42 | ({"x": -1}, {"x": slice(5, 6)}, None), 43 | ({"x": -2}, {"x": slice(2, 5)}, None), 44 | ({"x": np.int32(2)}, {"x": slice(5, 6)}, None), 45 | ({"x": slice(0, 3), "y": 1}, {"x": slice(0, 6), "y": slice(3, 4)}, None), 46 | ({"x": 4}, None, IndexError), 47 | ({"x": -4}, None, IndexError), 48 | ({"z": 1}, None, KeyError), 49 | ({"x": slice(None, None, 2)}, None, NotImplementedError), 50 | ({"x": 2.0}, None, ValueError), 51 | ], 52 | ids=lambda x: f"{x}", 53 | ) 54 | def test_indexers(block_indexers, expected, exception): 55 | data = dask.array.zeros((6, 4), chunks=((2, 3, 1), (3, 1))) 56 | da = xr.DataArray(data, dims=["x", "y"]) 57 | if exception is None: 58 | result = da.blocks.indexers(**block_indexers) 59 | assert result == expected 60 | else: 61 | with pytest.raises(exception): 62 | da.blocks.indexers(**block_indexers) 63 | 64 | 65 | def test_isel(): 66 | data = dask.array.random.random((6, 4), chunks=((2, 3, 1), (3, 1))) 67 | da = xr.DataArray(data, dims=["x", "y"]) 68 | 69 | result = da.blocks.isel(x=slice(1, 2), y=1).data.compute() 70 | expected = data.blocks[1:2, 1].compute() 71 | 72 | np.testing.assert_array_equal(result, expected) 73 | 74 | 75 | @pytest.mark.filterwarnings("ignore:Specified Dask chunks") 76 | @pytest.mark.parametrize("ranks", [1, 2, 3, 5, 10, 11]) 77 | def test_dataarray_mappable_write(tmpdir, da, ranks): 78 | store = os.path.join(tmpdir, "test.zarr") 79 | ds = da.to_dataset() 80 | ds.to_zarr(store, compute=False) 81 | 82 | with multiprocessing.get_context("spawn").Pool(ranks) as pool: 83 | pool.map(da.partition.mappable_write(store, ranks, da.dims), range(ranks)) 84 | 85 | result = xr.open_zarr(store) 86 | xr.testing.assert_identical(result, ds) 87 | 88 | 89 | SHAPE_AND_CHUNK_PAIRS = [ 90 | ((5,), (1,)), 91 | ((5,), (2,)), 92 | ((5,), (5,)), 93 | ((2, 5), (1, 1)), 94 | ((2, 5), (2, 1)), 95 | ((2, 5), (2, 2)), 96 | ((2, 5), (2, 4)), 97 | ((2, 5), (2, 5)), 98 | ((2, 1, 6), (1, 1, 1)), 99 | ((2, 1, 6), (1, 1, 2)), 100 | ((2, 1, 6), (2, 1, 2)), 101 | ((2, 1, 6), (2, 1, 5)), 102 | ((2, 3, 4, 5), (1, 1, 1, 1)), 103 | ((2, 3, 4, 5), (2, 1, 3, 3)), 104 | ] 105 | 106 | 107 | @pytest.fixture(params=SHAPE_AND_CHUNK_PAIRS, ids=lambda x: str(x)) 108 | def da(request): 109 | shape, chunks = request.param 110 | name = "foo" 111 | return _construct_dataarray(shape, chunks, name) 112 | 113 | 114 | def _construct_dataarray(shape, chunks, name): 115 | dims = list(string.ascii_lowercase[: len(shape)]) 116 | data = np.random.random(shape) 117 | coords = [range(length) for length in shape] 118 | da = xr.DataArray(data, dims=dims, name=name, coords=coords) 119 | if chunks is not None: 120 | chunks = {dim: chunk for dim, chunk in zip(dims, chunks)} 121 | da = da.chunk(chunks) 122 | 123 | # Add coverage for chunked coordinates 124 | chunked_coord_name = f"{da.name}_chunked_coord" 125 | da = da.assign_coords({chunked_coord_name: da.chunk(chunks)}) 126 | return da 127 | 128 | 129 | ALIGNED_SHAPE_AND_CHUNK_PAIRS = [ 130 | ((5,), (1,)), 131 | ((5,), (2,)), 132 | ((5,), (5,)), 133 | ((5, 2), (1, 1)), 134 | ((5, 2), (1, 2)), 135 | ((5, 2), (2, 2)), 136 | ((5, 2), (4, 2)), 137 | ((5, 2), (5, 2)), 138 | ((5, 2, 6), (1, 1, 1)), 139 | ((5, 2, 6), (1, 1, 2)), 140 | ((5, 2, 6), (2, 1, 2)), 141 | ((5, 2, 6), (2, 2, 5)), 142 | ] 143 | 144 | 145 | @pytest.fixture 146 | def ds(): 147 | unchunked_dataarrays = [] 148 | for i, (shape, chunks) in enumerate(ALIGNED_SHAPE_AND_CHUNK_PAIRS): 149 | da = _construct_dataarray(shape, None, f"unchunked_{i}") 150 | unchunked_dataarrays.append(da) 151 | 152 | chunked_dataarrays = [] 153 | for i, (shape, chunks) in enumerate(ALIGNED_SHAPE_AND_CHUNK_PAIRS): 154 | da = _construct_dataarray(shape, chunks, f"chunked_{i}") 155 | chunked_dataarrays.append(da) 156 | 157 | return xr.merge(unchunked_dataarrays + chunked_dataarrays) 158 | 159 | 160 | def get_files(directory): 161 | names = os.listdir(directory) 162 | files = [] 163 | for name in names: 164 | path = os.path.join(directory, name) 165 | if os.path.isfile(path): 166 | files.append(path) 167 | return files 168 | 169 | 170 | def checkpoint_modification_times(store, variables): 171 | times = {} 172 | for variable in variables: 173 | directory = os.path.join(store, variable) 174 | files = get_files(directory) 175 | for file in files: 176 | times[file] = os.path.getmtime(file) 177 | return times 178 | 179 | 180 | @pytest.mark.filterwarnings("ignore:Specified Dask chunks") 181 | @pytest.mark.parametrize("ranks", [1, 2, 3, 5, 10, 11]) 182 | @pytest.mark.parametrize("collect_variable_writes", [False, True]) 183 | def test_dataset_mappable_write(tmpdir, ds, ranks, collect_variable_writes): 184 | unchunked_variables = get_unchunked_variable_names(ds) 185 | 186 | store = os.path.join(tmpdir, "test.zarr") 187 | ds.partition.initialize_store(store) 188 | 189 | # Checkpoint modification times of all files associated with unchunked 190 | # variables. These should remain unchanged after initialization. 191 | expected_times = checkpoint_modification_times(store, unchunked_variables) 192 | 193 | with multiprocessing.get_context("spawn").Pool(ranks) as pool: 194 | pool.map( 195 | ds.partition.mappable_write( 196 | store, ranks, ds.dims, collect_variable_writes=collect_variable_writes 197 | ), 198 | range(ranks), 199 | ) 200 | 201 | result = xr.open_zarr(store) 202 | 203 | # Check that dataset roundtrips identically. 204 | xr.testing.assert_identical(result, ds) 205 | 206 | # Checkpoint modification times of all files associated with unchunked 207 | # variables after writing the chunked variables. The modification times of 208 | # the unchunked variables should be the same as before writing the chunked 209 | # variables. 210 | resulting_times = checkpoint_modification_times(store, unchunked_variables) 211 | assert expected_times == resulting_times 212 | 213 | 214 | @pytest.mark.parametrize("has_coord", [True, False]) 215 | @pytest.mark.parametrize("has_chunked_coord", [True, False]) 216 | @pytest.mark.parametrize( 217 | "original_chunks", [{"x": 2}, {"x": 2, "y": 5}], ids=lambda x: f"{x}" 218 | ) 219 | def test_PartitionMapper_integration( 220 | tmpdir, has_coord, has_chunked_coord, original_chunks 221 | ): 222 | def func(ds): 223 | return ds.rename(z="new_name").assign_attrs(dataset_attr="fun") 224 | 225 | ds = xr.Dataset({"z": (["x", "y"], np.ones((5, 10)), {"an": "attr"})}).chunk( 226 | original_chunks 227 | ) 228 | if has_coord: 229 | ds = ds.assign_coords(x=range(5)) 230 | if has_chunked_coord: 231 | chunked_coord = xr.DataArray(range(5), dims=["x"]).chunk({"x": 5}) 232 | ds = ds.assign_coords(b=chunked_coord) 233 | 234 | unchunked_variables = get_unchunked_variable_names(ds) 235 | 236 | store = str(tmpdir) 237 | mapper = ds.z.partition.map(store, ranks=3, dims=["x"], func=func, data=ds) 238 | for i, rank in enumerate(mapper): 239 | if i == 0: 240 | expected_times = checkpoint_modification_times(store, unchunked_variables) 241 | mapper.write(rank) 242 | 243 | resulting_times = checkpoint_modification_times(store, unchunked_variables) 244 | assert expected_times == resulting_times 245 | 246 | written = xr.open_zarr(store) 247 | xr.testing.assert_identical(func(ds), written) 248 | 249 | 250 | def test_PartitionMapper_integration_error(): 251 | func = lambda ds: ds 252 | a = xr.DataArray(np.ones((5, 10)), [range(5), range(10)], ["x", "y"], name="a") 253 | b = a.copy(deep=True).rename("b").chunk({"x": 1}) 254 | ds = xr.merge([a, b]) 255 | mapper = ds.b.partition.map("store", ranks=3, dims=["x"], func=func, data=ds) 256 | with pytest.raises(ValueError, match="The PartitionMapper approach"): 257 | for rank in mapper: 258 | mapper.write(rank) 259 | 260 | 261 | def test_partition_partition(): 262 | # Partitions have two qualities which we test using a DataArray that 263 | # has all unique values 264 | ds = xr.Dataset({"z": (["x", "y"], np.arange(50).reshape((5, 10)))}).chunk({"x": 2}) 265 | arr = ds["z"] 266 | 267 | n = 3 268 | regions = arr.partition.partition(n, dims=["x"]) 269 | assert n == len(regions) 270 | 271 | def to_set(arr): 272 | return set(arr.values.ravel().tolist()) 273 | 274 | # These are the properties of a partition 275 | # 1. sets in a partition are disjoint 276 | intersection = set.intersection(*[to_set(arr.isel(region)) for region in regions]) 277 | assert intersection == set() 278 | 279 | # assert that the values cover the set 280 | # 2. the sets cover the original set 281 | union = set.union(*[to_set(arr.isel(region)) for region in regions]) 282 | assert union == to_set(arr) 283 | 284 | 285 | @pytest.mark.parametrize( 286 | ("original_chunks", "override_chunks", "expected_chunks"), 287 | [ 288 | ({"x": 5, "y": 2}, None, ((5, 5), (2, 2, 2))), 289 | ({"x": 5, "y": 2}, {"y": 3}, ((5, 5), (3, 3))), 290 | ({"x": 5, "y": 2}, {"y": 3, "z": 1}, ((5, 5), (3, 3))), 291 | ], 292 | ids=lambda x: f"{x}", 293 | ) 294 | @pytest.mark.parametrize("dtype", [float, int]) 295 | def test__zeros_like_dataarray( 296 | original_chunks, override_chunks, expected_chunks, dtype 297 | ): 298 | da = xr.DataArray(np.zeros((10, 6), dtype=dtype), dims=["x", "y"]).chunk( 299 | original_chunks 300 | ) 301 | result = _zeros_like_dataarray(da, override_chunks) 302 | result_chunks = result.chunks 303 | assert result_chunks == expected_chunks 304 | assert result.dtype == da.dtype 305 | 306 | 307 | def test_zeros_like(): 308 | shape = (2, 4) 309 | dims = ["x", "y"] 310 | attrs = {"foo": "bar"} 311 | 312 | data1 = dask.array.random.random(shape) 313 | data2 = dask.array.random.randint(0, size=shape) 314 | data3 = dask.array.random.random(shape, chunks=(1, 1)) 315 | 316 | da1 = xr.DataArray(data1, dims=dims, name="a", attrs=attrs) 317 | da2 = xr.DataArray(data2, dims=dims, name="b", attrs=attrs) 318 | da3 = xr.DataArray(data3, dims=dims, name="c", attrs=attrs) 319 | ds = xr.merge([da1, da2, da3]) 320 | 321 | zeros1_data = dask.array.zeros(shape) 322 | zeros2_data = dask.array.zeros(shape, dtype=int) 323 | zeros3_data = dask.array.zeros(shape, chunks=(1, 1)) 324 | 325 | zeros1 = xr.DataArray(zeros1_data, dims=dims, name="a", attrs=attrs) 326 | zeros2 = xr.DataArray(zeros2_data, dims=dims, name="b", attrs=attrs) 327 | zeros3 = xr.DataArray(zeros3_data, dims=dims, name="c", attrs=attrs) 328 | expected = xr.merge([zeros1, zeros2, zeros3]) 329 | 330 | result = xpartition.zeros_like(ds) 331 | xr.testing.assert_identical(result, expected) 332 | 333 | for var in result: 334 | # assert_identical does not check dtype or chunks 335 | assert result[var].dtype == expected[var].dtype 336 | assert result[var].chunks == expected[var].chunks 337 | 338 | 339 | def test_partition_indexers_invalid_rank_error(): 340 | data = dask.array.zeros((6, 4), chunks=((6, 4))) 341 | da = xr.DataArray(data, dims=["x", "y"]) 342 | with pytest.raises(ValueError, match="greater than maximum rank"): 343 | da.partition.indexers(1, 1, ["x"]) 344 | 345 | 346 | @pytest.mark.parametrize( 347 | ("unfrozen_indexers", "frozen_indexers"), 348 | [ 349 | ( 350 | {"a": slice(None, None, 3), "b": slice(1, 10, 2)}, 351 | (("a", (None, None, 3)), ("b", (1, 10, 2))), 352 | ), 353 | (None, None), 354 | ], 355 | ids=lambda x: f"{x}", 356 | ) 357 | def test_freeze_unfreeze_indexers(unfrozen_indexers, frozen_indexers): 358 | assert freeze_indexers(unfrozen_indexers) == frozen_indexers 359 | assert unfreeze_indexers(frozen_indexers) == unfrozen_indexers 360 | 361 | 362 | @pytest.mark.parametrize( 363 | ("a", "b"), 364 | [ 365 | ( 366 | {"a": slice(None, None, 3), "b": slice(1, 10, 2)}, 367 | {"b": slice(1, 10, 2), "a": slice(None, None, 3)}, 368 | ), 369 | (None, None), 370 | ], 371 | ids=lambda x: f"{x}", 372 | ) 373 | def test_hashability_of_frozen_indexers(a, b): 374 | assert a == b 375 | frozen_indexers_a = freeze_indexers(a) 376 | frozen_indexers_b = freeze_indexers(b) 377 | 378 | # Despite having different key orders, the hashes of the frozen indexers 379 | # should be equal. 380 | assert hash(frozen_indexers_a) == hash(frozen_indexers_b) 381 | 382 | 383 | @pytest.mark.parametrize( 384 | ("collect_variable_writes", "expected_computes"), [(False, 9), (True, 3)] 385 | ) 386 | def test_dataset_mappable_write_minimizes_compute_calls( 387 | tmpdir, collect_variable_writes, expected_computes 388 | ): 389 | # This tests to ensure that calls to compute are minimized when writing 390 | # partitioned Datasets. Previously, a compute was called separately for 391 | # each variable in the Dataset. For fields that have common intermediates -- 392 | # e.g. loading a particular variable from somewhere -- this is inefficient, 393 | # because it means these intermediates must be computed multiple times. If 394 | # the option to collect_variable_writes is turned, however, we expect more 395 | # computes to be called (one for each partition and data variable in the 396 | # Dataset). 397 | store = os.path.join(tmpdir, "test.zarr") 398 | 399 | foo = _construct_dataarray((2, 9), (2, 3), "foo") 400 | bar = (2 * foo).rename("bar") 401 | ds = xr.merge([foo, bar]) 402 | 403 | ds.partition.initialize_store(store) 404 | scheduler = CountingScheduler() 405 | 406 | with dask.config.set(scheduler=scheduler): 407 | ranks = 3 408 | for rank in range(ranks): 409 | ds.partition.write(store, ranks, ds.dims, rank, collect_variable_writes) 410 | 411 | assert scheduler.total_computes == expected_computes 412 | 413 | result = xr.open_zarr(store) 414 | xr.testing.assert_identical(result, ds) 415 | 416 | 417 | @pytest.fixture() 418 | def mixed_ds(): 419 | dims = ["x_unchunked", "y_unchunked"] 420 | coords = [range(3), range(5)] 421 | data = np.zeros((3, 5)) 422 | template_unchunked = xr.DataArray(data, coords, dims) 423 | template_chunked = xr.DataArray(data, coords, dims).chunk({"x_unchunked": 1}) 424 | 425 | data_var_unchunked = template_unchunked.copy(deep=True).rename("data_var_unchunked") 426 | data_var_chunked = template_chunked.copy(deep=True).rename("data_var_chunked") 427 | coord_unchunked = template_unchunked.copy(deep=True).rename("coord_unchunked") 428 | coord_chunked = template_chunked.copy(deep=True).rename("coord_chunked") 429 | 430 | ds = xr.merge([data_var_chunked, data_var_unchunked]) 431 | ds = ds.assign_coords(coord_unchunked=coord_unchunked, coord_chunked=coord_chunked) 432 | return ds 433 | 434 | 435 | def test_get_unchunked_variable_names(mixed_ds): 436 | expected = {"x_unchunked", "y_unchunked", "data_var_unchunked", "coord_unchunked"} 437 | result = set(get_unchunked_variable_names(mixed_ds)) 438 | assert result == expected 439 | 440 | 441 | def test_get_unchunked_non_dimension_coord_names(mixed_ds): 442 | expected = {"coord_unchunked"} 443 | result = set(get_unchunked_non_dimension_coord_names(mixed_ds)) 444 | assert result == expected 445 | 446 | 447 | def test_get_unchunked_data_var_names(mixed_ds): 448 | expected = {"data_var_unchunked"} 449 | result = set(get_unchunked_data_var_names(mixed_ds)) 450 | assert result == expected 451 | 452 | 453 | def test_validate_PartitionMapper_dataset(mixed_ds): 454 | with pytest.raises(ValueError, match="The PartitionMapper approach"): 455 | validate_PartitionMapper_dataset(mixed_ds) 456 | 457 | 458 | @pytest.mark.parametrize( 459 | ("mode", "raises_on_existing"), [(None, True), ("w-", True), ("w", False)] 460 | ) 461 | def test_mode(tmpdir, ds, mode, raises_on_existing): 462 | store = os.path.join(tmpdir, "test.zarr") 463 | ds.to_zarr(store) 464 | 465 | if raises_on_existing: 466 | with pytest.raises(FileExistsError): 467 | ds.partition.initialize_store(store, mode=mode) 468 | else: 469 | ranks = 3 470 | ds.partition.initialize_store(store, mode=mode) 471 | for rank in range(ranks): 472 | ds.partition.write(store, ranks, ds.dims, rank) 473 | 474 | result = xr.open_zarr(store) 475 | xr.testing.assert_identical(result, ds) 476 | 477 | 478 | @pytest.mark.parametrize("zarr_format", [None, 2, 3]) 479 | def test_zarr_format(tmpdir, ds, zarr_format): 480 | store = os.path.join(tmpdir, "test.zarr") 481 | 482 | ranks = 3 483 | ds.partition.initialize_store(store, zarr_format=zarr_format) 484 | for rank in range(ranks): 485 | ds.partition.write(store, ranks, ds.dims, rank) 486 | 487 | result = xr.open_zarr(store) 488 | xr.testing.assert_identical(result, ds) 489 | 490 | expected_zarr_format = 3 if zarr_format is None else zarr_format 491 | group = zarr.open_group(store) 492 | result_zarr_format = group.metadata.zarr_format 493 | assert result_zarr_format == expected_zarr_format 494 | 495 | 496 | @pytest.mark.parametrize( 497 | ("chunks", "expected", "raises", "match"), 498 | [ 499 | ({"a": 2, "b": 4}, (2, 4), False, None), 500 | ({"a": 2, "b": (1, 2, 2)}, None, True, "uniform chunk"), 501 | ({"a": 2, "b": (1, 1, 3)}, None, True, "Final chunk"), 502 | ], 503 | ids=lambda x: f"{x!r}", 504 | ) 505 | def test_get_chunks_encoding(chunks, expected, raises, match): 506 | da = xr.DataArray(np.arange(10).reshape((2, 5)), dims=["a", "b"]) 507 | da = da.chunk(chunks) 508 | if raises: 509 | with pytest.raises(ValueError, match=match): 510 | get_chunks_encoding(da) 511 | else: 512 | result = get_chunks_encoding(da) 513 | assert result == expected 514 | 515 | 516 | @pytest.mark.parametrize( 517 | ("inner_chunks", "dim_sizes", "dim", "expected", "raises"), 518 | [ 519 | ({"a": 1}, {"a": 2}, "a", 1, False), 520 | ({"a": -1}, {"a": 2}, "a", 2, False), 521 | ({"a": -2}, {"a": 2}, "a", None, True), 522 | ], 523 | ids=lambda x: f"{x!r}", 524 | ) 525 | def test_get_inner_chunk_size(inner_chunks, dim_sizes, dim, expected, raises): 526 | if raises: 527 | with pytest.raises(ValueError, match="greater than 0"): 528 | get_inner_chunk_size(inner_chunks, dim_sizes, dim) 529 | else: 530 | result = get_inner_chunk_size(inner_chunks, dim_sizes, dim) 531 | assert result == expected 532 | 533 | 534 | @pytest.mark.parametrize( 535 | ("inner_chunks", "raises"), 536 | [({"a": 1}, False), ({"a": 2}, True)], 537 | ids=lambda x: f"{x!r}", 538 | ) 539 | def test_get_inner_chunks_encoding(inner_chunks, raises): 540 | da = xr.DataArray(np.arange(5), dims=["a"]).chunk({"a": 3}) 541 | if raises: 542 | with pytest.raises(ValueError, match="evenly divide"): 543 | get_inner_chunks_encoding(da, inner_chunks) 544 | else: 545 | expected = (1,) 546 | result = get_inner_chunks_encoding(da, inner_chunks) 547 | assert result == expected 548 | 549 | 550 | def test_sharded_store(tmpdir, ds): 551 | inner_chunks = {"a": 1, "b": 1, "c": 1} 552 | store = os.path.join(tmpdir, "sharded.zarr") 553 | 554 | ranks = 3 555 | ds.partition.initialize_store(store, inner_chunks=inner_chunks) 556 | for rank in range(ranks): 557 | ds.partition.write(store, ranks, ds.dims, rank) 558 | 559 | # Check that initialize_store and write do not mutate the encoding of 560 | # any of the variables in the original Dataset. 561 | for da in {**ds.coords, **ds.data_vars}.values(): 562 | assert "shards" not in da.encoding 563 | assert "chunks" not in da.encoding 564 | 565 | # Check that the chunks and encoding of the loaded Dataset match our 566 | # expectations. 567 | result = xr.open_zarr(store) 568 | for name, original_da in {**ds.coords, **ds.data_vars}.items(): 569 | result_da = result[name] 570 | 571 | if isinstance(result_da.data, dask.array.Array): 572 | stored_chunks = {} 573 | for dim, (size, *_) in zip(result_da.dims, result_da.chunks): 574 | stored_chunks[dim] = size 575 | 576 | if isinstance(original_da.data, dask.array.Array): 577 | expected_chunks = {dim: inner_chunks[dim] for dim in stored_chunks} 578 | expected_shards_encoding = get_chunks_encoding(ds[name]) 579 | expected_chunks_encoding = get_chunks_encoding(result_da) 580 | else: 581 | expected_chunks = result_da.sizes 582 | expected_shards_encoding = None 583 | expected_chunks_encoding = get_chunks_encoding(result_da) 584 | 585 | assert stored_chunks == expected_chunks 586 | assert result_da.encoding["shards"] == expected_shards_encoding 587 | assert result_da.encoding["chunks"] == expected_chunks_encoding 588 | 589 | # Finally check that the written Dataset, modulo chunks, is identical 590 | # to the provided Dataset. 591 | xr.testing.assert_identical(result, ds) 592 | 593 | 594 | def test_inner_chunks_zarr_format_2_error(tmpdir, ds): 595 | inner_chunks = {"a": 1, "b": 1, "c": 1} 596 | zarr_format = 2 597 | store = os.path.join(tmpdir, "sharded.zarr") 598 | 599 | with pytest.raises(ValueError, match="zarr_format=2"): 600 | ds.partition.initialize_store( 601 | store, inner_chunks=inner_chunks, zarr_format=zarr_format 602 | ) 603 | -------------------------------------------------------------------------------- /xpartition/__init__.py: -------------------------------------------------------------------------------- 1 | from . import xpartition 2 | from .xpartition import __version__, zeros_like, PartitionMapper 3 | -------------------------------------------------------------------------------- /xpartition/xarray_utils.py: -------------------------------------------------------------------------------- 1 | # Code in this file has been adapted from xarray. We therefore include a copy 2 | # of the xarray license below. 3 | # Apache License 4 | # Version 2.0, January 2004 5 | # http://www.apache.org/licenses/ 6 | 7 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | # 1. Definitions. 10 | 11 | # "License" shall mean the terms and conditions for use, reproduction, and 12 | # distribution as defined by Sections 1 through 9 of this document. 13 | 14 | # "Licensor" shall mean the copyright owner or entity authorized by the copyright 15 | # owner that is granting the License. 16 | 17 | # "Legal Entity" shall mean the union of the acting entity and all other entities 18 | # that control, are controlled by, or are under common control with that entity. 19 | # For the purposes of this definition, "control" means (i) the power, direct or 20 | # indirect, to cause the direction or management of such entity, whether by 21 | # contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | # outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | # "You" (or "Your") shall mean an individual or Legal Entity exercising 25 | # permissions granted by this License. 26 | 27 | # "Source" form shall mean the preferred form for making modifications, including 28 | # but not limited to software source code, documentation source, and configuration 29 | # files. 30 | 31 | # "Object" form shall mean any form resulting from mechanical transformation or 32 | # translation of a Source form, including but not limited to compiled object code, 33 | # generated documentation, and conversions to other media types. 34 | 35 | # "Work" shall mean the work of authorship, whether in Source or Object form, made 36 | # available under the License, as indicated by a copyright notice that is included 37 | # in or attached to the work (an example is provided in the Appendix below). 38 | 39 | # "Derivative Works" shall mean any work, whether in Source or Object form, that 40 | # is based on (or derived from) the Work and for which the editorial revisions, 41 | # annotations, elaborations, or other modifications represent, as a whole, an 42 | # original work of authorship. For the purposes of this License, Derivative Works 43 | # shall not include works that remain separable from, or merely link (or bind by 44 | # name) to the interfaces of, the Work and Derivative Works thereof. 45 | 46 | # "Contribution" shall mean any work of authorship, including the original version 47 | # of the Work and any modifications or additions to that Work or Derivative Works 48 | # thereof, that is intentionally submitted to Licensor for inclusion in the Work 49 | # by the copyright owner or by an individual or Legal Entity authorized to submit 50 | # on behalf of the copyright owner. For the purposes of this definition, 51 | # "submitted" means any form of electronic, verbal, or written communication sent 52 | # to the Licensor or its representatives, including but not limited to 53 | # communication on electronic mailing lists, source code control systems, and 54 | # issue tracking systems that are managed by, or on behalf of, the Licensor for 55 | # the purpose of discussing and improving the Work, but excluding communication 56 | # that is conspicuously marked or otherwise designated in writing by the copyright 57 | # owner as "Not a Contribution." 58 | 59 | # "Contributor" shall mean Licensor and any individual or Legal Entity on behalf 60 | # of whom a Contribution has been received by Licensor and subsequently 61 | # incorporated within the Work. 62 | 63 | # 2. Grant of Copyright License. 64 | 65 | # Subject to the terms and conditions of this License, each Contributor hereby 66 | # grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 67 | # irrevocable copyright license to reproduce, prepare Derivative Works of, 68 | # publicly display, publicly perform, sublicense, and distribute the Work and such 69 | # Derivative Works in Source or Object form. 70 | 71 | # 3. Grant of Patent License. 72 | 73 | # Subject to the terms and conditions of this License, each Contributor hereby 74 | # grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 75 | # irrevocable (except as stated in this section) patent license to make, have 76 | # made, use, offer to sell, sell, import, and otherwise transfer the Work, where 77 | # such license applies only to those patent claims licensable by such Contributor 78 | # that are necessarily infringed by their Contribution(s) alone or by combination 79 | # of their Contribution(s) with the Work to which such Contribution(s) was 80 | # submitted. If You institute patent litigation against any entity (including a 81 | # cross-claim or counterclaim in a lawsuit) alleging that the Work or a 82 | # Contribution incorporated within the Work constitutes direct or contributory 83 | # patent infringement, then any patent licenses granted to You under this License 84 | # for that Work shall terminate as of the date such litigation is filed. 85 | 86 | # 4. Redistribution. 87 | 88 | # You may reproduce and distribute copies of the Work or Derivative Works thereof 89 | # in any medium, with or without modifications, and in Source or Object form, 90 | # provided that You meet the following conditions: 91 | 92 | # You must give any other recipients of the Work or Derivative Works a copy of 93 | # this License; and 94 | # You must cause any modified files to carry prominent notices stating that You 95 | # changed the files; and 96 | # You must retain, in the Source form of any Derivative Works that You distribute, 97 | # all copyright, patent, trademark, and attribution notices from the Source form 98 | # of the Work, excluding those notices that do not pertain to any part of the 99 | # Derivative Works; and 100 | # If the Work includes a "NOTICE" text file as part of its distribution, then any 101 | # Derivative Works that You distribute must include a readable copy of the 102 | # attribution notices contained within such NOTICE file, excluding those notices 103 | # that do not pertain to any part of the Derivative Works, in at least one of the 104 | # following places: within a NOTICE text file distributed as part of the 105 | # Derivative Works; within the Source form or documentation, if provided along 106 | # with the Derivative Works; or, within a display generated by the Derivative 107 | # Works, if and wherever such third-party notices normally appear. The contents of 108 | # the NOTICE file are for informational purposes only and do not modify the 109 | # License. You may add Your own attribution notices within Derivative Works that 110 | # You distribute, alongside or as an addendum to the NOTICE text from the Work, 111 | # provided that such additional attribution notices cannot be construed as 112 | # modifying the License. 113 | # You may add Your own copyright statement to Your modifications and may provide 114 | # additional or different license terms and conditions for use, reproduction, or 115 | # distribution of Your modifications, or for any such Derivative Works as a whole, 116 | # provided Your use, reproduction, and distribution of the Work otherwise complies 117 | # with the conditions stated in this License. 118 | 119 | # 5. Submission of Contributions. 120 | 121 | # Unless You explicitly state otherwise, any Contribution intentionally submitted 122 | # for inclusion in the Work by You to the Licensor shall be under the terms and 123 | # conditions of this License, without any additional terms or conditions. 124 | # Notwithstanding the above, nothing herein shall supersede or modify the terms of 125 | # any separate license agreement you may have executed with Licensor regarding 126 | # such Contributions. 127 | 128 | # 6. Trademarks. 129 | 130 | # This License does not grant permission to use the trade names, trademarks, 131 | # service marks, or product names of the Licensor, except as required for 132 | # reasonable and customary use in describing the origin of the Work and 133 | # reproducing the content of the NOTICE file. 134 | 135 | # 7. Disclaimer of Warranty. 136 | 137 | # Unless required by applicable law or agreed to in writing, Licensor provides the 138 | # Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, 139 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, 140 | # including, without limitation, any warranties or conditions of TITLE, 141 | # NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are 142 | # solely responsible for determining the appropriateness of using or 143 | # redistributing the Work and assume any risks associated with Your exercise of 144 | # permissions under this License. 145 | 146 | # 8. Limitation of Liability. 147 | 148 | # In no event and under no legal theory, whether in tort (including negligence), 149 | # contract, or otherwise, unless required by applicable law (such as deliberate 150 | # and grossly negligent acts) or agreed to in writing, shall any Contributor be 151 | # liable to You for damages, including any direct, indirect, special, incidental, 152 | # or consequential damages of any character arising as a result of this License or 153 | # out of the use or inability to use the Work (including but not limited to 154 | # damages for loss of goodwill, work stoppage, computer failure or malfunction, or 155 | # any and all other commercial damages or losses), even if such Contributor has 156 | # been advised of the possibility of such damages. 157 | 158 | # 9. Accepting Warranty or Additional Liability. 159 | 160 | # While redistributing the Work or Derivative Works thereof, You may choose to 161 | # offer, and charge a fee for, acceptance of support, warranty, indemnity, or 162 | # other liability obligations and/or rights consistent with this License. However, 163 | # in accepting such obligations, You may act only on Your own behalf and on Your 164 | # sole responsibility, not on behalf of any other Contributor, and only if You 165 | # agree to indemnify, defend, and hold each Contributor harmless for any liability 166 | # incurred by, or claims asserted against, such Contributor by reason of your 167 | # accepting any such warranty or additional liability. 168 | 169 | # END OF TERMS AND CONDITIONS 170 | 171 | # APPENDIX: How to apply the Apache License to your work 172 | 173 | # To apply the Apache License to your work, attach the following boilerplate 174 | # notice, with the fields enclosed by brackets "[]" replaced with your own 175 | # identifying information. (Don't include the brackets!) The text should be 176 | # enclosed in the appropriate comment syntax for the file format. We also 177 | # recommend that a file or class name and description of purpose be included on 178 | # the same "printed page" as the copyright notice for easier identification within 179 | # third-party archives. 180 | 181 | # Copyright 2014-2024 xarray Developers 182 | 183 | # Licensed under the Apache License, Version 2.0 (the "License"); 184 | # you may not use this file except in compliance with the License. 185 | # You may obtain a copy of the License at 186 | 187 | # http://www.apache.org/licenses/LICENSE-2.0 188 | 189 | # Unless required by applicable law or agreed to in writing, software 190 | # distributed under the License is distributed on an "AS IS" BASIS, 191 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 192 | # See the License for the specific language governing permissions and 193 | # limitations under the License. 194 | import dask 195 | import xarray as xr 196 | 197 | from typing import Tuple 198 | 199 | 200 | def get_chunks_encoding(da: xr.DataArray) -> Tuple[int, ...]: 201 | # Code adapted from xarray.backends.zarr._determine_zarr_chunks 202 | if any(len(set(chunks[:-1])) > 1 for chunks in da.chunks): 203 | raise ValueError( 204 | f"Zarr requires uniform chunk sizes except for final chunk. " 205 | f"Variable named {da.name!r} has incompatible dask chunks: " 206 | f"{da.chunks!r}. Consider rechunking using `chunk()`." 207 | ) 208 | if any((chunks[0] < chunks[-1]) for chunks in da.chunks): 209 | raise ValueError( 210 | f"Final chunk of Zarr array must be the same size or smaller " 211 | f"than the first. Variable named {da.name!r} has incompatible " 212 | f"chunks: {da.chunks!r}. Consider rechunking using `chunk()`." 213 | ) 214 | return tuple(chunk_size for chunk_size, *_ in da.chunks) 215 | 216 | 217 | class CountingScheduler: 218 | # Code adapted from xarray.tests.__init__.py 219 | 220 | def __init__(self): 221 | self.total_computes = 0 222 | 223 | def __call__(self, dsk, keys, **kwargs): 224 | self.total_computes += 1 225 | return dask.get(dsk, keys, **kwargs) 226 | -------------------------------------------------------------------------------- /xpartition/xpartition.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import dataclasses 3 | import functools 4 | import logging 5 | import math 6 | from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union 7 | 8 | import dask.array 9 | import numpy as np 10 | import xarray as xr 11 | 12 | from xpartition.xarray_utils import get_chunks_encoding 13 | 14 | __version__ = "2025.03.0" 15 | 16 | 17 | Region = Union[None, Mapping[Hashable, slice]] 18 | Partition = Sequence[Region] 19 | HashableSlice = Tuple[Union[None, int], Union[None, int], Union[None, int]] 20 | HashableIndexers = Union[None, Tuple[Tuple[Hashable, HashableSlice], ...]] 21 | 22 | 23 | def _is_integer(value): 24 | """Check if a value is a Python or NumPy integer instance.""" 25 | return isinstance(value, (int, np.integer)) 26 | 27 | 28 | def _convert_scalars_to_slices(indexers): 29 | """Convert a dict of xarray dimension-index pairs to solely use slices. 30 | 31 | Assumes that the index values have been validated already in 32 | _validate_indexers. 33 | 34 | Parameters 35 | ---------- 36 | indexers : dict 37 | Dictionary mapping dimension names to integers or slices. 38 | 39 | Returns 40 | ------- 41 | dict 42 | """ 43 | result = {} 44 | for k, v in indexers.items(): 45 | if isinstance(v, slice): 46 | result[k] = v 47 | else: 48 | if v == -1: 49 | result[k] = slice(v, None) 50 | else: 51 | result[k] = slice(v, v + 1) 52 | return result 53 | 54 | 55 | def _validate_indexers(indexers, sizes): 56 | """Check that indexers for an array with given sizes are valid. 57 | 58 | xpartition does not support indexing the blocks with non-contiguous array 59 | regions, e.g. with slices that skip elements. It also does not support 60 | indexing with anything other than an integer or slice along a dimension. 61 | 62 | Parameters 63 | ---------- 64 | indexers : dict 65 | Dictionary mapping dimension names to possible indexers. 66 | sizes : dict 67 | Dictionary mapping dimension names to sizes of the array. 68 | 69 | Raises 70 | ------ 71 | KeyError, IndexError, NotImplementedError, or ValueError depending on the 72 | context. 73 | """ 74 | for k, v in indexers.items(): 75 | if k not in sizes: 76 | raise KeyError(f"Dimension {k!r} is not a valid dimension.") 77 | elif _is_integer(v): 78 | if abs(v) > sizes[k] - 1: 79 | raise IndexError( 80 | f"Index {v} is out of bounds for dimension {k!r} of length {sizes[k]}." 81 | ) 82 | elif isinstance(v, slice): 83 | if v.step is not None and v.step != 1: 84 | raise NotImplementedError( 85 | "xpartition does not support indexing with slices with a step size different than None or 1." 86 | ) 87 | else: 88 | raise ValueError(f"Invalid indexer provided for dim {k!r}: {v}.") 89 | 90 | 91 | def _convert_block_indexers_to_array_indexers(block_indexers, chunks): 92 | """Convert a dict of dask block indexers to array indexers. 93 | 94 | Parameters 95 | ---------- 96 | block_indexers : dict 97 | Dictionary mapping dimension names to slices. The slices 98 | represent slices in dask block space. 99 | chunks : dict 100 | Dictionary mapping dimension names to tuples representing 101 | the chunk structure of the given dimension. 102 | 103 | Returns 104 | ------- 105 | dict 106 | """ 107 | array_indexers = {} 108 | for dim, block_indexer in block_indexers.items(): 109 | if block_indexer.start is None: 110 | start = 0 111 | else: 112 | start = sum(chunks[dim][: block_indexer.start]) 113 | stop = sum(chunks[dim][: block_indexer.stop]) 114 | array_indexers[dim] = slice(start, stop) 115 | return array_indexers 116 | 117 | 118 | @xr.register_dataarray_accessor("blocks") 119 | class BlocksAccessor: 120 | def __init__(self, xarray_obj): 121 | self._obj = xarray_obj 122 | if not isinstance(self._obj.data, dask.array.Array): 123 | raise ValueError( 124 | "The blocks accessor is only valid for dask-backed arrays." 125 | ) 126 | 127 | @property 128 | def _chunks(self) -> Dict[Hashable, Tuple[int, ...]]: 129 | return {dim: self._obj.chunks[k] for k, dim in enumerate(self._obj.dims)} 130 | 131 | @property 132 | def shape(self) -> Tuple[int, ...]: 133 | return tuple(len(c) for c in self._obj.chunks) 134 | 135 | @property 136 | def sizes(self) -> Dict[Hashable, int]: 137 | return {dim: size for dim, size in zip(self._obj.dims, self.shape)} 138 | 139 | def indexers(self, **block_indexers) -> Region: 140 | """Return a dict of array indexers that correspond to the block indexers. 141 | 142 | Parameters 143 | ---------- 144 | **block_indexers 145 | Dimension-indexer pairs in dask block space. These can be integers 146 | or contiguous slices. 147 | 148 | Returns 149 | ------- 150 | dict 151 | 152 | Examples 153 | -------- 154 | >>> import xarray as xr; import dask.array as darray; import xpartition 155 | >>> arr = darray.zeros((10, 20), chunks=(2, 5)) 156 | >>> da = xr.DataArray(arr, dims=["x", "y"], name="foo") 157 | >>> da 158 | 159 | dask.array 160 | Dimensions without coordinates: x, y 161 | >>> da.blocks.indexers(x=2, y=3) 162 | {'x': slice(4, 6, None), 'y': slice(15, 20, None)} 163 | >>> da.blocks.indexers(x=2) 164 | {'x': slice(4, 6, None)} 165 | >>> da.blocks.indexers(x=slice(None, None)) 166 | {'x': slice(0, 10, None)} 167 | >>> da.blocks.indexers(x=slice(None, 3)) 168 | {'x': slice(0, 6, None)} 169 | >>> da.blocks.indexers(x=slice(3, None)) 170 | {'x': slice(6, 10, None)} 171 | >>> da.blocks.indexers(x=2, y=slice(0, 2)) 172 | {'x': slice(4, 6, None), 'y': slice(0, 10, None)} 173 | """ 174 | _validate_indexers(block_indexers, self.sizes) 175 | block_indexers = _convert_scalars_to_slices(block_indexers) 176 | return _convert_block_indexers_to_array_indexers(block_indexers, self._chunks) 177 | 178 | def isel(self, **block_indexers) -> xr.DataArray: 179 | slices = self.indexers(**block_indexers) 180 | # TODO: should we squeeze out dimensions where scalars were passed? 181 | return self._obj.isel(slices) 182 | 183 | 184 | def _write_partition_dataarray( 185 | da: xr.DataArray, store: str, ranks: int, dims: Sequence[Hashable], rank: int 186 | ): 187 | ds = da.drop_vars(da.coords).to_dataset() 188 | partition = da.partition.indexers(ranks, rank, dims) 189 | if partition is not None: 190 | ds.isel(partition).to_zarr(store, region=partition) 191 | 192 | 193 | def freeze_indexers(indexers: Region) -> HashableIndexers: 194 | """Return an immutable (hashable) version of the indexers.""" 195 | if indexers is None: 196 | return indexers 197 | else: 198 | immutable = ((k, (s.start, s.stop, s.step)) for k, s in indexers.items()) 199 | return tuple(sorted(immutable, key=lambda x: x[0])) 200 | 201 | 202 | def unfreeze_indexers(frozen_indexers: HashableIndexers) -> Region: 203 | """Convert an immutable version of the indexers back to its usual type.""" 204 | if frozen_indexers is None: 205 | return frozen_indexers 206 | else: 207 | return {k: slice(*s) for k, s in frozen_indexers} 208 | 209 | 210 | def _collect_by_partition( 211 | ds: xr.Dataset, ranks: int, dims: Sequence[Hashable], rank: int 212 | ) -> Sequence[Tuple[Region, xr.Dataset]]: 213 | """Return a list of pairs of partitions and Datasets containing 214 | DataArrays that can be written out to those partitions. 215 | """ 216 | dataarrays = collections.defaultdict(list) 217 | for da in {**ds.coords, **ds.data_vars}.values(): 218 | if isinstance(da.data, dask.array.Array): 219 | partition_dims = [dim for dim in dims if dim in da.dims] 220 | indexers = da.partition.indexers(ranks, rank, partition_dims) 221 | dataarrays[freeze_indexers(indexers)].append(da.drop_vars(da.coords)) 222 | return [(unfreeze_indexers(k), xr.merge(v)) for k, v in dataarrays.items()] 223 | 224 | 225 | def _write_partition_dataset_via_individual_variables( 226 | ds: xr.Dataset, store: str, ranks: int, dims: Sequence[Hashable], rank: int 227 | ): 228 | for da in {**ds.coords, **ds.data_vars}.values(): 229 | if isinstance(da.data, dask.array.Array): 230 | partition_dims = [dim for dim in dims if dim in da.dims] 231 | _write_partition_dataarray(da, store, ranks, partition_dims, rank) 232 | 233 | 234 | def _write_partition_dataset_via_collected_variables( 235 | ds: xr.Dataset, store: str, ranks: int, dims: Sequence[Hashable], rank: int 236 | ): 237 | collected_by_partition = _collect_by_partition(ds, ranks, dims, rank) 238 | for partition, d in collected_by_partition: 239 | if partition is not None: 240 | d.isel(partition).to_zarr(store, region=partition) 241 | 242 | 243 | class Map(Sequence): 244 | """Lazy sequence""" 245 | 246 | def __init__(self, func, seq): 247 | self.seq = seq 248 | self.func = func 249 | 250 | def __getitem__(self, i): 251 | return self.func(self.seq[i]) 252 | 253 | def __len__(self): 254 | return len(self.seq) 255 | 256 | 257 | @xr.register_dataarray_accessor("partition") 258 | class PartitionDataArrayAccessor: 259 | def __init__(self, xarray_obj): 260 | self._obj = xarray_obj 261 | if not isinstance(self._obj.data, dask.array.Array): 262 | raise ValueError( 263 | "The partition accessor is only valid for dask-backed arrays." 264 | ) 265 | 266 | def _meta_array(self, chunks: Dict[Hashable, int]) -> xr.DataArray: 267 | dummy_data = dask.array.zeros(self._obj.blocks.shape) 268 | da = xr.DataArray(dummy_data, dims=self._obj.dims, name="blocks") 269 | return da.chunk(chunks) 270 | 271 | def _optimal_meta_chunk_sizes( 272 | self, ranks: int, dims: Sequence[Hashable] 273 | ) -> Dict[Hashable, int]: 274 | """Determine the optimal meta chunk sizes for the DataArray. 275 | 276 | Partitions are prioritized based on the ordering of the dims 277 | provided. Priority means we will first make the meta chunk 278 | size one along those dimensions before moving to larger meta 279 | chunk sizes. 280 | 281 | Parameters 282 | ---------- 283 | ranks : int 284 | Total number of ranks available to partition across. 285 | dims : Sequence[Hashable] 286 | Dimensions to partition among; if a dimension is left out 287 | no partitions will be made along that dimension. 288 | 289 | Returns 290 | ------- 291 | Dict[Hashable, int] 292 | """ 293 | chunk_sizes = {} 294 | for dim in dims: 295 | block_sizes = [] 296 | for d, s in chunk_sizes.items(): 297 | block_size = math.ceil(self._obj.blocks.sizes[d] / s) 298 | block_sizes.append(block_size) 299 | blocks = np.prod(block_sizes) 300 | size = math.ceil(self._obj.blocks.sizes[dim] / (ranks // blocks)) 301 | chunk_sizes[dim] = min(size, self._obj.blocks.sizes[dim]) 302 | return chunk_sizes 303 | 304 | def partition(self, ranks, dims) -> Partition: 305 | """Compute a ranks-sized partition respecting dask block boundaries 306 | 307 | Parameters 308 | ---------- 309 | ranks : int 310 | Total number of ranks available to partition across. 311 | dims : Sequence[Hashable] 312 | Dimensions to partition among; if a dimension is left out 313 | no partitions will be made along that dimension. 314 | 315 | Returns 316 | ------- 317 | a list of disjoint regions whose union is the full coordinate space 318 | """ 319 | return Map(functools.partial(self._indexers, ranks, dims), list(range(ranks))) 320 | 321 | def _indexers(self, ranks, dims, rank): 322 | """Needed for creating a partial function within the partition method.""" 323 | return self.indexers(ranks, rank, dims) 324 | 325 | def indexers(self, ranks: int, rank: int, dims: Sequence[Hashable]) -> Region: 326 | """Partition the dask blocks across the given dims. 327 | 328 | Parameters 329 | ---------- 330 | ranks : int 331 | Total number of ranks available to partition across. 332 | rank : int 333 | Specific rank to obtain the indexers for. 334 | dims : Sequence[Hashable] 335 | Dimensions to partition among; if a dimension is left out 336 | no partitions will be made along that dimension. 337 | 338 | Returns 339 | ------- 340 | Dict[Hashable, slice] 341 | """ 342 | if rank >= ranks: 343 | raise ValueError(f"Rank {rank} is greater than maximum rank {ranks - 1}.") 344 | 345 | meta_chunk_sizes = self._optimal_meta_chunk_sizes(ranks, dims) 346 | meta_array = self._meta_array(meta_chunk_sizes) 347 | try: 348 | meta_indices = np.unravel_index(rank, meta_array.blocks.shape) 349 | except ValueError: 350 | return None 351 | else: 352 | meta_indexers = dict(zip(meta_array.dims, meta_indices)) 353 | dask_indexers = meta_array.blocks.indexers(**meta_indexers) 354 | return self._obj.blocks.indexers(**dask_indexers) 355 | 356 | def write( 357 | self, 358 | store: str, 359 | ranks: int, 360 | dims: Sequence[Hashable], 361 | rank: int, 362 | collect_variable_writes: bool = False, 363 | ): 364 | self.to_dataset().partition.write( 365 | store, ranks, dims, rank, collect_variable_writes 366 | ) 367 | 368 | def mappable_write( 369 | self, 370 | store: str, 371 | ranks: int, 372 | dims: Sequence[Hashable], 373 | collect_variable_writes: bool = False, 374 | ) -> Callable[[int], None]: 375 | return self._obj.to_dataset().partition.mappable_write( 376 | store, ranks, dims, collect_variable_writes 377 | ) 378 | 379 | @property 380 | def _chunks(self): 381 | return {dim: self._obj.chunks[k] for k, dim in enumerate(self._obj.dims)} 382 | 383 | def map( 384 | self, store: str, ranks: int, dims: Sequence[Hashable], func, data 385 | ) -> "PartitionMapper": 386 | plan = _ValidWorkPlan(self, ranks, dims) 387 | return PartitionMapper(plan, func, data, store) 388 | 389 | 390 | @xr.register_dataset_accessor("partition") 391 | class PartitionDatasetAccessor: 392 | def __init__(self, xarray_obj): 393 | self._obj = xarray_obj 394 | 395 | def initialize_store( 396 | self, 397 | store: str, 398 | inner_chunks: Optional[Dict[Hashable, int]] = None, 399 | mode: Optional[str] = None, 400 | zarr_format: Optional[int] = None, 401 | ): 402 | """Initialize a zarr store for partitioned writes. 403 | 404 | The ``inner_chunks`` and ``zarr_format`` parameters provided here 405 | will automatically be applied in the ``write`` step, as they are 406 | encoded on disk in the initialization process. 407 | 408 | Parameters 409 | ---------- 410 | store : str 411 | Path to zarr store. 412 | inner_chunks : dict (optional) 413 | Dictionary mapping dimension names to inner chunk sizes for writing 414 | a sharded zarr store. Outer chunks (a.k.a. shards) will be inferred 415 | from the dask chunks on the variables in the Dataset. If not 416 | provided, a standard unsharded zarr store will be written, whose 417 | chunks will correspond to the dask chunks. 418 | mode : str or None 419 | ``mode`` to pass through to :py:meth:`xarray.Dataset.to_zarr`. 420 | zarr_format : int or None 421 | ``zarr_format`` to pass through to :py:meth:`xarray.Dataset.to_zarr`. 422 | """ 423 | ds = self._obj 424 | if inner_chunks is not None: 425 | if zarr_format == 2: 426 | raise ValueError( 427 | "It is not possible to specify inner_chunks when zarr_format=2. " 428 | "Sharded stores are only possible with zarr version 3." 429 | ) 430 | ds = set_shards_and_chunks_encoding(ds, inner_chunks) 431 | ds.to_zarr(store, compute=False, mode=mode, zarr_format=zarr_format) 432 | 433 | def write( 434 | self, 435 | store: str, 436 | ranks: int, 437 | dims: Sequence[Hashable], 438 | rank: int, 439 | collect_variable_writes: bool = False, 440 | ): 441 | """Write a Dataset partition to disk on a given rank. 442 | 443 | Parameters 444 | ---------- 445 | store : str 446 | Path to zarr store. 447 | ranks : int 448 | Total number of ranks available to partition across. 449 | dims : Sequence[Hashable] 450 | Dimensions to partition among; if a dimension is left out 451 | no partitions will be made along that dimension. 452 | rank : int 453 | Rank of process to write partition from. 454 | collect_variable_writes : bool 455 | Whether to collect data variables with like partition indexers 456 | together when writing data out to disk (default False). It can 457 | be beneficial to set this to True if data variables in the Dataset 458 | have like chunk structure, and also share intermediate data. An 459 | example of this would be two fields that derive from the same 460 | input data. By default this input data would need be computed or 461 | loaded twice; with this option set to True, it the input data would 462 | only need to be computed or loaded once. A caveat, however, is that 463 | it can increase memory usage. 464 | """ 465 | if collect_variable_writes: 466 | f = _write_partition_dataset_via_collected_variables 467 | else: 468 | f = _write_partition_dataset_via_individual_variables 469 | f(self._obj, store, ranks, dims, rank) 470 | 471 | def mappable_write( 472 | self, 473 | store: str, 474 | ranks: int, 475 | dims: Sequence[Hashable], 476 | collect_variable_writes: bool = False, 477 | ) -> Callable[[int], None]: 478 | """Return a function that can write data for a partition on a rank. 479 | 480 | Parameters 481 | ---------- 482 | store : str 483 | Path to zarr store. 484 | ranks : int 485 | Total number of ranks available to partition across. 486 | dims : Sequence[Hashable] 487 | Dimensions to partition among; if a dimension is left out 488 | no partitions will be made along that dimension. 489 | collect_variable_writes : bool 490 | Whether to collect data variables with like partition indexers 491 | together when writing data out to disk (default False). It can 492 | be beneficial to set this to True if data variables in the Dataset 493 | have like chunk structure, and also share intermediate data. An 494 | example of this would be two fields that derive from the same 495 | input data. By default this input data would need be computed or 496 | loaded twice; with this option set to True, it the input data would 497 | only need to be computed or loaded once. A caveat, however, is that 498 | it can increase memory usage. 499 | 500 | Returns 501 | ------- 502 | function 503 | """ 504 | if collect_variable_writes: 505 | f = _write_partition_dataset_via_collected_variables 506 | else: 507 | f = _write_partition_dataset_via_individual_variables 508 | return functools.partial(f, self._obj, store, ranks, dims) 509 | 510 | 511 | def _merge_chunks(arr, override_chunks): 512 | chunks_to_update = {} 513 | for dim, sizes in override_chunks.items(): 514 | if dim in arr.dims: 515 | axis = arr.get_axis_num(dim) 516 | chunks_to_update[axis] = sizes 517 | original_chunks = {axis: sizes for axis, sizes in enumerate(arr.chunks)} 518 | return {**original_chunks, **chunks_to_update} 519 | 520 | 521 | def _zeros_like_dataarray(arr, override_chunks): 522 | if override_chunks is None: 523 | override_chunks = {} 524 | chunks = _merge_chunks(arr, override_chunks) 525 | return xr.apply_ufunc( 526 | dask.array.zeros_like, arr, kwargs=dict(chunks=chunks), dask="allowed" 527 | ) 528 | 529 | 530 | def zeros_like(ds: xr.Dataset, override_chunks=None): 531 | """Performant implementation of zeros_like. 532 | 533 | xr.zeros_like(ds).chunk(chunks) is very slow for datasets with many 534 | changes. 535 | 536 | Parameters 537 | ---------- 538 | ds : xr.Dataset 539 | Input dataset with dask-backed data variables. 540 | override_chunks : dict 541 | Dimension chunk-size pairs indicating any dimensions one would like to 542 | override the original chunk sizes along. For any dimensions that are not 543 | present, zeros_like will use the chunk size along that dimension for each 544 | variable in the input Dataset. 545 | 546 | Returns 547 | ------- 548 | xr.Dataset 549 | """ 550 | return ds.apply( 551 | _zeros_like_dataarray, override_chunks=override_chunks, keep_attrs=True 552 | ) 553 | 554 | 555 | class _ValidWorkPlan: 556 | """A mapping between input and output partitionings that will 557 | avoid race conditions in parallel jobs 558 | """ 559 | 560 | def __init__(self, partitioner, ranks: int, dims: Sequence[Hashable]): 561 | self._partitioner = partitioner 562 | self._ranks = ranks 563 | self.dims = dims 564 | 565 | @property 566 | def output_chunks(self): 567 | return {dim: self._partitioner._chunks[dim] for dim in self.dims} 568 | 569 | @property 570 | def input_partition(self): 571 | return self._partitioner.partition(self._ranks, self.dims) 572 | 573 | 574 | def get_unchunked_variable_names(ds): 575 | unchunked = [] 576 | for name, variable in ds.variables.items(): 577 | if isinstance(variable.data, np.ndarray): 578 | unchunked.append(name) 579 | return unchunked 580 | 581 | 582 | def get_unchunked_non_dimension_coord_names(ds): 583 | names = [] 584 | for name, da in ds.coords.items(): 585 | if name not in ds.dims and isinstance(da.data, np.ndarray): 586 | names.append(name) 587 | return names 588 | 589 | 590 | def get_unchunked_data_var_names(ds): 591 | names = [] 592 | for name, da in ds.data_vars.items(): 593 | if isinstance(da.data, np.ndarray): 594 | names.append(name) 595 | return names 596 | 597 | 598 | def validate_PartitionMapper_dataset(ds): 599 | unchunked_non_dimension_coords = get_unchunked_non_dimension_coord_names(ds) 600 | unchunked_data_vars = get_unchunked_data_var_names(ds) 601 | invalid_unchunked_vars = unchunked_non_dimension_coords + unchunked_data_vars 602 | if invalid_unchunked_vars: 603 | raise ValueError( 604 | f"The PartitionMapper approach does not support writing datasets that " 605 | f"contain unchunked non-dimension coordinates or data variables. " 606 | f"Consider dropping or chunking these before initiating the write or " 607 | f"switching to the traditional xpartition writing approach. The " 608 | f"variables in question are {invalid_unchunked_vars!r}." 609 | ) 610 | 611 | 612 | @dataclasses.dataclass 613 | class PartitionMapper: 614 | """Evaluate a function on each region of a partition and store the output 615 | to a zarr store 616 | """ 617 | 618 | plan: _ValidWorkPlan 619 | func: Callable[[xr.Dataset], xr.Dataset] 620 | data: xr.Dataset 621 | path: str 622 | 623 | @property 624 | def dims(self): 625 | return self.plan.dims 626 | 627 | def _initialize_store(self): 628 | region = self.plan.input_partition[0] 629 | iData = self.data.isel(region) 630 | iOut = self.func(iData) 631 | validate_PartitionMapper_dataset(iOut) 632 | 633 | full_indexers = {dim: self.data[dim] for dim in self.dims} 634 | 635 | dims_without_coords = (set(iOut.dims) - set(iOut.indexes)) & set(self.dims) 636 | for dim in dims_without_coords: 637 | iOut = iOut.assign_coords({dim: iOut[dim]}) 638 | 639 | schema = zeros_like( 640 | iOut.reindex(full_indexers), override_chunks=self.plan.output_chunks 641 | ) 642 | schema = schema.drop_vars(dims_without_coords) 643 | schema.partition.initialize_store(self.path) 644 | 645 | def write(self, rank): 646 | logging.info(f"Writing {rank + 1} of {len(self.plan.input_partition)}") 647 | region = self.plan.input_partition[rank] 648 | iData = self.data.isel(region) 649 | iOut = self.func(iData) 650 | unchunked_variables = get_unchunked_variable_names(iOut) 651 | iOut.drop_vars(unchunked_variables).to_zarr(self.path, region=region) 652 | logging.info(f"Done writing {rank + 1}.") 653 | 654 | def __iter__(self): 655 | self._initialize_store() 656 | return iter(range(len(self.plan.input_partition))) 657 | 658 | 659 | def get_inner_chunk_size( 660 | inner_chunks: Dict[Hashable, int], dim_sizes: Dict[Hashable, int], dim: Hashable 661 | ) -> int: 662 | chunk_size = inner_chunks.get(dim, dim_sizes[dim]) 663 | 664 | if chunk_size > 0 or chunk_size == -1: 665 | chunk_size = chunk_size if chunk_size > 0 else dim_sizes[dim] 666 | else: 667 | raise ValueError( 668 | f"Inner chunk size must be greater than 0 or be equal to -1; got chunk " 669 | f"size {chunk_size} along dim {dim!r}." 670 | ) 671 | return chunk_size 672 | 673 | 674 | def get_inner_chunks_encoding( 675 | da: xr.DataArray, inner_chunks: Dict[Hashable, int] 676 | ) -> Tuple[int, ...]: 677 | shards = dict(zip(da.dims, get_chunks_encoding(da))) 678 | 679 | chunks = [] 680 | for dim in da.dims: 681 | chunk_size = get_inner_chunk_size(inner_chunks, da.sizes, dim) 682 | if shards[dim] % chunk_size == 0: 683 | chunks.append(chunk_size) 684 | else: 685 | raise ValueError( 686 | f"Inner chunk size ({chunk_size}) for dimension {dim!r} does not " 687 | f"evenly divide shard size ({shards[dim]}) for DataArray " 688 | f"{da.name!r}." 689 | ) 690 | return tuple(chunks) 691 | 692 | 693 | def set_shards_and_chunks_encoding( 694 | ds: xr.Dataset, inner_chunks: Dict[Hashable, int] 695 | ) -> xr.Dataset: 696 | # Make a shallow copy to avoid mutating the encoding of the input dataset. 697 | ds = ds.copy(deep=False) 698 | 699 | for da in {**ds.coords, **ds.data_vars}.values(): 700 | if isinstance(da.data, dask.array.Array): 701 | shards = get_chunks_encoding(da) 702 | chunks = get_inner_chunks_encoding(da, inner_chunks) 703 | da.encoding["shards"] = shards 704 | da.encoding["chunks"] = chunks 705 | return ds 706 | --------------------------------------------------------------------------------