├── .github
└── workflows
│ └── ci-build.yml
├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── conftest.py
├── docs
├── Makefile
├── _static
│ ├── README.md
│ ├── xarray-beam-logo.png
│ └── xarray-beam-vs-xarray-dask.png
├── aggregation.ipynb
├── api.md
├── conf.py
├── data-model.ipynb
├── index.md
├── make.bat
├── read-write.ipynb
├── rechunking.ipynb
├── requirements.txt
└── why-xarray-beam.md
├── examples
├── README.md
├── __init__.py
├── era5_climatology.py
├── era5_climatology_test.py
├── xbeam_rechunk.py
└── xbeam_rechunk_test.py
├── pyproject.toml
├── setup.py
└── xarray_beam
├── __init__.py
└── _src
├── __init__.py
├── combiners.py
├── combiners_test.py
├── core.py
├── core_test.py
├── dataset.py
├── dataset_test.py
├── integration_test.py
├── rechunk.py
├── rechunk_test.py
├── test_util.py
├── threadmap.py
├── threadmap_test.py
├── zarr.py
└── zarr_test.py
/.github/workflows/ci-build.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | # Triggers the workflow on push or pull request events but only for the main branch
5 | push:
6 | branches: [ main ]
7 | pull_request:
8 | branches: [ main ]
9 | # Allows you to run this workflow manually from the Actions tab
10 | workflow_dispatch:
11 |
12 | jobs:
13 | build:
14 | name: "python ${{ matrix.python-version }}"
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.9", "3.10", "3.11"]
20 | steps:
21 | - name: Cancel previous
22 | uses: styfle/cancel-workflow-action@0.7.0
23 | with:
24 | access_token: ${{ github.token }}
25 | if: ${{github.ref != 'refs/head/main'}}
26 | - uses: actions/checkout@v4
27 | - name: Set up Python ${{ matrix.python-version }}
28 | uses: actions/setup-python@v5
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 | - name: Get pip cache dir
32 | id: pip-cache
33 | run: |
34 | python -m pip install --upgrade pip wheel
35 | echo "::set-output name=dir::$(pip cache dir)"
36 | - name: pip cache
37 | uses: actions/cache@v4
38 | with:
39 | path: ${{ steps.pip-cache.outputs.dir }}
40 | key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
41 | restore-keys: |
42 | ${{ runner.os }}-pip-
43 | - name: Install Xarray-Beam
44 | run: |
45 | pip install -e .[tests]
46 | - name: Run unit tests
47 | # TODO(shoyer): remove the pangeo-forge module? the tests no longer pass
48 | run: |
49 | pytest xarray_beam
50 | - name: Run example tests
51 | # The examples define some of the same flags, so we run pytest in separate processes.
52 | run: |
53 | pytest examples/era5_climatology_test.py
54 | pytest examples/xbeam_rechunk_test.py
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | .DS_Store
3 | build
4 | dist
5 | docs/.ipynb_checkpoints
6 | docs/_build
7 | docs/_autosummary
8 | docs/*.zarr
9 | __pycache__
10 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | os: ubuntu-22.04
10 | tools:
11 | python: "3.10"
12 |
13 | # Build documentation in the docs/ directory with Sphinx
14 | sphinx:
15 | configuration: docs/conf.py
16 |
17 | # Optionally declare the Python requirements required to build your docs
18 | python:
19 | install:
20 | - method: pip
21 | path: .
22 | - requirements: docs/requirements.txt
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Xarray-Beam
2 |
3 | Xarray-Beam is a Python library for building
4 | [Apache Beam](https://beam.apache.org/) pipelines with
5 | [Xarray](http://xarray.pydata.org/en/stable/) datasets.
6 |
7 | The project aims to facilitate data transformations and analysis on large-scale
8 | multi-dimensional labeled arrays, such as:
9 |
10 | - Ad-hoc computation on Xarray data, by dividing a `xarray.Dataset` into many
11 | smaller pieces ("chunks").
12 | - Adjusting array chunks, using the
13 | [Rechunker algorithm](https://rechunker.readthedocs.io/en/latest/algorithm.html).
14 | - Ingesting large, multi-dimensional array datasets into an analysis-ready,
15 | cloud-optimized format, namely [Zarr](https://zarr.readthedocs.io/) (see
16 | also [Pangeo Forge](https://github.com/pangeo-forge/pangeo-forge-recipes)).
17 | - Calculating statistics (e.g., "climatology") across distributed datasets
18 | with arbitrary groups.
19 |
20 | For more about our approach and how to get started,
21 | **[read the documentation](https://xarray-beam.readthedocs.io/)**!
22 |
23 | **Warning: Xarray-Beam is a sharp tool 🔪**
24 |
25 | Xarray-Beam is relatively new, and focused on expert users:
26 |
27 | - We use it extensively at Google for processing large-scale weather datasets,
28 | but there is not yet a vibrant external community.
29 | - It provides low-level abstractions that facilitate writing very large
30 | scale data pipelines (e.g., 100+ TB), but by design it requires explicitly
31 | thinking about how every operation is parallelized.
32 |
33 | ## Installation
34 |
35 | Xarray-Beam requires recent versions of immutabledict, Xarray, Dask, Rechunker,
36 | Zarr, and Apache Beam. For best performance when writing Zarr files, use Xarray
37 | 0.19.0 or later.
38 |
39 | ## Disclaimer
40 |
41 | Xarray-Beam is an experiment that we are sharing with the outside world in the
42 | hope that it will be useful. It is not a supported Google product. We welcome
43 | feedback, bug reports and code contributions, but cannot guarantee they will be
44 | addressed.
45 |
46 | See the "Contribution guidelines" for more.
47 |
48 | ## Credits
49 |
50 | Contributors:
51 |
52 | - Stephan Hoyer
53 | - Jason Hickey
54 | - Cenk Gazen
55 | - Alex Merose
56 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Configure FLAGS with default values for absltest."""
15 | from absl import app
16 |
17 | try:
18 | app.run(lambda argv: None)
19 | except SystemExit:
20 | pass
21 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/README.md:
--------------------------------------------------------------------------------
1 | These images were created in Google drawings.
2 |
3 | - [Logo](https://docs.google.com/drawings/d/1gBt0tOOnY7pMDB1g45T7ARuwCc1CZjojZT2IfdvA5FI/edit?usp=sharing)
4 | - [Xarray-Beam vs Xarray-Dask](https://docs.google.com/drawings/d/1LBM4AKHadzuS8tsHQ5TcTAWaBM7E6NeCd7Aq16hpIxk/edit?usp=sharing)
5 |
--------------------------------------------------------------------------------
/docs/_static/xarray-beam-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/xarray-beam/da2ff60f48209e9c82ea1760d1d98540556e7271/docs/_static/xarray-beam-logo.png
--------------------------------------------------------------------------------
/docs/_static/xarray-beam-vs-xarray-dask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/xarray-beam/da2ff60f48209e9c82ea1760d1d98540556e7271/docs/_static/xarray-beam-vs-xarray-dask.png
--------------------------------------------------------------------------------
/docs/aggregation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "aaf2f2cb",
6 | "metadata": {},
7 | "source": [
8 | "# Aggregation\n",
9 | "\n",
10 | "Xarray-Beam can perform efficient distributed data aggregation in the \"map-reduce\" model. \n",
11 | "\n",
12 | "This currently only includes `Mean`, but we would welcome contributions of other aggregation functions such as `Sum`, `Std`, `Var`, `Min`, `Max`, etc."
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "id": "4b77b953",
18 | "metadata": {},
19 | "source": [
20 | "## High-level API\n",
21 | "\n",
22 | "The `Mean` transformation comes in three forms: {py:class}`Mean `, {py:class}`Mean.Globally `, and {py:class}`Mean.PerKey `. The implementation is highly scalable, based on a Beam's [`CombineFn`](https://beam.apache.org/documentation/transforms/python/aggregation/combineglobally/#example-4-combining-with-a-combinefn)."
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "id": "fca45196",
28 | "metadata": {},
29 | "source": [
30 | "The high-level `Mean` transform can be used to aggregate a distributed dataset across an existing dimension or dimensions, similar to Xarray's `.mean()` method:"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 1,
36 | "id": "3e387dd1",
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "\n",
44 | "Dimensions: (lat: 25, time: 2920, lon: 53)\n",
45 | "Coordinates:\n",
46 | " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
47 | " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
48 | " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
49 | "Data variables:\n",
50 | " air (time, lat, lon) float32 241.2 242.5 243.5 ... 296.5 296.2 295.7\n",
51 | "Attributes:\n",
52 | " Conventions: COARDS\n",
53 | " title: 4x daily NMC reanalysis (1948)\n",
54 | " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n",
55 | " platform: Model\n",
56 | " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...\n"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "import apache_beam as beam\n",
62 | "import numpy as np\n",
63 | "import xarray_beam as xbeam\n",
64 | "import xarray\n",
65 | "\n",
66 | "ds = xarray.tutorial.load_dataset('air_temperature')\n",
67 | "print(ds)"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 2,
73 | "id": "5967340a",
74 | "metadata": {},
75 | "outputs": [
76 | {
77 | "data": {
78 | "application/javascript": [
79 | "\n",
80 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
81 | " var jqueryScript = document.createElement('script');\n",
82 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
83 | " jqueryScript.type = 'text/javascript';\n",
84 | " jqueryScript.onload = function() {\n",
85 | " var datatableScript = document.createElement('script');\n",
86 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
87 | " datatableScript.type = 'text/javascript';\n",
88 | " datatableScript.onload = function() {\n",
89 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
90 | " window.interactive_beam_jquery(document).ready(function($){\n",
91 | " \n",
92 | " });\n",
93 | " }\n",
94 | " document.head.appendChild(datatableScript);\n",
95 | " };\n",
96 | " document.head.appendChild(jqueryScript);\n",
97 | " } else {\n",
98 | " window.interactive_beam_jquery(document).ready(function($){\n",
99 | " \n",
100 | " });\n",
101 | " }"
102 | ]
103 | },
104 | "metadata": {},
105 | "output_type": "display_data"
106 | },
107 | {
108 | "name": "stderr",
109 | "output_type": "stream",
110 | "text": [
111 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[2]: Mean/PerKey/CombinePerKey(MeanCombineFn)/GroupByKey'. \n",
112 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[2]: Mean/PerKey/CombinePerKey(MeanCombineFn)/GroupByKey'. \n",
113 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[2]: Mean/PerKey/CombinePerKey(MeanCombineFn)'. \n"
114 | ]
115 | },
116 | {
117 | "name": "stdout",
118 | "output_type": "stream",
119 | "text": [
120 | "(Key(offsets={'lat': 0, 'lon': 0}, vars=None), \n",
121 | "Dimensions: (lat: 25, lon: 53)\n",
122 | "Coordinates:\n",
123 | " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
124 | " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
125 | "Data variables:\n",
126 | " air (lat, lon) float64 260.4 260.2 259.9 259.5 ... 297.3 297.3 297.3)\n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "with beam.Pipeline() as p:\n",
132 | " p | xbeam.DatasetToChunks(ds, chunks={'time': 1000}) | xbeam.Mean('time') | beam.Map(print)"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "id": "7a584485",
138 | "metadata": {},
139 | "source": [
140 | "## Lower-level API"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "id": "6f7585a9",
146 | "metadata": {},
147 | "source": [
148 | "Xarray-Beam also includes lower-level transforations modelled off of [`beam.Mean`](https://beam.apache.org/documentation/transforms/python/aggregation/mean/) rather than {py:meth}`xarray.Dataset.mean`: they compute averages over sequences of `xarray.Dataset` objects or (`key`, `xarray.Dataset`) pairs, rather than calculating an average over an existing Xarray dimension or based on `xarray_beam.Key` objects, e.g.,"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 3,
154 | "id": "86a925af",
155 | "metadata": {},
156 | "outputs": [
157 | {
158 | "name": "stderr",
159 | "output_type": "stream",
160 | "text": [
161 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-b19331c2-e263-4737-bd64-012081154884.json']\n"
162 | ]
163 | },
164 | {
165 | "data": {
166 | "text/plain": [
167 | "[\n",
168 | " Dimensions: (x: 3)\n",
169 | " Dimensions without coordinates: x\n",
170 | " Data variables:\n",
171 | " foo (x) float64 0.05667 -0.02306 -0.1648]"
172 | ]
173 | },
174 | "execution_count": 3,
175 | "metadata": {},
176 | "output_type": "execute_result"
177 | }
178 | ],
179 | "source": [
180 | "datasets = [\n",
181 | " xarray.Dataset({'foo': ('x', np.random.randn(3))})\n",
182 | " for _ in range(100)\n",
183 | "]\n",
184 | "datasets | xbeam.Mean.Globally()"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "id": "5670130b",
190 | "metadata": {},
191 | "source": [
192 | "Notice how existing dimensions on each datasets are unchanged by the transformation. If you want to average over existing dimensions, use the high-level `Mean` transform or do that aggregation yourself, e.g., by averaging inside each chunk before combining the data."
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "id": "6ab152b6",
198 | "metadata": {},
199 | "source": [
200 | "Similarly, the keys fed into `xbeam.Mean.PerKey` can be any hashables, including but not limited to `xbeam.Key`:"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 4,
206 | "id": "b2399483",
207 | "metadata": {},
208 | "outputs": [
209 | {
210 | "name": "stderr",
211 | "output_type": "stream",
212 | "text": [
213 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-b19331c2-e263-4737-bd64-012081154884.json']\n"
214 | ]
215 | },
216 | {
217 | "data": {
218 | "text/plain": [
219 | "[('DJF',\n",
220 | " \n",
221 | " Dimensions: ()\n",
222 | " Data variables:\n",
223 | " air float64 273.6),\n",
224 | " ('MAM',\n",
225 | " \n",
226 | " Dimensions: ()\n",
227 | " Data variables:\n",
228 | " air float64 279.0),\n",
229 | " ('JJA',\n",
230 | " \n",
231 | " Dimensions: ()\n",
232 | " Data variables:\n",
233 | " air float64 289.2),\n",
234 | " ('SON',\n",
235 | " \n",
236 | " Dimensions: ()\n",
237 | " Data variables:\n",
238 | " air float64 283.0)]"
239 | ]
240 | },
241 | "execution_count": 4,
242 | "metadata": {},
243 | "output_type": "execute_result"
244 | }
245 | ],
246 | "source": [
247 | "datasets = [\n",
248 | " (time.dt.season.item(), ds.sel(time=time).mean())\n",
249 | " for time in ds.time\n",
250 | "]\n",
251 | "datasets | xbeam.Mean.PerKey()"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "id": "c9db1627",
257 | "metadata": {},
258 | "source": [
259 | "`Mean.PerKey` is particularly useful in combination with {class}`beam.GroupByKey` for performing large-scale \"group by\" operations. For example, that a look at the [ERA5 climatology example](https://github.com/google/xarray-beam/blob/main/examples/era5_climatology.py)."
260 | ]
261 | },
262 | {
263 | "cell_type": "markdown",
264 | "id": "0fb6fc6a",
265 | "metadata": {},
266 | "source": [
267 | "## Custom aggregations\n",
268 | "\n",
269 | "The \"tree reduction\" algorithm used by the combiner inside `Mean` is great, but it isn't the only way to aggregate a dataset with Xarray-Beam.\n",
270 | "\n",
271 | "In many cases, the easiest way to scale up an aggregation pipeline is to make use of [rechunking](rechunking.ipynb) to convert the many small datasets inside your pipeline into a form that is easier to calculate in a scalable way. However, rechunking is much less efficient than using combiner, because each use of `Rechunk` requires a complete shuffle of the input data (i.e., writing all data in the pipepilne to temporary files on disk).\n",
272 | "\n",
273 | "For example, here's how one could compute the `median`, which is a notoriously difficult statistic to calculate with distributed algorithms:"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 5,
279 | "id": "ef1ef099",
280 | "metadata": {
281 | "scrolled": false
282 | },
283 | "outputs": [
284 | {
285 | "name": "stderr",
286 | "output_type": "stream",
287 | "text": [
288 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n",
289 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n",
290 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n",
291 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n",
292 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: ConsolidateChunks/GroupByTempKeys'. \n",
293 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: ConsolidateChunks/GroupByTempKeys'. \n"
294 | ]
295 | },
296 | {
297 | "name": "stdout",
298 | "output_type": "stream",
299 | "text": [
300 | "\n",
301 | "Dimensions: (lat: 25, lon: 53)\n",
302 | "Coordinates:\n",
303 | " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
304 | " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
305 | "Data variables:\n",
306 | " air (lat, lon) float32 261.3 261.1 260.9 260.3 ... 297.3 297.3 297.3\n"
307 | ]
308 | }
309 | ],
310 | "source": [
311 | "source_chunks = {'time': 100, 'lat': -1, 'lon': -1}\n",
312 | "working_chunks = {'lat': 10, 'lon': 10, 'time': -1}\n",
313 | "\n",
314 | "with beam.Pipeline() as p:\n",
315 | " (\n",
316 | " p\n",
317 | " | xbeam.DatasetToChunks(ds, source_chunks)\n",
318 | " | xbeam.Rechunk(ds.sizes, source_chunks, working_chunks, itemsize=4)\n",
319 | " | beam.MapTuple(lambda k, v: (k.with_offsets(time=None), v.median('time')))\n",
320 | " | xbeam.ConsolidateChunks({'lat': -1, 'lon': -1})\n",
321 | " | beam.MapTuple(lambda k, v: print(v))\n",
322 | " )"
323 | ]
324 | }
325 | ],
326 | "metadata": {
327 | "kernelspec": {
328 | "display_name": "Python 3 (ipykernel)",
329 | "language": "python",
330 | "name": "python3"
331 | },
332 | "language_info": {
333 | "codemirror_mode": {
334 | "name": "ipython",
335 | "version": 3
336 | },
337 | "file_extension": ".py",
338 | "mimetype": "text/x-python",
339 | "name": "python",
340 | "nbconvert_exporter": "python",
341 | "pygments_lexer": "ipython3",
342 | "version": "3.9.13"
343 | }
344 | },
345 | "nbformat": 4,
346 | "nbformat_minor": 5
347 | }
348 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API docs
2 |
3 | ```{eval-rst}
4 | .. currentmodule:: xarray_beam
5 | ```
6 |
7 | ## Core data model
8 |
9 | ```{eval-rst}
10 | .. autosummary::
11 | :toctree: _autosummary
12 |
13 | Key
14 | ```
15 |
16 | ## Reading and writing data
17 |
18 | ```{eval-rst}
19 | .. autosummary::
20 | :toctree: _autosummary
21 |
22 | open_zarr
23 | DatasetToChunks
24 | ChunksToZarr
25 | DatasetToZarr
26 | make_template
27 | replace_template_dims
28 | setup_zarr
29 | validate_zarr_chunk
30 | write_chunk_to_zarr
31 | ```
32 |
33 | ## Aggregation
34 |
35 | ```{eval-rst}
36 | .. autosummary::
37 | :toctree: _autosummary
38 |
39 | Mean
40 | Mean.Globally
41 | Mean.PerKey
42 | MeanCombineFn
43 | ```
44 |
45 | ## Rechunking
46 |
47 | ```{eval-rst}
48 | .. autosummary::
49 | :toctree: _autosummary
50 |
51 | ConsolidateChunks
52 | ConsolidateVariables
53 | SplitChunks
54 | SplitVariables
55 | Rechunk
56 | ```
57 |
58 | ## Utility transforms
59 |
60 | ```{eval-rst}
61 | .. autosummary::
62 | :toctree: _autosummary
63 |
64 | ValidateEachChunk
65 | ```
66 |
67 | ## Utility functions
68 |
69 | ```{eval-rst}
70 | .. autosummary::
71 | :toctree: _autosummary
72 |
73 | offsets_to_slices
74 | validate_chunk
75 | consolidate_chunks
76 | consolidate_variables
77 | consolidate_fully
78 | split_chunks
79 | split_variables
80 | in_memory_rechunk
81 | ```
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 | # Print Python environment info for easier debugging on ReadTheDocs
18 |
19 | import sys
20 | import subprocess
21 | import xarray_beam # verify this works
22 |
23 | print("python exec:", sys.executable)
24 | print("sys.path:", sys.path)
25 | print("pip environment:")
26 | subprocess.run([sys.executable, "-m", "pip", "list"])
27 |
28 | print(f"xarray_beam: {xarray_beam.__version__}, {xarray_beam.__file__}")
29 |
30 | # -- Project information -----------------------------------------------------
31 |
32 | project = 'Xarray-Beam'
33 | copyright = '2021, Google LCC'
34 | author = 'The Xarray-Beam authors'
35 |
36 |
37 | # -- General configuration ---------------------------------------------------
38 |
39 | # Add any Sphinx extension module names here, as strings. They can be
40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
41 | # ones.
42 | extensions = [
43 | 'sphinx.ext.autodoc',
44 | 'sphinx.ext.autosummary',
45 | 'sphinx.ext.napoleon',
46 | 'myst_nb',
47 | ]
48 |
49 | # Add any paths that contain templates here, relative to this directory.
50 | templates_path = ['_templates']
51 |
52 | # List of patterns, relative to source directory, that match files and
53 | # directories to ignore when looking for source files.
54 | # This pattern also affects html_static_path and html_extra_path.
55 | exclude_patterns = ['_build', '_templates', 'Thumbs.db', '.DS_Store']
56 |
57 | intersphinx_mapping = {
58 | "xarray": ("https://xarray.pydata.org/en/latest/", None),
59 | }
60 |
61 | # -- Options for HTML output -------------------------------------------------
62 |
63 | # The theme to use for HTML and HTML Help pages. See the documentation for
64 | # a list of builtin themes.
65 | #
66 | html_theme = 'sphinx_rtd_theme'
67 |
68 | # Add any paths that contain custom static files (such as style sheets) here,
69 | # relative to this directory. They are copied after the builtin static files,
70 | # so a file named "default.css" will overwrite the builtin "default.css".
71 | html_static_path = ['_static']
72 |
73 | # -- Extension config
74 |
75 | autosummary_generate = True
76 |
77 | # https://myst-nb.readthedocs.io/en/latest/use/execute.html
78 | jupyter_execute_notebooks = "cache"
79 | # https://myst-nb.readthedocs.io/en/latest/use/formatting_outputs.html#removing-stdout-and-stderr
80 | nb_output_stderr = "remove-warn"
81 |
82 | # https://stackoverflow.com/a/66295922/809705
83 | autodoc_typehints = "description"
84 |
--------------------------------------------------------------------------------
/docs/data-model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Core data model"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Xarray-Beam tries to make it _straightforward_ to write distributed pipelines with Xarray objects, but unlike libraries like [Xarray with Dask](http://xarray.pydata.org/en/stable/user-guide/dask.html) or Dask/Spark DataFrames, it doesn't hide the distributed magic inside high-level objects.\n",
15 | "\n",
16 | "Xarray-Beam is a lower-level tool. You will be manipulating large datasets piece-by-piece yourself, and you as the developer will be responsible for maintaining Xarray-Beam's internal invariants. This means that to successfully use Xarray-Beam, **you will need to understand how how it represents distributed datasets**.\n",
17 | "\n",
18 | "This responsibility requires a bit more coding and understanding, but offers benefits in performance and flexibility. This brief tutorial will show you how.\n",
19 | "\n",
20 | "We'll start off with some standard imports:"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 1,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import apache_beam as beam\n",
30 | "import numpy as np\n",
31 | "import xarray_beam as xbeam\n",
32 | "import xarray"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "## Keys in Xarray-Beam\n",
40 | "\n",
41 | "Xarray-Beam is designed around the model that every stage in your Beam pipeline _could_ be stored in a single `xarray.Dataset` object, but is instead represented by a distributed beam `PCollection` of smaller `xarray.Dataset` objects, distributed in two possible ways:\n",
42 | "\n",
43 | "- Distinct _variables_ in a Dataset may be separated across multiple records.\n",
44 | "- Individual arrays can also be split into multiple _chunks_, similar to those used by [dask.array](https://docs.dask.org/en/latest/array.html).\n",
45 | "\n",
46 | "To keep track of how individual records could be combined into a larger (virtual) dataset, Xarray-Beam defines a {py:class}`~xarray_beam.Key` object. Key objects consist of:\n",
47 | "\n",
48 | "1. `offsets`: integer offests for chunks from the origin in an `immutabledict`\n",
49 | "2. `vars`: The subset of variables included in each chunk, either as a `frozenset`, or as `None` to indicate \"all variables\"."
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "Making a {py:class}`~xarray_beam.Key` from scratch is simple:"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 2,
62 | "metadata": {},
63 | "outputs": [
64 | {
65 | "data": {
66 | "text/plain": [
67 | "Key(offsets={'x': 0, 'y': 10}, vars=None)"
68 | ]
69 | },
70 | "execution_count": 2,
71 | "metadata": {},
72 | "output_type": "execute_result"
73 | }
74 | ],
75 | "source": [
76 | "key = xbeam.Key({'x': 0, 'y': 10}, vars=None)\n",
77 | "key"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "Or given an existing {py:class}`~xarray_beam.Key`, you can easily modify it with `replace()` or `with_offsets()`:"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 3,
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "data": {
94 | "text/plain": [
95 | "Key(offsets={'x': 0, 'y': 10}, vars={'bar', 'foo'})"
96 | ]
97 | },
98 | "execution_count": 3,
99 | "metadata": {},
100 | "output_type": "execute_result"
101 | }
102 | ],
103 | "source": [
104 | "key.replace(vars={'foo', 'bar'})"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 4,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "data": {
114 | "text/plain": [
115 | "Key(offsets={'y': 10, 'z': 1}, vars=None)"
116 | ]
117 | },
118 | "execution_count": 4,
119 | "metadata": {},
120 | "output_type": "execute_result"
121 | }
122 | ],
123 | "source": [
124 | "key.with_offsets(x=None, z=1)"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "{py:class}`~xarray_beam.Key` objects don't do very much. They are just simple structs with two attributes, along with various special methods required to use them as `dict` keys or as keys in Beam pipelines. You can find a more examples of manipulating keys in the docstring of {py:class}`~xarray_beam.Key`."
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "## Creating PCollections"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {},
144 | "source": [
145 | "The standard inputs & outputs for Xarray-Beam are PCollections of `(xbeam.Key, xarray.Dataset)` pairs. Xarray-Beam provides a bunch of PCollections for typical tasks, but many pipelines will still involve some manual manipulation of `Key` and `Dataset` objects, e.g., with builtin Beam transforms like `beam.Map`.\n",
146 | "\n",
147 | "To start off, let's write a helper functions for creating our first collection from scratch:"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 5,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "def create_records():\n",
157 | " for offset in [0, 4]:\n",
158 | " key = xbeam.Key({'x': offset, 'y': 0})\n",
159 | " data = 2 * offset + np.arange(8).reshape(4, 2)\n",
160 | " chunk = xarray.Dataset({\n",
161 | " 'foo': (('x', 'y'), data),\n",
162 | " 'bar': (('x', 'y'), 100 + data),\n",
163 | " })\n",
164 | " yield key, chunk"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {},
170 | "source": [
171 | "Let's take a look the entries, which are lazily constructed with the generator:"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 6,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "inputs = list(create_records())"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 7,
186 | "metadata": {},
187 | "outputs": [
188 | {
189 | "data": {
190 | "text/plain": [
191 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n",
192 | " \n",
193 | " Dimensions: (x: 4, y: 2)\n",
194 | " Dimensions without coordinates: x, y\n",
195 | " Data variables:\n",
196 | " foo (x, y) int64 0 1 2 3 4 5 6 7\n",
197 | " bar (x, y) int64 100 101 102 103 104 105 106 107),\n",
198 | " (Key(offsets={'x': 4, 'y': 0}, vars=None),\n",
199 | " \n",
200 | " Dimensions: (x: 4, y: 2)\n",
201 | " Dimensions without coordinates: x, y\n",
202 | " Data variables:\n",
203 | " foo (x, y) int64 8 9 10 11 12 13 14 15\n",
204 | " bar (x, y) int64 108 109 110 111 112 113 114 115)]"
205 | ]
206 | },
207 | "execution_count": 7,
208 | "metadata": {},
209 | "output_type": "execute_result"
210 | }
211 | ],
212 | "source": [
213 | "inputs"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "```{note}\n",
221 | "There are multiple valid ways to represent a chunk of a larger dataset with a `Key`.\n",
222 | "\n",
223 | "- **Offsets for unchunked dimensions are optional**. Because all chunks have the same offset along the `y` axis, including `y` in `offsets` is not required as long as we don't need to create multiple chunks along that dimension.\n",
224 | "- **Indicating variables is optional, if all chunks have the same variables**. We could have set `vars={'foo', 'bar'}` on each of these `Key` objects instead of `vars=None`. This would be an equally valid representation of the same records, since all of our datasets have the same variables.\n",
225 | "```"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "We now have the inputs we need to use Xarray-Beam's helper functions and PTransforms. For example, we can fully consolidate chunks & variables to see what single `xarray.Dataset` these values would correspond to:"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 8,
238 | "metadata": {},
239 | "outputs": [
240 | {
241 | "data": {
242 | "text/plain": [
243 | "(Key(offsets={'x': 0, 'y': 0}, vars={'bar', 'foo'}),\n",
244 | " \n",
245 | " Dimensions: (x: 8, y: 2)\n",
246 | " Dimensions without coordinates: x, y\n",
247 | " Data variables:\n",
248 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15\n",
249 | " bar (x, y) int64 100 101 102 103 104 105 ... 110 111 112 113 114 115)"
250 | ]
251 | },
252 | "execution_count": 8,
253 | "metadata": {},
254 | "output_type": "execute_result"
255 | }
256 | ],
257 | "source": [
258 | "xbeam.consolidate_fully(inputs)"
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {},
264 | "source": [
265 | "To execute with Beam, of course, we need to turn Python lists/generators into Beam PCollections, e.g., with `beam.Create()`:"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 9,
271 | "metadata": {},
272 | "outputs": [
273 | {
274 | "data": {
275 | "application/javascript": [
276 | "\n",
277 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
278 | " var jqueryScript = document.createElement('script');\n",
279 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
280 | " jqueryScript.type = 'text/javascript';\n",
281 | " jqueryScript.onload = function() {\n",
282 | " var datatableScript = document.createElement('script');\n",
283 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
284 | " datatableScript.type = 'text/javascript';\n",
285 | " datatableScript.onload = function() {\n",
286 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
287 | " window.interactive_beam_jquery(document).ready(function($){\n",
288 | " \n",
289 | " });\n",
290 | " }\n",
291 | " document.head.appendChild(datatableScript);\n",
292 | " };\n",
293 | " document.head.appendChild(jqueryScript);\n",
294 | " } else {\n",
295 | " window.interactive_beam_jquery(document).ready(function($){\n",
296 | " \n",
297 | " });\n",
298 | " }"
299 | ]
300 | },
301 | "metadata": {},
302 | "output_type": "display_data"
303 | },
304 | {
305 | "name": "stdout",
306 | "output_type": "stream",
307 | "text": [
308 | "(Key(offsets={'x': 0, 'y': 0}, vars=None), \n",
309 | "Dimensions: (x: 4, y: 2)\n",
310 | "Dimensions without coordinates: x, y\n",
311 | "Data variables:\n",
312 | " foo (x, y) int64 0 1 2 3 4 5 6 7\n",
313 | " bar (x, y) int64 100 101 102 103 104 105 106 107)\n",
314 | "(Key(offsets={'x': 4, 'y': 0}, vars=None), \n",
315 | "Dimensions: (x: 4, y: 2)\n",
316 | "Dimensions without coordinates: x, y\n",
317 | "Data variables:\n",
318 | " foo (x, y) int64 8 9 10 11 12 13 14 15\n",
319 | " bar (x, y) int64 108 109 110 111 112 113 114 115)\n"
320 | ]
321 | }
322 | ],
323 | "source": [
324 | "with beam.Pipeline() as p:\n",
325 | " p | beam.Create(create_records()) | beam.Map(print)"
326 | ]
327 | },
328 | {
329 | "cell_type": "markdown",
330 | "metadata": {},
331 | "source": [
332 | "## Writing pipelines"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {},
338 | "source": [
339 | "Transforms in Xarray-Beam typically act on (key, value) pairs of `(xbeam.Key, xarray.Dataset)`. For example, we can dump our dataset on disk in the scalable [Zarr](https://zarr.readthedocs.io/) format using {py:class}`~xarray_beam.ChunksToZarr`:"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 10,
345 | "metadata": {},
346 | "outputs": [
347 | {
348 | "name": "stderr",
349 | "output_type": "stream",
350 | "text": [
351 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2d587cc9-3460-4867-a777-2f0c8bc4e743.json']\n"
352 | ]
353 | },
354 | {
355 | "data": {
356 | "text/plain": [
357 | "[None, None]"
358 | ]
359 | },
360 | "execution_count": 10,
361 | "metadata": {},
362 | "output_type": "execute_result"
363 | }
364 | ],
365 | "source": [
366 | "inputs | xbeam.ChunksToZarr('my-data.zarr')"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "Xarray-Beam doesn't try to provide transformations for everything. In particular, it omits most [embarrassingly parallel](https://en.wikipedia.org/wiki/Embarrassingly_parallel) operations that can be performed independently on each chunk of a larger dataset. You can write these yourself using [`beam.Map`](https://beam.apache.org/documentation/transforms/python/elementwise/map/).\n",
374 | "\n",
375 | "For example, consider elementwise arithmetic. We can write a `lambda` function that acts on each key-value pair updating the xarray.Dataset objects appropriately, and put it into an Xarray-Beam pipeline using `beam.MapTuple`:"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": 11,
381 | "metadata": {},
382 | "outputs": [
383 | {
384 | "name": "stderr",
385 | "output_type": "stream",
386 | "text": [
387 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-beb6b630-5659-435b-8a62-c213f6177340.json']\n"
388 | ]
389 | },
390 | {
391 | "data": {
392 | "text/plain": [
393 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n",
394 | " \n",
395 | " Dimensions: (x: 4, y: 2)\n",
396 | " Dimensions without coordinates: x, y\n",
397 | " Data variables:\n",
398 | " foo (x, y) int64 1 2 3 4 5 6 7 8\n",
399 | " bar (x, y) int64 101 102 103 104 105 106 107 108),\n",
400 | " (Key(offsets={'x': 4, 'y': 0}, vars=None),\n",
401 | " \n",
402 | " Dimensions: (x: 4, y: 2)\n",
403 | " Dimensions without coordinates: x, y\n",
404 | " Data variables:\n",
405 | " foo (x, y) int64 9 10 11 12 13 14 15 16\n",
406 | " bar (x, y) int64 109 110 111 112 113 114 115 116)]"
407 | ]
408 | },
409 | "execution_count": 11,
410 | "metadata": {},
411 | "output_type": "execute_result"
412 | }
413 | ],
414 | "source": [
415 | "inputs | beam.MapTuple(lambda k, v: (k, v + 1))"
416 | ]
417 | },
418 | {
419 | "cell_type": "markdown",
420 | "metadata": {},
421 | "source": [
422 | "For operations that add or remove (unchunked) dimensions, you may need to update `Key` objects as well to maintain the Xarray-Beam invariants, e.g., if we want to remove the `y` dimension entirely:"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": 12,
428 | "metadata": {},
429 | "outputs": [
430 | {
431 | "name": "stderr",
432 | "output_type": "stream",
433 | "text": [
434 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-beb6b630-5659-435b-8a62-c213f6177340.json']\n"
435 | ]
436 | },
437 | {
438 | "data": {
439 | "text/plain": [
440 | "[(Key(offsets={'x': 0}, vars=None),\n",
441 | " \n",
442 | " Dimensions: (x: 4)\n",
443 | " Dimensions without coordinates: x\n",
444 | " Data variables:\n",
445 | " foo (x) float64 0.5 2.5 4.5 6.5\n",
446 | " bar (x) float64 100.5 102.5 104.5 106.5),\n",
447 | " (Key(offsets={'x': 4}, vars=None),\n",
448 | " \n",
449 | " Dimensions: (x: 4)\n",
450 | " Dimensions without coordinates: x\n",
451 | " Data variables:\n",
452 | " foo (x) float64 8.5 10.5 12.5 14.5\n",
453 | " bar (x) float64 108.5 110.5 112.5 114.5)]"
454 | ]
455 | },
456 | "execution_count": 12,
457 | "metadata": {},
458 | "output_type": "execute_result"
459 | }
460 | ],
461 | "source": [
462 | "inputs | beam.MapTuple(lambda k, v: (k.with_offsets(y=None), v.mean('y')))"
463 | ]
464 | },
465 | {
466 | "cell_type": "markdown",
467 | "metadata": {},
468 | "source": [
469 | "```{note}\n",
470 | "Missing transformations in Xarray-Beam is partially an intentional design decision to reduce scope, and partially just a reflection of what we've gotten around to implementing. If after reading through the rest of docs you notice missing transformations or are wondering how to compute something in Xarray-Beam, please [open an issue](https://github.com/google/xarray-beam/issues) to discuss.\n",
471 | "```"
472 | ]
473 | }
474 | ],
475 | "metadata": {
476 | "kernelspec": {
477 | "display_name": "Python 3 (ipykernel)",
478 | "language": "python",
479 | "name": "python3"
480 | },
481 | "language_info": {
482 | "codemirror_mode": {
483 | "name": "ipython",
484 | "version": 3
485 | },
486 | "file_extension": ".py",
487 | "mimetype": "text/x-python",
488 | "name": "python",
489 | "nbconvert_exporter": "python",
490 | "pygments_lexer": "ipython3",
491 | "version": "3.9.13"
492 | }
493 | },
494 | "nbformat": 4,
495 | "nbformat_minor": 2
496 | }
497 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Xarray-Beam: distributed Xarray with Apache Beam
2 |
3 | Xarray-Beam is a library for writing [Apache Beam](http://beam.apache.org/) pipelines consisting of [xarray](http://xarray.pydata.org) Dataset objects. This documentation (and Xarray-Beam itself) assumes basic familiarity with both Beam and Xarray.
4 |
5 | The documentation includes narrative documentation that will walk you through the basics of writing a pipeline with Xarray-Beam, and also comprehensive API docs.
6 |
7 | We recommend reading both, as well as a few [end to end examples](https://github.com/google/xarray-beam/tree/main/examples) to understand what code using Xarray-Beam typically looks like.
8 |
9 | ## Contents
10 |
11 | ```{toctree}
12 | :maxdepth: 1
13 | why-xarray-beam.md
14 | data-model.ipynb
15 | read-write.ipynb
16 | aggregation.ipynb
17 | rechunking.ipynb
18 | api.md
19 | ```
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/rechunking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Rechunking"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Rechunking lets us re-distribute how datasets are split between variables and chunks across a Beam PCollection."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "To get started we'll recreate our dummy data from the data model tutorial:"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {
28 | "tags": [
29 | "hide-inputs"
30 | ]
31 | },
32 | "outputs": [],
33 | "source": [
34 | "import apache_beam as beam\n",
35 | "import numpy as np\n",
36 | "import xarray_beam as xbeam\n",
37 | "import xarray\n",
38 | "\n",
39 | "def create_records():\n",
40 | " for offset in [0, 4]:\n",
41 | " key = xbeam.Key({'x': offset, 'y': 0})\n",
42 | " data = 2 * offset + np.arange(8).reshape(4, 2)\n",
43 | " chunk = xarray.Dataset({\n",
44 | " 'foo': (('x', 'y'), data),\n",
45 | " 'bar': (('x', 'y'), 100 + data),\n",
46 | " })\n",
47 | " yield key, chunk\n",
48 | " \n",
49 | "inputs = list(create_records())"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## Choosing chunks\n",
57 | "\n",
58 | "Chunking can be essential for some operations. Some operations are very hard or impossible to perform with certain chunking schemes. For example, to make a plot all the data needs to come toether on a single machine. Other calculations such as calculating a median are _possible_ to perform on distributed data, but require tricky algorithms and/or approximation.\n",
59 | "\n",
60 | "More broadly, chunking can have critical performance implications, similar to [those for Xarray and Dask](http://xarray.pydata.org/en/stable/user-guide/dask.html#chunking-and-performance). As a rule of thumb, chunk sizes of 10-100 MB work well. The optimal chunk size is a balance among a number of considerations, adapted here [from Dask docs](https://docs.dask.org/en/latest/array-chunks.html):\n",
61 | "\n",
62 | "1. Chunks should be small enough to fit comfortably into memory on a single machine. As an upper limit, chunks over roughly 2 GB in size will not fit into the protocol buffers Beam uses to pass data between workers. \n",
63 | "2. There should be enough chunks for Beam runners (like Cloud Dataflow) to elastically shard work over many workers.\n",
64 | "3. Chunks should be large enough to amortize the overhead of networking and the Python interpreter, which starts to become noticeable for arrays with fewer than 1 million elements.\n",
65 | "\n",
66 | "The `nbytes` attribute on both NumPy arrays and `xarray.Dataset` objects is a good easy way to figure out how larger chunks are."
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {},
72 | "source": [
73 | "## Adjusting variables"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "The simplest transformation is splitting (or consoldating) different _variables_ in a Dataset with {py:class}`~xarray_beam.SplitVariables()` and {py:class}`~xarray_beam.ConsolidateVariables()`, e.g.,"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 3,
86 | "metadata": {},
87 | "outputs": [
88 | {
89 | "name": "stderr",
90 | "output_type": "stream",
91 | "text": [
92 | "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
93 | ]
94 | },
95 | {
96 | "data": {
97 | "application/javascript": [
98 | "\n",
99 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
100 | " var jqueryScript = document.createElement('script');\n",
101 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
102 | " jqueryScript.type = 'text/javascript';\n",
103 | " jqueryScript.onload = function() {\n",
104 | " var datatableScript = document.createElement('script');\n",
105 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
106 | " datatableScript.type = 'text/javascript';\n",
107 | " datatableScript.onload = function() {\n",
108 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
109 | " window.interactive_beam_jquery(document).ready(function($){\n",
110 | " \n",
111 | " });\n",
112 | " }\n",
113 | " document.head.appendChild(datatableScript);\n",
114 | " };\n",
115 | " document.head.appendChild(jqueryScript);\n",
116 | " } else {\n",
117 | " window.interactive_beam_jquery(document).ready(function($){\n",
118 | " \n",
119 | " });\n",
120 | " }"
121 | ]
122 | },
123 | "metadata": {},
124 | "output_type": "display_data"
125 | },
126 | {
127 | "name": "stderr",
128 | "output_type": "stream",
129 | "text": [
130 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n",
131 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n"
132 | ]
133 | },
134 | {
135 | "data": {
136 | "text/plain": [
137 | "[(Key(offsets={'x': 0, 'y': 0}, vars={'foo'}),\n",
138 | " \n",
139 | " Dimensions: (x: 4, y: 2)\n",
140 | " Dimensions without coordinates: x, y\n",
141 | " Data variables:\n",
142 | " foo (x, y) int64 0 1 2 3 4 5 6 7),\n",
143 | " (Key(offsets={'x': 0, 'y': 0}, vars={'bar'}),\n",
144 | " \n",
145 | " Dimensions: (x: 4, y: 2)\n",
146 | " Dimensions without coordinates: x, y\n",
147 | " Data variables:\n",
148 | " bar (x, y) int64 100 101 102 103 104 105 106 107),\n",
149 | " (Key(offsets={'x': 4, 'y': 0}, vars={'foo'}),\n",
150 | " \n",
151 | " Dimensions: (x: 4, y: 2)\n",
152 | " Dimensions without coordinates: x, y\n",
153 | " Data variables:\n",
154 | " foo (x, y) int64 8 9 10 11 12 13 14 15),\n",
155 | " (Key(offsets={'x': 4, 'y': 0}, vars={'bar'}),\n",
156 | " \n",
157 | " Dimensions: (x: 4, y: 2)\n",
158 | " Dimensions without coordinates: x, y\n",
159 | " Data variables:\n",
160 | " bar (x, y) int64 108 109 110 111 112 113 114 115)]"
161 | ]
162 | },
163 | "execution_count": 3,
164 | "metadata": {},
165 | "output_type": "execute_result"
166 | }
167 | ],
168 | "source": [
169 | "inputs | xbeam.SplitVariables()"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {},
175 | "source": [
176 | "```{tip}\n",
177 | "Instead of a separate transform for splitting variables, you can also set `split_vars=True` in {py:class}`~xarray_beam.DatasetToChunks`.\n",
178 | "```"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "## Adjusting chunks"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "You can also adjust _chunks_ in a dataset to distribute arrays of different sizes. Here you have two choices of API:\n",
193 | "\n",
194 | "1. The lower level {py:class}`~xarray_beam.SplitChunks` and {py:class}`~xarray_beam.ConsolidateChunks`. These transformations apply a single splitting (with indexing) or consolidation (with {py:function}`xarray.concat`) function to array elements.\n",
195 | "2. The high level {py:class}`~xarray_beam.Rechunk`, which uses a pipeline of multiple split/consolidate steps (as needed) to efficiently rechunk a dataset.\n"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "### Low level rechunking"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "For minor adjustments (e.g., mostly along a single dimension), the more explicit `SplitChunks()` and `ConsolidateChunks()` are good options. They take a dict of _desired_ chunk sizes as a parameter, which can also be `-1` to indicate \"no chunking\" along a dimension:"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 4,
215 | "metadata": {},
216 | "outputs": [
217 | {
218 | "name": "stderr",
219 | "output_type": "stream",
220 | "text": [
221 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n",
222 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n",
223 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n",
224 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n"
225 | ]
226 | },
227 | {
228 | "data": {
229 | "text/plain": [
230 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n",
231 | " \n",
232 | " Dimensions: (x: 8, y: 2)\n",
233 | " Dimensions without coordinates: x, y\n",
234 | " Data variables:\n",
235 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15\n",
236 | " bar (x, y) int64 100 101 102 103 104 105 ... 110 111 112 113 114 115)]"
237 | ]
238 | },
239 | "execution_count": 4,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "inputs | xbeam.ConsolidateChunks({'x': -1})"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {},
251 | "source": [
252 | "Note that because these transformations only split _or_ consolidate, they cannot necessary fully rechunk a dataset in a single step if the new chunk sizes are not multiples of old chunks (with consolidate) or do not even divide the old chunks (with split), e.g.,"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 5,
258 | "metadata": {},
259 | "outputs": [
260 | {
261 | "name": "stderr",
262 | "output_type": "stream",
263 | "text": [
264 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n",
265 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n"
266 | ]
267 | },
268 | {
269 | "data": {
270 | "text/plain": [
271 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n",
272 | " \n",
273 | " Dimensions: (x: 4, y: 2)\n",
274 | " Dimensions without coordinates: x, y\n",
275 | " Data variables:\n",
276 | " foo (x, y) int64 0 1 2 3 4 5 6 7\n",
277 | " bar (x, y) int64 100 101 102 103 104 105 106 107),\n",
278 | " (Key(offsets={'x': 4, 'y': 0}, vars=None),\n",
279 | " \n",
280 | " Dimensions: (x: 1, y: 2)\n",
281 | " Dimensions without coordinates: x, y\n",
282 | " Data variables:\n",
283 | " foo (x, y) int64 8 9\n",
284 | " bar (x, y) int64 108 109),\n",
285 | " (Key(offsets={'x': 5, 'y': 0}, vars=None),\n",
286 | " \n",
287 | " Dimensions: (x: 3, y: 2)\n",
288 | " Dimensions without coordinates: x, y\n",
289 | " Data variables:\n",
290 | " foo (x, y) int64 10 11 12 13 14 15\n",
291 | " bar (x, y) int64 110 111 112 113 114 115)]"
292 | ]
293 | },
294 | "execution_count": 5,
295 | "metadata": {},
296 | "output_type": "execute_result"
297 | }
298 | ],
299 | "source": [
300 | "inputs | xbeam.SplitChunks({'x': 5}) # notice that the first two chunks are still separate!"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {},
306 | "source": [
307 | "For such uneven cases, you'll need to use split followed by consolidate:"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 6,
313 | "metadata": {},
314 | "outputs": [
315 | {
316 | "name": "stderr",
317 | "output_type": "stream",
318 | "text": [
319 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n",
320 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n",
321 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n",
322 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n",
323 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n",
324 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n"
325 | ]
326 | },
327 | {
328 | "data": {
329 | "text/plain": [
330 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n",
331 | " \n",
332 | " Dimensions: (x: 5, y: 2)\n",
333 | " Dimensions without coordinates: x, y\n",
334 | " Data variables:\n",
335 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9\n",
336 | " bar (x, y) int64 100 101 102 103 104 105 106 107 108 109),\n",
337 | " (Key(offsets={'x': 5, 'y': 0}, vars=None),\n",
338 | " \n",
339 | " Dimensions: (x: 3, y: 2)\n",
340 | " Dimensions without coordinates: x, y\n",
341 | " Data variables:\n",
342 | " foo (x, y) int64 10 11 12 13 14 15\n",
343 | " bar (x, y) int64 110 111 112 113 114 115)]"
344 | ]
345 | },
346 | "execution_count": 6,
347 | "metadata": {},
348 | "output_type": "execute_result"
349 | }
350 | ],
351 | "source": [
352 | "inputs | xbeam.SplitChunks({'x': 5}) | xbeam.ConsolidateChunks({'x': 5})"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "metadata": {},
358 | "source": [
359 | "### High level rechunking"
360 | ]
361 | },
362 | {
363 | "cell_type": "markdown",
364 | "metadata": {},
365 | "source": [
366 | "Alternatively, the high-level `Rechunk()` method applies multiple split and consolidate steps based on the [Rechunker](https://github.com/pangeo-data/rechunker) algorithm:"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": 7,
372 | "metadata": {},
373 | "outputs": [
374 | {
375 | "name": "stderr",
376 | "output_type": "stream",
377 | "text": [
378 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n",
379 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n",
380 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n",
381 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n",
382 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n",
383 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n"
384 | ]
385 | },
386 | {
387 | "data": {
388 | "text/plain": [
389 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n",
390 | " \n",
391 | " Dimensions: (x: 5, y: 2)\n",
392 | " Dimensions without coordinates: x, y\n",
393 | " Data variables:\n",
394 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9\n",
395 | " bar (x, y) int64 100 101 102 103 104 105 106 107 108 109),\n",
396 | " (Key(offsets={'x': 5, 'y': 0}, vars=None),\n",
397 | " \n",
398 | " Dimensions: (x: 3, y: 2)\n",
399 | " Dimensions without coordinates: x, y\n",
400 | " Data variables:\n",
401 | " foo (x, y) int64 10 11 12 13 14 15\n",
402 | " bar (x, y) int64 110 111 112 113 114 115)]"
403 | ]
404 | },
405 | "execution_count": 7,
406 | "metadata": {},
407 | "output_type": "execute_result"
408 | }
409 | ],
410 | "source": [
411 | "inputs | xbeam.Rechunk(dim_sizes={'x': 6}, source_chunks={'x': 3}, target_chunks={'x': 5}, itemsize=8)"
412 | ]
413 | },
414 | {
415 | "cell_type": "markdown",
416 | "metadata": {},
417 | "source": [
418 | "`Rechunk` requires specifying a few more parameters, but based on that information **it can be _much_ more efficient for more complex rechunking tasks**, particular in cases where data needs to be distributed into a very different shape (e.g., distributing a matrix across rows vs. columns).\n",
419 | "\n",
420 | "The naive \"splitting\" approach in such cases may divide datasets into extremely small tasks corresponding to individual array elements, which adds a huge amount of overhead."
421 | ]
422 | }
423 | ],
424 | "metadata": {
425 | "celltoolbar": "Tags",
426 | "interpreter": {
427 | "hash": "aef148d7ea0dbd1f91630322dd5bc9e24a2135d95f24fe1a9dab9696856be2b9"
428 | },
429 | "kernelspec": {
430 | "display_name": "Python 3 (ipykernel)",
431 | "language": "python",
432 | "name": "python3"
433 | },
434 | "language_info": {
435 | "codemirror_mode": {
436 | "name": "ipython",
437 | "version": 3
438 | },
439 | "file_extension": ".py",
440 | "mimetype": "text/x-python",
441 | "name": "python",
442 | "nbconvert_exporter": "python",
443 | "pygments_lexer": "ipython3",
444 | "version": "3.9.13"
445 | }
446 | },
447 | "nbformat": 4,
448 | "nbformat_minor": 2
449 | }
450 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | # doc requirements
2 | Jinja2==3.1.3
3 | myst-nb==0.17.2
4 | myst-parser==0.18.1
5 | sphinx_rtd_theme==1.2.1
6 | sphinx==5.3.0
7 | scipy==1.10.1
8 |
9 | # xarray-beam requirements
10 | apache-beam==2.47.0
11 | dask==2023.5.1
12 | immutabledict==2.2.4
13 | numpy==1.24.3
14 | pandas==1.5.3
15 | pooch==1.7.0
16 | rechunker==0.5.1
17 | xarray==2023.5.0
18 | zarr==2.14.2
19 |
--------------------------------------------------------------------------------
/docs/why-xarray-beam.md:
--------------------------------------------------------------------------------
1 | # Why Xarray-Beam
2 |
3 | ## Our goals
4 |
5 | Xarray-Beam is a Python library for building
6 | [Apache Beam](https://beam.apache.org/) pipelines with
7 | [Xarray](http://xarray.pydata.org/en/stable/) datasets.
8 |
9 | The project aims to facilitate data transformations and analysis on large-scale
10 | multi-dimensional labeled arrays, such as:
11 |
12 | - Ad-hoc computation on Xarray data, by dividing a `xarray.Dataset` into many
13 | smaller pieces ("chunks").
14 | - Adjusting array chunks, using the
15 | [Rechunker algorithm](https://rechunker.readthedocs.io/en/latest/algorithm.html).
16 | - Ingesting large, multi-dimensional array datasets into an analysis-ready,
17 | cloud-optimized format, namely [Zarr](https://zarr.readthedocs.io/) (see
18 | also [Pangeo Forge](https://github.com/pangeo-forge/pangeo-forge-recipes)).
19 | - Calculating statistics (e.g., "climatology") across distributed datasets
20 | with arbitrary groups.
21 |
22 | ## Our approach
23 |
24 | In Xarray-Beam, distributed Xarray datasets are represented by Beam PCollections
25 | of `(xarray_beam.Key, xarray.Dataset)` pairs, corresponding to a "chunk" of a
26 | larger (virtual) dataset. The {py:class}`~xarray_beam.Key` provides sufficient
27 | metadata for Beam PTransforms like those included in Xarray-Beam to perform
28 | collective operations on the entire dataset. This chunking model is highly
29 | flexible, allowing datasets to be split across multiple variables and/or
30 | into orthogonal, contiguous "chunks" along dimensions.
31 |
32 | Xarray-Beam does not (yet) include high-level abstrations like a "distributed
33 | dataset" object. Users need to have a mental model for how their data pipeline
34 | is distributed across many machines, which is facilitated by its direct
35 | representation as a Beam pipeline. (In our experience, building such a mental
36 | model is basically required to get good performance out of large-scale
37 | pipelines, anyways.)
38 |
39 | Implementation wise, Xarray-Beam is a _thin layer_ on top of existing libraries
40 | for working with large-scale Xarray datasets. For example, it leverages
41 | [Dask](https://dask.org/) for describing lazy arrays and for executing
42 | multi-threaded computation on a single machine.
43 |
44 | ## How does Dask compare?
45 |
46 | We love Dask! Xarray-Beam explores a different part of the design space for
47 | distributed data pipelines than Xarray's built-in Dask integration:
48 |
49 | - Xarray-Beam is built around explicit manipulation of `(xarray_beam.Key,
50 | xarray.Dataset)`. This requires more boilerplate but is also
51 | more robust than generating distributed computation graphs in Dask using
52 | Xarray's built-in API.
53 | - Xarray-Beam distributes datasets by splitting them into many
54 | `xarray.Dataset` chunks, rather than the chunks of NumPy arrays typically
55 | used by Xarray with Dask (unless using
56 | [xarray.map_blocks](http://xarray.pydata.org/en/stable/user-guide/dask.html#automatic-parallelization-with-apply-ufunc-and-map-blocks)).
57 | Chunks of datasets is a more convenient data-model for writing ad-hoc whole
58 | dataset transformations, but is potentially a bit less efficient.
59 | - Beam ([like Spark](https://docs.dask.org/en/latest/spark.html)) was designed
60 | around a higher-level model for distributed computation than Dask (although
61 | Dask has been making
62 | [progress in this direction](https://coiled.io/blog/dask-under-the-hood-scheduler-refactor/)).
63 | Roughly speaking, this trade-off favors scalability over flexibility.
64 | - Beam allows for executing distributed computation using multiple runners,
65 | notably including Google Cloud Dataflow and Apache Spark. These runners are
66 | more mature than Dask, and in many cases are supported as a service by major
67 | commercial cloud providers.
68 |
69 | 
70 |
71 | These design choices are not set in stone. In particular, in the future we
72 | _could_ imagine writing a high-level `xarray_beam.Dataset` that emulates the
73 | `xarray.Dataset` API, similar to the popular high-level DataFrame APIs in Beam,
74 | Spark and Dask. This could be built on top of the lower-level transformations
75 | currently in Xarray-Beam, or alternatively could use a "chunks of NumPy arrays"
76 | representation similar to that used by dask.array.
77 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Xarray-Beam examples
2 |
3 | The examples in this directory use the ERA5 surface dataset (as assembled by
4 | Pangeo), which consists of 19 data variables stored in float32 precision. It
5 | totals 24.8 TB in size:
6 |
7 | ```
8 | >>> xarray.open_zarr('gs://pangeo-era5/reanalysis/spatial-analysis', consolidated=True)
9 |
10 | Dimensions: (latitude: 721, longitude: 1440, time: 350640)
11 | Coordinates:
12 | * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
13 | * longitude (longitude) float32 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
14 | * time (time) datetime64[ns] 1979-01-01 ... 2018-12-31T23:00:00
15 | Data variables: (12/17)
16 | asn (time, latitude, longitude) float32 dask.array
17 | d2m (time, latitude, longitude) float32 dask.array
18 | e (time, latitude, longitude) float32 dask.array
19 | mn2t (time, latitude, longitude) float32 dask.array
20 | mx2t (time, latitude, longitude) float32 dask.array
21 | ptype (time, latitude, longitude) float32 dask.array
22 | ... ...
23 | tcc (time, latitude, longitude) float32 dask.array
24 | tcrw (time, latitude, longitude) float32 dask.array
25 | tp (time, latitude, longitude) float32 dask.array
26 | tsn (time, latitude, longitude) float32 dask.array
27 | u10 (time, latitude, longitude) float32 dask.array
28 | v10 (time, latitude, longitude) float32 dask.array
29 | Attributes:
30 | Conventions: CF-1.6
31 | history: 2019-09-20 05:15:01 GMT by grib_to_netcdf-2.10.0: /opt/ecmw...
32 | ```
33 |
34 | TODO(shoyer): add instructions for running these examples using Google Cloud
35 | DataFlow.
36 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/era5_climatology.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Calculate climatology for the Pangeo ERA5 surface dataset."""
15 | from typing import Tuple
16 |
17 | from absl import app
18 | from absl import flags
19 | import apache_beam as beam
20 | import numpy as np
21 | import xarray
22 | import xarray_beam as xbeam
23 |
24 |
25 | INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path')
26 | OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path')
27 | RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
28 |
29 |
30 | # pylint: disable=expression-not-assigned
31 |
32 |
33 | def rekey_chunk_on_month_hour(
34 | key: xbeam.Key, dataset: xarray.Dataset
35 | ) -> Tuple[xbeam.Key, xarray.Dataset]:
36 | """Replace the 'time' dimension with 'month'/'hour'."""
37 | month = dataset.time.dt.month.item()
38 | hour = dataset.time.dt.hour.item()
39 | new_key = key.with_offsets(time=None, month=month - 1, hour=hour)
40 | new_dataset = dataset.squeeze('time', drop=True).expand_dims(
41 | month=[month], hour=[hour]
42 | )
43 | return new_key, new_dataset
44 |
45 |
46 | def main(argv):
47 | source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value)
48 |
49 | # This lazy "template" allows us to setup the Zarr outputs before running the
50 | # pipeline.
51 | max_month = source_dataset.time.dt.month.max().item() # normally 12
52 | template = (
53 | xbeam.make_template(source_dataset)
54 | .isel(time=0, drop=True)
55 | .expand_dims(month=np.arange(1, max_month + 1), hour=np.arange(24))
56 | )
57 | output_chunks = {'hour': 1, 'month': 1}
58 |
59 | with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
60 | (
61 | root
62 | | xbeam.DatasetToChunks(source_dataset, source_chunks)
63 | | xbeam.SplitChunks({'time': 1})
64 | | beam.MapTuple(rekey_chunk_on_month_hour)
65 | | xbeam.Mean.PerKey()
66 | | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
67 | )
68 |
69 |
70 | if __name__ == '__main__':
71 | app.run(main)
72 |
--------------------------------------------------------------------------------
/examples/era5_climatology_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for era5_climatology."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import flagsaver
18 | import xarray
19 |
20 | from . import era5_climatology
21 | from xarray_beam._src import test_util
22 |
23 |
24 | class Era5ClimatologyTest(test_util.TestCase):
25 |
26 | def test(self):
27 | input_path = self.create_tempdir('source').full_path
28 | output_path = self.create_tempdir('destination').full_path
29 |
30 | input_ds = test_util.dummy_era5_surface_dataset(times=90 * 24, freq='1H')
31 | input_ds.chunk({'time': 31}).to_zarr(input_path)
32 |
33 | expected = input_ds.groupby('time.month').apply(
34 | lambda x: x.groupby('time.hour').mean('time')
35 | )
36 |
37 | with flagsaver.flagsaver(
38 | input_path=input_path,
39 | output_path=output_path,
40 | ):
41 | era5_climatology.main([])
42 |
43 | actual = xarray.open_zarr(output_path)
44 | xarray.testing.assert_allclose(actual, expected, atol=1e-7)
45 |
46 |
47 | if __name__ == '__main__':
48 | absltest.main()
49 |
--------------------------------------------------------------------------------
/examples/xbeam_rechunk.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Rechunk a Zarr dataset."""
15 | from typing import Dict
16 |
17 | from absl import app
18 | from absl import flags
19 | import apache_beam as beam
20 | import xarray_beam as xbeam
21 |
22 |
23 | INPUT_PATH = flags.DEFINE_string('input_path', None, help='input Zarr path')
24 | OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='output Zarr path')
25 | TARGET_CHUNKS = flags.DEFINE_string(
26 | 'target_chunks',
27 | '',
28 | help=(
29 | 'chunks on the input Zarr dataset to change on the outputs, in the '
30 | 'form of a comma separated dimension=size pairs, e.g., '
31 | "--target_chunks='x=10,y=10'. Omitted dimensions are not changed and a "
32 | 'chunksize of -1 indicates not to chunk a dimension.'
33 | ),
34 | )
35 | RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner')
36 |
37 |
38 | # pylint: disable=expression-not-assigned
39 |
40 |
41 | def _parse_chunks_str(chunks_str: str) -> Dict[str, int]:
42 | chunks = {}
43 | parts = chunks_str.split(',')
44 | for part in parts:
45 | k, v = part.split('=')
46 | chunks[k] = int(v)
47 | return chunks
48 |
49 |
50 | def main(argv):
51 | source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value)
52 | template = xbeam.make_template(source_dataset)
53 | target_chunks = dict(source_chunks, **_parse_chunks_str(TARGET_CHUNKS.value))
54 | itemsize = max(variable.dtype.itemsize for variable in template.values())
55 |
56 | with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
57 | (
58 | root
59 | | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True)
60 | | xbeam.Rechunk( # pytype: disable=wrong-arg-types
61 | source_dataset.sizes,
62 | source_chunks,
63 | target_chunks,
64 | itemsize=itemsize,
65 | )
66 | | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
67 | )
68 |
69 |
70 | if __name__ == '__main__':
71 | app.run(main)
72 |
--------------------------------------------------------------------------------
/examples/xbeam_rechunk_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for xbeam_rechunk."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import flagsaver
18 | import xarray
19 |
20 | from . import xbeam_rechunk
21 | from xarray_beam._src import test_util
22 |
23 |
24 | class Era5RechunkTest(test_util.TestCase):
25 |
26 | def test(self):
27 | input_path = self.create_tempdir('source').full_path
28 | output_path = self.create_tempdir('destination').full_path
29 |
30 | input_ds = test_util.dummy_era5_surface_dataset(times=365)
31 | input_ds.chunk({'time': 31}).to_zarr(input_path)
32 |
33 | with flagsaver.flagsaver(
34 | input_path=input_path,
35 | output_path=output_path,
36 | target_chunks='latitude=5,longitude=5,time=-1',
37 | ):
38 | xbeam_rechunk.main([])
39 |
40 | output_ds = xarray.open_zarr(output_path)
41 | self.assertEqual(
42 | {k: v[0] for k, v in output_ds.chunks.items()},
43 | {'latitude': 5, 'longitude': 5, 'time': 365}
44 | )
45 | xarray.testing.assert_identical(input_ds, output_ds)
46 |
47 |
48 | if __name__ == '__main__':
49 | absltest.main()
50 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pyink]
2 | line-length = 80
3 | preview = true
4 | pyink-indentation = 2
5 | pyink-use-majority-quotes = true
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Setup Xarray-Beam."""
16 | import setuptools
17 |
18 |
19 | base_requires = [
20 | 'apache_beam>=2.31.0',
21 | 'dask',
22 | 'immutabledict',
23 | 'rechunker>=0.5.1',
24 | 'zarr',
25 | 'xarray',
26 | ]
27 | docs_requires = [
28 | 'myst-nb',
29 | 'myst-parser',
30 | 'sphinx',
31 | 'sphinx_rtd_theme',
32 | 'scipy',
33 | ]
34 | tests_requires = [
35 | 'absl-py',
36 | 'pandas',
37 | 'pytest',
38 | 'scipy',
39 | 'h5netcdf'
40 | ]
41 |
42 | setuptools.setup(
43 | name='xarray-beam',
44 | version='0.8.1', # keep in sync with __init__.py
45 | license='Apache 2.0',
46 | author='Google LLC',
47 | author_email='noreply@google.com',
48 | install_requires=base_requires,
49 | extras_require={
50 | 'tests': tests_requires,
51 | 'docs': docs_requires,
52 | },
53 | url='https://github.com/google/xarray-beam',
54 | packages=setuptools.find_packages(exclude=['examples']),
55 | python_requires='>=3',
56 | )
57 |
--------------------------------------------------------------------------------
/xarray_beam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public API for Xarray-Beam."""
15 |
16 | # pylint: disable=g-multiple-import
17 | from xarray_beam._src.combiners import (
18 | Mean,
19 | MeanCombineFn,
20 | )
21 | from xarray_beam._src.core import (
22 | Key,
23 | DatasetToChunks,
24 | ValidateEachChunk,
25 | offsets_to_slices,
26 | validate_chunk
27 | )
28 | from xarray_beam._src.dataset import (
29 | Dataset,
30 | )
31 | from xarray_beam._src.rechunk import (
32 | ConsolidateChunks,
33 | ConsolidateVariables,
34 | SplitChunks,
35 | SplitVariables,
36 | Rechunk,
37 | consolidate_chunks,
38 | consolidate_variables,
39 | consolidate_fully,
40 | split_chunks,
41 | split_variables,
42 | in_memory_rechunk,
43 | )
44 | from xarray_beam._src.zarr import (
45 | open_zarr,
46 | make_template,
47 | replace_template_dims,
48 | setup_zarr,
49 | validate_zarr_chunk,
50 | write_chunk_to_zarr,
51 | ChunksToZarr,
52 | DatasetToZarr,
53 | )
54 |
55 | __version__ = '0.8.1' # keep in sync with setup.py
56 |
--------------------------------------------------------------------------------
/xarray_beam/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/xarray_beam/_src/combiners.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Combiners for xarray-beam."""
15 | from __future__ import annotations
16 | import dataclasses
17 | from typing import Optional, Sequence, Union
18 |
19 | import apache_beam as beam
20 | import numpy.typing as npt
21 | import xarray
22 |
23 | from xarray_beam._src import core
24 |
25 |
26 | # TODO(shoyer): add other combiners: sum, std, var, min, max, etc.
27 |
28 |
29 | DimLike = Optional[Union[str, Sequence[str]]]
30 |
31 |
32 | @dataclasses.dataclass
33 | class MeanCombineFn(beam.transforms.CombineFn):
34 | """CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
35 |
36 | dim: DimLike = None
37 | skipna: bool = True
38 | dtype: Optional[npt.DTypeLike] = None
39 |
40 | def create_accumulator(self):
41 | return (0, 0)
42 |
43 | def add_input(self, sum_count, element):
44 | (sum_, count) = sum_count
45 |
46 | if self.dtype is not None:
47 | element = element.astype(self.dtype)
48 |
49 | if self.skipna:
50 | sum_increment = element.fillna(0)
51 | count_increment = element.notnull()
52 | else:
53 | sum_increment = element
54 | count_increment = xarray.ones_like(element)
55 |
56 | if self.dim is not None:
57 | # unconditionally set skipna=False because we already explictly fill in
58 | # missing values explicitly above
59 | sum_increment = sum_increment.sum(self.dim, skipna=False)
60 | count_increment = count_increment.sum(self.dim)
61 |
62 | new_sum = sum_ + sum_increment
63 | new_count = count + count_increment
64 |
65 | return new_sum, new_count
66 |
67 | def merge_accumulators(self, accumulators):
68 | sums, counts = zip(*accumulators)
69 | return sum(sums), sum(counts)
70 |
71 | def extract_output(self, sum_count):
72 | (sum_, count) = sum_count
73 | return sum_ / count
74 |
75 | def for_input_type(self, input_type):
76 | return self
77 |
78 |
79 | @dataclasses.dataclass
80 | class Mean(beam.PTransform):
81 | """Calculate the mean over one or more distributed dataset dimensions."""
82 |
83 | dim: Union[str, Sequence[str]]
84 | skipna: bool = True
85 | dtype: Optional[npt.DTypeLike] = None
86 | fanout: Optional[int] = None
87 |
88 | def _update_key(
89 | self, key: core.Key, chunk: xarray.Dataset
90 | ) -> tuple[core.Key, xarray.Dataset]:
91 | dims = [self.dim] if isinstance(self.dim, str) else self.dim
92 | new_key = key.with_offsets(**{d: None for d in dims if d in key.offsets})
93 | return new_key, chunk
94 |
95 | def expand(self, pcoll):
96 | return (
97 | pcoll
98 | | beam.MapTuple(self._update_key)
99 | | Mean.PerKey(self.dim, self.skipna, self.dtype, self.fanout)
100 | )
101 |
102 | @dataclasses.dataclass
103 | class Globally(beam.PTransform):
104 | """Calculate global mean over a pcollection of xarray.Dataset objects."""
105 |
106 | dim: DimLike = None
107 | skipna: bool = True
108 | dtype: Optional[npt.DTypeLike] = None
109 | fanout: Optional[int] = None
110 |
111 | def expand(self, pcoll):
112 | combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype)
113 | return pcoll | beam.CombineGlobally(combine_fn).with_fanout(self.fanout)
114 |
115 | @dataclasses.dataclass
116 | class PerKey(beam.PTransform):
117 | """Calculate per-key mean over a pcollection of (hashable, Dataset)."""
118 |
119 | dim: DimLike = None
120 | skipna: bool = True
121 | dtype: Optional[npt.DTypeLike] = None
122 | fanout: Optional[int] = None
123 |
124 | def expand(self, pcoll):
125 | combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype)
126 | return pcoll | beam.CombinePerKey(combine_fn).with_hot_key_fanout(
127 | self.fanout
128 | )
129 |
--------------------------------------------------------------------------------
/xarray_beam/_src/combiners_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for xarray_beam._src.combiners."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import numpy as np
19 | import xarray
20 | import xarray_beam as xbeam
21 | from xarray_beam._src import test_util
22 |
23 |
24 | # pylint: disable=expression-not-assigned
25 | # pylint: disable=pointless-statement
26 |
27 |
28 | class MeanTest(test_util.TestCase):
29 |
30 | def test_globally(self):
31 | nan = np.nan
32 | data_with_nans = np.array(
33 | [[1, 2, 3], [4, 5, nan], [6, nan, nan], [nan, nan, nan]]
34 | )
35 | dataset = xarray.Dataset({'foo': (('x', 'y'), data_with_nans)})
36 | inputs_x = [dataset.isel(x=i) for i in range(4)]
37 | inputs_y = [dataset.isel(y=i) for i in range(3)]
38 |
39 | with self.subTest('skipna-default'):
40 | expected = dataset.mean('y', skipna=True)
41 | (actual,) = inputs_y | xbeam.Mean.Globally()
42 | xarray.testing.assert_allclose(expected, actual)
43 |
44 | with self.subTest('skipna=True'):
45 | expected = dataset.mean('y', skipna=True)
46 | (actual,) = inputs_y | xbeam.Mean.Globally(skipna=True)
47 | xarray.testing.assert_allclose(expected, actual)
48 |
49 | expected = dataset.mean('x', skipna=True)
50 | (actual,) = inputs_x | xbeam.Mean.Globally(skipna=True)
51 | xarray.testing.assert_allclose(expected, actual)
52 |
53 | with self.subTest('skipna=False', skipna=False):
54 | expected = dataset.mean('y', skipna=False)
55 | (actual,) = inputs_y | xbeam.Mean.Globally(skipna=False)
56 | xarray.testing.assert_allclose(expected, actual)
57 |
58 | expected = dataset.mean('x', skipna=False)
59 | (actual,) = inputs_x | xbeam.Mean.Globally(skipna=False)
60 | xarray.testing.assert_allclose(expected, actual)
61 |
62 | with self.subTest('with-fanout'):
63 | expected = dataset.mean('y', skipna=True)
64 | (actual,) = inputs_y | xbeam.Mean.Globally(fanout=2)
65 | xarray.testing.assert_allclose(expected, actual)
66 |
67 | def test_dim_globally(self):
68 | inputs = [
69 | xarray.Dataset({'x': ('time', [1, 2])}),
70 | xarray.Dataset({'x': ('time', [3])}),
71 | ]
72 | expected = xarray.Dataset({'x': 2.0})
73 | (actual,) = inputs | xbeam.Mean.Globally(dim='time')
74 | xarray.testing.assert_allclose(expected, actual)
75 |
76 | def test_per_key(self):
77 | inputs = [
78 | (0, xarray.Dataset({'x': 1})),
79 | (0, xarray.Dataset({'x': 2})),
80 | (1, xarray.Dataset({'x': 3})),
81 | (1, xarray.Dataset({'x': 4})),
82 | ]
83 | expected = [
84 | (0, xarray.Dataset({'x': 1.5})),
85 | (1, xarray.Dataset({'x': 3.5})),
86 | ]
87 | actual = inputs | xbeam.Mean.PerKey()
88 | self.assertAllCloseChunks(actual, expected)
89 |
90 | def test_mean_1d(self):
91 | inputs = [
92 | (xbeam.Key({'x': 0}), xarray.Dataset({'y': ('x', [1, 2, 3])})),
93 | (xbeam.Key({'x': 3}), xarray.Dataset({'y': ('x', [4, 5, 6])})),
94 | ]
95 | expected = [
96 | (xbeam.Key({}), xarray.Dataset({'y': 3.5})),
97 | ]
98 | actual = inputs | xbeam.Mean('x')
99 | self.assertAllCloseChunks(actual, expected)
100 | actual = inputs | xbeam.Mean(['x'])
101 | self.assertAllCloseChunks(actual, expected)
102 |
103 | def test_mean_many(self):
104 | inputs = []
105 | for i in range(0, 100, 10):
106 | inputs.append(
107 | (xbeam.Key({'x': i}), xarray.Dataset({'y': ('x', i + np.arange(10))}))
108 | )
109 | expected = [
110 | (xbeam.Key({}), xarray.Dataset({'y': 49.5})),
111 | ]
112 | actual = inputs | xbeam.Mean('x', fanout=2)
113 | self.assertAllCloseChunks(actual, expected)
114 |
115 | def test_mean_nans(self):
116 | nan = np.nan
117 | data_with_nans = np.array(
118 | [[1, 2, 3], [4, 5, nan], [6, nan, nan], [nan, nan, nan]]
119 | )
120 | dataset = xarray.Dataset({'foo': (('x', 'y'), data_with_nans)})
121 | eager = test_util.EagerPipeline()
122 | chunks = eager | xbeam.DatasetToChunks(dataset, {'x': 1, 'y': 1})
123 |
124 | expected = eager | xbeam.DatasetToChunks(
125 | dataset.mean('y', skipna=False), {'x': 1}
126 | )
127 | actual = chunks | xbeam.Mean('y', skipna=False)
128 | self.assertAllCloseChunks(actual, expected)
129 |
130 | expected = eager | xbeam.DatasetToChunks(
131 | dataset.mean('y', skipna=True), {'x': 1}
132 | )
133 | actual = chunks | xbeam.Mean('y', skipna=True)
134 | self.assertAllCloseChunks(actual, expected)
135 |
136 | expected = eager | xbeam.DatasetToChunks(
137 | dataset.mean('x', skipna=True), {'y': 1}
138 | )
139 | actual = chunks | xbeam.Mean('x', skipna=True)
140 | self.assertAllCloseChunks(actual, expected)
141 |
142 | expected = eager | xbeam.DatasetToChunks(
143 | dataset.mean('x', skipna=False), {'y': 1}
144 | )
145 | actual = chunks | xbeam.Mean('x', skipna=False)
146 | self.assertAllCloseChunks(actual, expected)
147 |
148 | def test_mean_2d(self):
149 | inputs = [
150 | (xbeam.Key({'y': 0}), xarray.Dataset({'z': (('x', 'y'), [[1, 2, 3]])})),
151 | (xbeam.Key({'y': 3}), xarray.Dataset({'z': (('x', 'y'), [[4, 5, 6]])})),
152 | ]
153 |
154 | expected = [
155 | (xbeam.Key({'y': 0}), xarray.Dataset({'z': ('y', [1, 2, 3])})),
156 | (xbeam.Key({'y': 3}), xarray.Dataset({'z': ('y', [4, 5, 6])})),
157 | ]
158 | actual = inputs | xbeam.Mean('x')
159 | self.assertAllCloseChunks(actual, expected)
160 |
161 | expected = [
162 | (xbeam.Key({}), xarray.Dataset({'z': (('x',), [3.5])})),
163 | ]
164 | actual = inputs | xbeam.Mean('y')
165 | self.assertAllCloseChunks(actual, expected)
166 |
167 | expected = [
168 | (xbeam.Key({}), xarray.Dataset({'z': 3.5})),
169 | ]
170 | actual = inputs | xbeam.Mean(['x', 'y'])
171 | self.assertAllCloseChunks(actual, expected)
172 |
173 |
174 | if __name__ == '__main__':
175 | absltest.main()
176 |
--------------------------------------------------------------------------------
/xarray_beam/_src/core.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Core data model for xarray-beam."""
15 | import itertools
16 | import math
17 | from typing import (
18 | AbstractSet,
19 | Dict,
20 | Generic,
21 | Iterator,
22 | List,
23 | Mapping,
24 | Optional,
25 | Sequence,
26 | Tuple,
27 | TypeVar,
28 | Union,
29 | )
30 |
31 | import apache_beam as beam
32 | import immutabledict
33 | import numpy as np
34 | import xarray
35 | from xarray_beam._src import threadmap
36 |
37 | _DEFAULT = object()
38 |
39 |
40 | class Key:
41 | """A key for keeping track of chunks of a distributed xarray.Dataset.
42 |
43 | Key object in Xarray-Beam include two components:
44 |
45 | - "offsets": an immutable dict indicating integer offsets (total number of
46 | array elements) from the origin along each dimension for this chunk.
47 | - "vars": either an frozenset or None, indicating the subset of Dataset
48 | variables included in this chunk. None means that all variables are
49 | included.
50 |
51 | Key objects are "deterministically encoded" by Beam, which makes them suitable
52 | for use as keys in Beam pipelines, i.e., with beam.GroupByKey. They are also
53 | immutable and hashable, which makes them usable as keys in Python
54 | dictionaries.
55 |
56 | Example usage::
57 |
58 | >>> key = xarray_beam.Key(offsets={'x': 10}, vars={'foo'})
59 |
60 | >>> key
61 | xarray_beam.Key(offsets={'x': 10}, vars={'foo'})
62 |
63 | >>> key.offsets
64 | immutabledict({'x': 10})
65 |
66 | >>> key.vars
67 | frozenset({'foo'})
68 |
69 | To replace some offsets::
70 |
71 | >>> key.with_offsets(y=0) # insert
72 | xarray_beam.Key(offsets={'x': 10, 'y': 0}, vars={'foo'})
73 |
74 | >>> key.with_offsets(x=20) # override
75 | xarray_beam.Key(offsets={'x': 20}, vars={'foo'})
76 |
77 | >>> key.with_offsets(x=None) # remove
78 | xarray_beam.Key(offsets={}, vars={'foo'})
79 |
80 | To entirely replace offsets or variables::
81 |
82 | >>> key.replace(offsets={'y': 0})
83 | xarray_beam.Key(offsets={'y': 0}, vars={'foo'})
84 |
85 | >>> key.replace(vars=None)
86 | xarray_beam.Key(offsets={'x': 10}, vars=None)
87 | """
88 |
89 | # pylint: disable=redefined-builtin
90 |
91 | def __init__(
92 | self,
93 | offsets: Optional[Mapping[str, int]] = None,
94 | vars: Optional[AbstractSet[str]] = None,
95 | ):
96 | if offsets is None:
97 | offsets = {}
98 | if isinstance(vars, str):
99 | raise TypeError(f"vars must be a set or None, but is {vars!r}")
100 | self.offsets = immutabledict.immutabledict(offsets)
101 | self.vars = None if vars is None else frozenset(vars)
102 |
103 | def replace(
104 | self,
105 | offsets: Union[Mapping[str, int], object] = _DEFAULT,
106 | vars: Union[AbstractSet[str], None, object] = _DEFAULT,
107 | ) -> "Key":
108 | if offsets is _DEFAULT:
109 | offsets = self.offsets
110 | if vars is _DEFAULT:
111 | vars = self.vars
112 | return type(self)(offsets, vars)
113 |
114 | def with_offsets(self, **offsets: Optional[int]) -> "Key":
115 | new_offsets = dict(self.offsets)
116 | for k, v in offsets.items():
117 | if v is None:
118 | del new_offsets[k]
119 | else:
120 | new_offsets[k] = v
121 | return self.replace(offsets=new_offsets)
122 |
123 | def __repr__(self) -> str:
124 | offsets = dict(self.offsets)
125 | vars = set(self.vars) if self.vars is not None else None
126 | return f"{type(self).__name__}(offsets={offsets}, vars={vars})"
127 |
128 | def __hash__(self) -> int:
129 | return hash((self.offsets, self.vars))
130 |
131 | def __eq__(self, other) -> bool:
132 | if not isinstance(other, Key):
133 | return NotImplemented
134 | return self.offsets == other.offsets and self.vars == other.vars
135 |
136 | def __ne__(self, other) -> bool:
137 | return not self == other
138 |
139 | # Beam uses these methods (also used for pickling) for "deterministic
140 | # encoding" of groupby keys
141 | def __getstate__(self):
142 | offsets_state = sorted(self.offsets.items())
143 | vars_state = None if self.vars is None else sorted(self.vars)
144 | return (offsets_state, vars_state)
145 |
146 | def __setstate__(self, state):
147 | self.__init__(*state)
148 |
149 |
150 | def offsets_to_slices(
151 | offsets: Mapping[str, int],
152 | sizes: Mapping[str, int],
153 | base: Optional[Mapping[str, int]] = None,
154 | ) -> Dict[str, slice]:
155 | """Convert offsets into slices with an optional base offset.
156 |
157 | Args:
158 | offsets: integer offsets from the origin along each axis.
159 | sizes: dimension sizes for the corresponding chunks.
160 | base: optional base-offset to subract from this key. This allows for
161 | relative indexing, e.g., into a chunk of a larger Dataset.
162 |
163 | Returns:
164 | Slices suitable for indexing with xarray.Dataset.isel().
165 |
166 | Raises:
167 | ValueError: if an offset is specified for a dimension where there is no
168 | corresponding size specified.
169 |
170 | Example usage::
171 |
172 | >>> offsets_to_slices({'x': 100}, sizes={'x': 10})
173 | {'x': slice(100, 110, 1)}
174 | >>> offsets_to_slices({'x': 100}, sizes={'x': 10}, base={'x': 100})
175 | {'x': slice(0, 10, 1)}
176 | """
177 | if base is None:
178 | base = {}
179 | slices = {}
180 | missing_chunk_sizes = [k for k in offsets.keys() if k not in sizes]
181 | if missing_chunk_sizes:
182 | raise ValueError(
183 | "some dimensions have offsets specified but no dimension sizes: "
184 | f"offsets={offsets} and sizes={sizes}"
185 | )
186 | for k, size in sizes.items():
187 | offset = offsets.get(k, 0) - base.get(k, 0)
188 | slices[k] = slice(offset, offset + size, 1)
189 | return slices
190 |
191 |
192 | def _chunks_to_offsets(
193 | chunks: Mapping[str, Sequence[int]],
194 | ) -> Dict[str, List[int]]:
195 | return {
196 | dim: np.concatenate([[0], np.cumsum(sizes)[:-1]]).tolist()
197 | for dim, sizes in chunks.items()
198 | }
199 |
200 |
201 | def iter_chunk_keys(
202 | offsets: Mapping[str, Sequence[int]],
203 | vars: Optional[AbstractSet[str]] = None, # pylint: disable=redefined-builtin
204 | ) -> Iterator[Key]:
205 | """Iterate over the Key objects corresponding to the given chunks."""
206 | chunk_indices = [range(len(sizes)) for sizes in offsets.values()]
207 | for indices in itertools.product(*chunk_indices):
208 | key_offsets = {
209 | dim: offsets[dim][index] for dim, index in zip(offsets, indices)
210 | }
211 | yield Key(key_offsets, vars)
212 |
213 |
214 | def compute_offset_index(
215 | offsets: Mapping[str, Sequence[int]],
216 | ) -> Dict[str, Dict[int, int]]:
217 | """Compute a mapping from chunk offsets to chunk indices."""
218 | index = {}
219 | for dim, dim_offsets in offsets.items():
220 | index[dim] = {}
221 | for i, offset in enumerate(dim_offsets):
222 | index[dim][offset] = i
223 | return index
224 |
225 |
226 | def normalize_expanded_chunks(
227 | chunks: Mapping[str, Union[int, Tuple[int, ...]]],
228 | dim_sizes: Mapping[str, int],
229 | ) -> Dict[str, Tuple[int, ...]]:
230 | # pylint: disable=g-doc-args
231 | # pylint: disable=g-doc-return-or-yield
232 | """Normalize a dict of chunks to give the expanded size of each block.
233 |
234 | Example usage::
235 |
236 | >>> normalize_expanded_chunks({'x': 3}, {'x': 10})
237 | {'x': (3, 3, 3, 1)}
238 | """
239 | result = {}
240 | for dim, dim_size in dim_sizes.items():
241 | if dim not in chunks or chunks[dim] == -1:
242 | result[dim] = (dim_size,)
243 | elif isinstance(chunks[dim], tuple):
244 | total = sum(chunks[dim])
245 | if total != dim_size:
246 | raise ValueError(
247 | f"sum of provided chunks does not match size of dimension {dim}: "
248 | f"{total} vs {dim_size}"
249 | )
250 | result[dim] = chunks[dim]
251 | else:
252 | multiple, remainder = divmod(dim_size, chunks[dim])
253 | result[dim] = multiple * (chunks[dim],) + (
254 | (remainder,) if remainder else ()
255 | )
256 | return result
257 |
258 |
259 | DatasetOrDatasets = TypeVar(
260 | "DatasetOrDatasets", xarray.Dataset, List[xarray.Dataset]
261 | )
262 |
263 |
264 | class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
265 | """Split one or more xarray.Datasets into keyed chunks."""
266 |
267 | def __init__(
268 | self,
269 | dataset: DatasetOrDatasets,
270 | chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None,
271 | split_vars: bool = False,
272 | num_threads: Optional[int] = None,
273 | shard_keys_threshold: int = 200_000,
274 | ):
275 | """Initialize DatasetToChunks.
276 |
277 | Args:
278 | dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
279 | [xarray.Dataset, ...]) pairs.
280 | chunks: optional chunking scheme. Required if the dataset is *not* already
281 | chunked. If the dataset *is* already chunked with Dask, `chunks` takes
282 | precedence over the existing chunks.
283 | split_vars: whether to split the dataset into separate records for each
284 | data variable or to keep all data variables together. This is
285 | recommended if you don't need to perform joint operations on different
286 | dataset variables and individual variable chunks are sufficiently large.
287 | num_threads: optional number of Dataset chunks to load in parallel per
288 | worker. More threads can increase throughput, but also increases memory
289 | usage and makes it harder for Beam runners to shard work. Note that each
290 | variable in a Dataset is already loaded in parallel, so this is most
291 | useful for Datasets with a small number of variables or when using
292 | split_vars=True.
293 | shard_keys_threshold: threshold at which to compute keys on Beam workers,
294 | rather than only on the host process. This is important for scaling
295 | pipelines to millions of tasks.
296 | """
297 | self.dataset = dataset
298 | self._validate(dataset, split_vars)
299 | if chunks is None:
300 | chunks = self._first.chunks
301 | if not chunks:
302 | raise ValueError("dataset must be chunked or chunks must be provided")
303 | for dim in chunks:
304 | if not any(dim in ds.dims for ds in self._datasets):
305 | raise ValueError(
306 | f"chunks key {dim!r} is not a dimension on the provided dataset(s)"
307 | )
308 | expanded_chunks = normalize_expanded_chunks(chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation
309 | self.expanded_chunks = expanded_chunks
310 | self.split_vars = split_vars
311 | self.num_threads = num_threads
312 | self.shard_keys_threshold = shard_keys_threshold
313 | # TODO(shoyer): consider recalculating these potentially large properties on
314 | # each worker, rather than only once on the host.
315 | self.offsets = _chunks_to_offsets(expanded_chunks)
316 | self.offset_index = compute_offset_index(self.offsets)
317 | # We use the simple heuristic of only sharding inputs along the dimension
318 | # with the most chunks.
319 | lengths = {k: len(v) for k, v in self.offsets.items()}
320 | self.sharded_dim = max(lengths, key=lengths.get) if lengths else None
321 | self.shard_count = self._shard_count()
322 |
323 | @property
324 | def _first(self) -> xarray.Dataset:
325 | return self._datasets[0]
326 |
327 | @property
328 | def _datasets(self) -> List[xarray.Dataset]:
329 | if isinstance(self.dataset, xarray.Dataset):
330 | return [self.dataset]
331 | return list(self.dataset) # pytype: disable=bad-return-type
332 |
333 | def _validate(self, dataset, split_vars):
334 | """Raise errors if input parameters are invalid."""
335 | if not isinstance(dataset, xarray.Dataset):
336 | if not (
337 | isinstance(dataset, list)
338 | and all(isinstance(ds, xarray.Dataset) for ds in dataset)
339 | ):
340 | raise TypeError(
341 | "'dataset' must be an 'xarray.Dataset' or 'list[xarray.Dataset]'"
342 | )
343 | if not dataset:
344 | raise ValueError("dataset list cannot be empty")
345 | for ds in self._datasets[1:]:
346 | for dim, size in ds.sizes.items():
347 | if dim not in self._first.dims:
348 | raise ValueError(
349 | f"dimension {dim} does not appear on the first dataset"
350 | )
351 | if size != self._first.sizes[dim]:
352 | raise ValueError(
353 | f"dimension {dim} has an inconsistent size on different datasets"
354 | )
355 | if split_vars:
356 | for ds in self._datasets:
357 | if not ds.keys() <= self._first.keys():
358 | raise ValueError(
359 | "inconsistent data_vars when splitting variables:"
360 | f" {tuple(ds.keys())} != {tuple(self._first.keys())}"
361 | )
362 |
363 | def _task_count(self) -> int:
364 | """Count the number of tasks emitted by this transform."""
365 | counts = {k: len(v) for k, v in self.expanded_chunks.items()}
366 | if not self.split_vars:
367 | return int(np.prod(list(counts.values())))
368 | total = 0
369 | for variable in self._first.values():
370 | count_list = [v for k, v in counts.items() if k in variable.dims]
371 | total += int(np.prod(count_list))
372 | return total
373 |
374 | def _shard_count(self) -> Optional[int]:
375 | """Determine the number of times to shard input keys."""
376 | task_count = self._task_count()
377 | if task_count <= self.shard_keys_threshold:
378 | return None # no sharding
379 |
380 | if not self.split_vars:
381 | return math.ceil(task_count / self.shard_keys_threshold)
382 |
383 | var_count = sum(
384 | self.sharded_dim in var.dims for var in self._first.values()
385 | )
386 | return math.ceil(task_count / (var_count * self.shard_keys_threshold))
387 |
388 | def _iter_all_keys(self) -> Iterator[Key]:
389 | """Iterate over all Key objects."""
390 | if not self.split_vars:
391 | yield from iter_chunk_keys(self.offsets)
392 | else:
393 | for name, variable in self._first.items():
394 | relevant_offsets = {
395 | k: v for k, v in self.offsets.items() if k in variable.dims
396 | }
397 | yield from iter_chunk_keys(relevant_offsets, vars={name}) # pytype: disable=wrong-arg-types # always-use-property-annotation
398 |
399 | def _iter_shard_keys(
400 | self, shard_id: Optional[int], var_name: Optional[str]
401 | ) -> Iterator[Key]:
402 | """Iterate over Key objects for a specific shard and variable."""
403 | if var_name is None:
404 | offsets = self.offsets
405 | else:
406 | offsets = {dim: self.offsets[dim] for dim in self._first[var_name].dims}
407 |
408 | if shard_id is None:
409 | assert self.split_vars
410 | yield from iter_chunk_keys(offsets, vars={var_name})
411 | else:
412 | assert self.split_vars == (var_name is not None)
413 | dim = self.sharded_dim
414 | count = math.ceil(len(self.offsets[dim]) / self.shard_count)
415 | dim_slice = slice(shard_id * count, (shard_id + 1) * count)
416 | offsets = {**offsets, dim: offsets[dim][dim_slice]}
417 | vars_ = {var_name} if self.split_vars else None
418 | yield from iter_chunk_keys(offsets, vars=vars_)
419 |
420 | def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
421 | """Create inputs for sharded key iterators."""
422 | if not self.split_vars:
423 | return [(i, None) for i in range(self.shard_count)]
424 |
425 | inputs = []
426 | for name, variable in self._first.items():
427 | if self.sharded_dim in variable.dims:
428 | inputs.extend([(i, name) for i in range(self.shard_count)])
429 | else:
430 | inputs.append((None, name))
431 | return inputs # pytype: disable=bad-return-type # always-use-property-annotation
432 |
433 | def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, DatasetOrDatasets]]:
434 | """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
435 | sizes = {
436 | dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
437 | for dim, offset in key.offsets.items()
438 | }
439 | slices = offsets_to_slices(key.offsets, sizes)
440 | results = []
441 | for ds in self._datasets:
442 | dataset = ds if key.vars is None else ds[list(key.vars)]
443 | valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
444 | chunk = dataset.isel(valid_slices)
445 | # Load the data, using a separate thread for each variable
446 | num_threads = len(dataset)
447 | result = chunk.chunk().compute(num_workers=num_threads)
448 | results.append(result)
449 |
450 | if isinstance(self.dataset, xarray.Dataset):
451 | yield key, results[0]
452 | else:
453 | yield key, list(results)
454 |
455 | def expand(self, pcoll):
456 | if self.shard_count is None:
457 | # Create all keys on the machine launching the Beam pipeline. This is
458 | # faster if the number of keys is small.
459 | key_pcoll = pcoll | beam.Create(self._iter_all_keys())
460 | else:
461 | # Create keys in separate shards on Beam workers. This is more scalable.
462 | key_pcoll = (
463 | pcoll
464 | | beam.Create(self._shard_inputs())
465 | | "GenerateKeys" >> beam.FlatMapTuple(self._iter_shard_keys)
466 | | beam.Reshuffle()
467 | )
468 |
469 | return key_pcoll | "KeyToChunks" >> threadmap.FlatThreadMap(
470 | self._key_to_chunks, num_threads=self.num_threads
471 | )
472 |
473 |
474 | def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None:
475 | """Verify that a key and dataset(s) are valid for xarray-beam transforms."""
476 | if isinstance(datasets, xarray.Dataset):
477 | datasets: list[xarray.Dataset] = [datasets]
478 |
479 | for dataset in datasets:
480 | # Verify that no variables are chunked with Dask
481 | for var_name, variable in dataset.variables.items():
482 | if variable.chunks is not None:
483 | raise ValueError(
484 | f"Dataset variable {var_name!r} corresponding to key {key} is"
485 | " chunked with Dask. Datasets passed to validate_chunk must be"
486 | f" fully computed (not chunked): {dataset}\nThis typically arises"
487 | " with datasets originating with `xarray.open_zarr()`, which by"
488 | " default use Dask. If this is the case, you can fix it by passing"
489 | " `chunks=None` or xarray_beam.open_zarr(). Alternatively, you"
490 | " can load datasets explicitly into memory with `.compute()`."
491 | )
492 |
493 | # Validate key offsets
494 | missing_keys = [
495 | repr(k) for k in key.offsets.keys() if k not in dataset.dims
496 | ]
497 | if missing_keys:
498 | raise ValueError(
499 | f"Key offset(s) {', '.join(missing_keys)} in {key} not found in"
500 | f" Dataset dimensions: {dataset!r}"
501 | )
502 |
503 | # Validate key vars
504 | if key.vars is not None:
505 | missing_vars = [repr(v) for v in key.vars if v not in dataset.data_vars]
506 | if missing_vars:
507 | raise ValueError(
508 | f"Key var(s) {', '.join(missing_vars)} in {key} not found in"
509 | f" Dataset data variables: {dataset!r}"
510 | )
511 |
512 |
513 | class ValidateEachChunk(beam.PTransform):
514 | """Check that keys and dataset(s) are valid for xarray-beam transforms."""
515 |
516 | def _validate(self, key, dataset):
517 | validate_chunk(key, dataset)
518 | return key, dataset
519 |
520 | def expand(self, pcoll):
521 | return pcoll | beam.MapTuple(self._validate)
522 |
--------------------------------------------------------------------------------
/xarray_beam/_src/core_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for xarray_beam._src.core."""
15 |
16 | import re
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import apache_beam as beam
20 | import immutabledict
21 | import numpy as np
22 | import xarray
23 | import xarray_beam as xbeam
24 | from xarray_beam._src import core
25 | from xarray_beam._src import test_util
26 |
27 |
28 | # pylint: disable=expression-not-assigned
29 | # pylint: disable=pointless-statement
30 |
31 |
32 | class KeyTest(test_util.TestCase):
33 |
34 | def test_constructor(self):
35 | key = xbeam.Key({'x': 0, 'y': 10})
36 | self.assertIsInstance(key.offsets, immutabledict.immutabledict)
37 | self.assertEqual(dict(key.offsets), {'x': 0, 'y': 10})
38 | self.assertIsNone(key.vars)
39 |
40 | key = xbeam.Key(vars={'foo'})
41 | self.assertEqual(dict(key.offsets), {})
42 | self.assertIsInstance(key.vars, frozenset)
43 | self.assertEqual(set(key.vars), {'foo'})
44 |
45 | with self.assertRaisesRegex(TypeError, 'vars must be a set or None'):
46 | xbeam.Key(vars='foo')
47 |
48 | def test_replace(self):
49 | key = xbeam.Key({'x': 0}, {'foo'})
50 |
51 | expected = xbeam.Key({'x': 1}, {'foo'})
52 | actual = key.replace({'x': 1})
53 | self.assertEqual(expected, actual)
54 |
55 | expected = xbeam.Key({'y': 1}, {'foo'})
56 | actual = key.replace({'y': 1})
57 | self.assertEqual(expected, actual)
58 |
59 | expected = xbeam.Key({'x': 0})
60 | actual = key.replace(vars=None)
61 | self.assertEqual(expected, actual)
62 |
63 | expected = xbeam.Key({'x': 0}, {'bar'})
64 | actual = key.replace(vars={'bar'})
65 | self.assertEqual(expected, actual)
66 |
67 | expected = xbeam.Key({'y': 1}, {'foo'})
68 | actual = key.replace({'y': 1}, {'foo'})
69 | self.assertEqual(expected, actual)
70 |
71 | expected = xbeam.Key({'y': 1}, {'bar'})
72 | actual = key.replace({'y': 1}, {'bar'})
73 | self.assertEqual(expected, actual)
74 |
75 | def test_with_offsets(self):
76 | key = xbeam.Key({'x': 0})
77 |
78 | expected = xbeam.Key({'x': 1})
79 | actual = key.with_offsets(x=1)
80 | self.assertEqual(expected, actual)
81 |
82 | expected = xbeam.Key({'x': 0, 'y': 1})
83 | actual = key.with_offsets(y=1)
84 | self.assertEqual(expected, actual)
85 |
86 | expected = xbeam.Key()
87 | actual = key.with_offsets(x=None)
88 | self.assertEqual(expected, actual)
89 |
90 | expected = xbeam.Key({'y': 1, 'z': 2})
91 | actual = key.with_offsets(x=None, y=1, z=2)
92 | self.assertEqual(expected, actual)
93 |
94 | key2 = xbeam.Key({'x': 0}, vars={'foo'})
95 | expected = xbeam.Key({'x': 1}, vars={'foo'})
96 | actual = key2.with_offsets(x=1)
97 | self.assertEqual(expected, actual)
98 |
99 | def test_repr(self):
100 | key = xbeam.Key({'x': 0, 'y': 10})
101 | expected = "Key(offsets={'x': 0, 'y': 10}, vars=None)"
102 | self.assertEqual(repr(key), expected)
103 |
104 | key = xbeam.Key(vars={'foo'})
105 | expected = "Key(offsets={}, vars={'foo'})"
106 | self.assertEqual(repr(key), expected)
107 |
108 | def test_dict_key(self):
109 | first = {xbeam.Key({'x': 0, 'y': 10}): 1}
110 | second = {xbeam.Key({'x': 0, 'y': 10}): 1}
111 | self.assertEqual(first, second)
112 |
113 | def test_equality(self):
114 | key = xbeam.Key({'x': 0, 'y': 10})
115 | self.assertEqual(key, key)
116 | self.assertNotEqual(key, None)
117 |
118 | key2 = xbeam.Key({'x': 0, 'y': 10}, {'bar'})
119 | self.assertEqual(key2, key2)
120 | self.assertNotEqual(key, key2)
121 | self.assertNotEqual(key2, key)
122 |
123 | def test_offsets_as_beam_key(self):
124 | inputs = [
125 | (xbeam.Key({'x': 0, 'y': 1}), 1),
126 | (xbeam.Key({'x': 0, 'y': 2}), 2),
127 | (xbeam.Key({'y': 1, 'x': 0}), 3),
128 | ]
129 | expected = [
130 | (xbeam.Key({'x': 0, 'y': 1}), [1, 3]),
131 | (xbeam.Key({'x': 0, 'y': 2}), [2]),
132 | ]
133 | actual = inputs | beam.GroupByKey()
134 | self.assertEqual(actual, expected)
135 |
136 | def test_vars_as_beam_key(self):
137 | inputs = [
138 | (xbeam.Key(vars={'foo'}), 1),
139 | (xbeam.Key(vars={'bar'}), 2),
140 | (xbeam.Key(vars={'foo'}), 3),
141 | ]
142 | expected = [
143 | (xbeam.Key(vars={'foo'}), [1, 3]),
144 | (xbeam.Key(vars={'bar'}), [2]),
145 | ]
146 | actual = inputs | beam.GroupByKey()
147 | self.assertEqual(actual, expected)
148 |
149 |
150 | class TestOffsetsToSlices(test_util.TestCase):
151 |
152 | def test_offsets_to_slices(self):
153 | offsets = {'x': 0, 'y': 10}
154 |
155 | expected = {'x': slice(0, 5, 1), 'y': slice(10, 20, 1)}
156 | slices = core.offsets_to_slices(offsets, {'x': 5, 'y': 10})
157 | self.assertEqual(slices, expected)
158 |
159 | expected = {
160 | 'x': slice(0, 5, 1),
161 | 'y': slice(10, 20, 1),
162 | 'extra_key': slice(0, 100, 1),
163 | }
164 | slices = core.offsets_to_slices(
165 | offsets, {'x': 5, 'y': 10, 'extra_key': 100}
166 | )
167 | self.assertEqual(slices, expected)
168 |
169 | with self.assertRaises(ValueError):
170 | core.offsets_to_slices(offsets, {'y': 10})
171 |
172 | def test_offsets_to_slices_base(self):
173 | offsets = {'x': 100, 'y': 210}
174 |
175 | base = {'x': 100, 'y': 200}
176 | expected = {'x': slice(0, 5, 1), 'y': slice(10, 20, 1)}
177 | slices = core.offsets_to_slices(offsets, {'x': 5, 'y': 10}, base=base)
178 | self.assertEqual(slices, expected)
179 |
180 | base = {'x': 100}
181 | expected = {'x': slice(0, 5, 1), 'y': slice(210, 220, 1)}
182 | slices = core.offsets_to_slices(offsets, {'x': 5, 'y': 10}, base=base)
183 | self.assertEqual(slices, expected)
184 |
185 |
186 | class DatasetToChunksTest(test_util.TestCase):
187 |
188 | def test_iter_chunk_keys(self):
189 | actual = list(core.iter_chunk_keys({'x': (0, 3), 'y': (0, 2, 4)}))
190 | expected = [
191 | xbeam.Key({'x': 0, 'y': 0}),
192 | xbeam.Key({'x': 0, 'y': 2}),
193 | xbeam.Key({'x': 0, 'y': 4}),
194 | xbeam.Key({'x': 3, 'y': 0}),
195 | xbeam.Key({'x': 3, 'y': 2}),
196 | xbeam.Key({'x': 3, 'y': 4}),
197 | ]
198 | self.assertEqual(actual, expected)
199 |
200 | def test_compute_offset_index(self):
201 | actual = core.compute_offset_index({'x': (0, 3), 'y': (0, 2, 4)})
202 | expected = {'x': {0: 0, 3: 1}, 'y': {0: 0, 2: 1, 4: 2}}
203 | self.assertEqual(actual, expected)
204 |
205 | def test_normalize_expanded_chunks(self):
206 | actual = core.normalize_expanded_chunks({}, {'x': 10})
207 | expected = {'x': (10,)}
208 | self.assertEqual(actual, expected)
209 |
210 | actual = core.normalize_expanded_chunks({'x': -1}, {'x': 10})
211 | expected = {'x': (10,)}
212 | self.assertEqual(actual, expected)
213 |
214 | actual = core.normalize_expanded_chunks({'x': (5, 5)}, {'x': 10})
215 | expected = {'x': (5, 5)}
216 | self.assertEqual(actual, expected)
217 |
218 | with self.assertRaisesRegex(
219 | ValueError,
220 | 'sum of provided chunks does not match',
221 | ):
222 | core.normalize_expanded_chunks({'x': (5, 5, 5)}, {'x': 10})
223 |
224 | actual = core.normalize_expanded_chunks({'x': 3, 'y': 2}, {'x': 9, 'y': 4})
225 | expected = {'x': (3, 3, 3), 'y': (2, 2)}
226 | self.assertEqual(actual, expected)
227 |
228 | actual = core.normalize_expanded_chunks({'x': 3}, {'x': 10})
229 | expected = {'x': (3, 3, 3, 1)}
230 | self.assertEqual(actual, expected)
231 |
232 | def test_dataset_to_chunks_multiple(self):
233 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
234 | expected = [
235 | (xbeam.Key({'x': 0}), dataset.head(x=3)),
236 | (xbeam.Key({'x': 3}), dataset.tail(x=3)),
237 | ]
238 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
239 | dataset.chunk({'x': 3})
240 | )
241 | self.assertIdenticalChunks(actual, expected)
242 |
243 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
244 | dataset.chunk({'x': 3}), num_threads=2
245 | )
246 | self.assertIdenticalChunks(actual, expected)
247 |
248 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
249 | dataset, chunks={'x': 3}
250 | )
251 | self.assertIdenticalChunks(actual, expected)
252 |
253 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
254 | dataset.chunk({'x': 3}), shard_keys_threshold=1
255 | )
256 | self.assertIdenticalChunks(actual, expected)
257 |
258 | def test_datasets_to_chunks_multiple(self):
259 | datasets = [
260 | xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(7)
261 | ]
262 | expected = [
263 | (xbeam.Key({'x': 0}), [ds.head(x=3) for ds in datasets]),
264 | (xbeam.Key({'x': 3}), [ds.tail(x=3) for ds in datasets]),
265 | ]
266 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
267 | [ds.chunk({'x': 3}) for ds in datasets]
268 | )
269 | self.assertIdenticalChunks(actual, expected)
270 |
271 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
272 | [ds.chunk({'x': 3}) for ds in datasets], num_threads=2
273 | )
274 | self.assertIdenticalChunks(actual, expected)
275 |
276 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
277 | datasets, chunks={'x': 3}
278 | )
279 | self.assertIdenticalChunks(actual, expected)
280 |
281 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
282 | [ds.chunk({'x': 3}) for ds in datasets], shard_keys_threshold=1
283 | )
284 | self.assertIdenticalChunks(actual, expected)
285 |
286 | def test_dataset_to_chunks_whole(self):
287 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
288 | expected = [(xbeam.Key({'x': 0}), dataset)]
289 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
290 | dataset, chunks={'x': -1}
291 | )
292 | self.assertIdenticalChunks(actual, expected)
293 |
294 | def test_datasets_to_chunks_whole(self):
295 | datasets = [
296 | xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(11)
297 | ]
298 | expected = [(xbeam.Key({'x': 0}), datasets)]
299 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
300 | datasets, chunks={'x': -1}
301 | )
302 | self.assertIdenticalChunks(actual, expected)
303 |
304 | def test_dataset_to_chunks_vars(self):
305 | dataset = xarray.Dataset({
306 | 'foo': ('x', np.arange(6)),
307 | 'bar': ('x', -np.arange(6)),
308 | })
309 | expected = [
310 | (xbeam.Key({'x': 0}, {'foo'}), dataset.head(x=3)[['foo']]),
311 | (xbeam.Key({'x': 0}, {'bar'}), dataset.head(x=3)[['bar']]),
312 | (xbeam.Key({'x': 3}, {'foo'}), dataset.tail(x=3)[['foo']]),
313 | (xbeam.Key({'x': 3}, {'bar'}), dataset.tail(x=3)[['bar']]),
314 | ]
315 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
316 | dataset, chunks={'x': 3}, split_vars=True
317 | )
318 | self.assertIdenticalChunks(actual, expected)
319 |
320 | def test_datasets_to_chunks_vars(self):
321 | datasets = [
322 | xarray.Dataset({
323 | 'foo': ('x', i + np.arange(6)),
324 | 'bar': ('x', i - np.arange(6)),
325 | })
326 | for i in range(12)
327 | ]
328 | expected = [
329 | (
330 | xbeam.Key({'x': 0}, {'foo'}),
331 | [ds.head(x=3)[['foo']] for ds in datasets],
332 | ),
333 | (
334 | xbeam.Key({'x': 0}, {'bar'}),
335 | [ds.head(x=3)[['bar']] for ds in datasets],
336 | ),
337 | (
338 | xbeam.Key({'x': 3}, {'foo'}),
339 | [ds.tail(x=3)[['foo']] for ds in datasets],
340 | ),
341 | (
342 | xbeam.Key({'x': 3}, {'bar'}),
343 | [ds.tail(x=3)[['bar']] for ds in datasets],
344 | ),
345 | ]
346 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
347 | datasets, chunks={'x': 3}, split_vars=True
348 | )
349 | self.assertIdenticalChunks(actual, expected)
350 |
351 | @parameterized.parameters(
352 | {'shard_keys_threshold': 1},
353 | {'shard_keys_threshold': 2},
354 | {'shard_keys_threshold': 10},
355 | )
356 | def test_dataset_to_chunks_split_with_different_dims(
357 | self, shard_keys_threshold
358 | ):
359 | dataset = xarray.Dataset({
360 | 'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])),
361 | 'bar': ('x', np.array([1, 2])),
362 | 'baz': ('z', np.array([1, 2, 3])),
363 | })
364 | expected = [
365 | (xbeam.Key({'x': 0, 'y': 0}, {'foo'}), dataset[['foo']].head(x=1)),
366 | (xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']].head(x=1)),
367 | (xbeam.Key({'x': 1, 'y': 0}, {'foo'}), dataset[['foo']].tail(x=1)),
368 | (xbeam.Key({'x': 1}, {'bar'}), dataset[['bar']].tail(x=1)),
369 | (xbeam.Key({'z': 0}, {'baz'}), dataset[['baz']]),
370 | ]
371 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
372 | dataset,
373 | chunks={'x': 1},
374 | split_vars=True,
375 | shard_keys_threshold=shard_keys_threshold,
376 | )
377 | self.assertIdenticalChunks(actual, expected)
378 |
379 | def test_dataset_to_chunks_empty(self):
380 | dataset = xarray.Dataset()
381 | expected = [(xbeam.Key({}), dataset)]
382 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
383 | dataset, chunks={}
384 | )
385 | self.assertIdenticalChunks(actual, expected)
386 |
387 | def test_datasets_to_chunks_empty(self):
388 | datasets = [xarray.Dataset() for _ in range(5)]
389 | expected = [(xbeam.Key({}), datasets)]
390 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks(
391 | datasets, chunks={}
392 | )
393 | self.assertIdenticalChunks(actual, expected)
394 |
395 | def test_task_count(self):
396 | dataset = xarray.Dataset({
397 | 'foo': (('x', 'y'), np.zeros((3, 6))),
398 | 'bar': ('x', np.zeros(3)),
399 | 'baz': ('z', np.zeros(10)),
400 | })
401 |
402 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1})
403 | self.assertEqual(to_chunks._task_count(), 3)
404 |
405 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1}, split_vars=True)
406 | self.assertEqual(to_chunks._task_count(), 7)
407 |
408 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'y': 1}, split_vars=True)
409 | self.assertEqual(to_chunks._task_count(), 8)
410 |
411 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'z': 1}, split_vars=True)
412 | self.assertEqual(to_chunks._task_count(), 12)
413 |
414 | datasets = [dataset.copy() for _ in range(9)]
415 |
416 | to_chunks = xbeam.DatasetToChunks(datasets, chunks={'x': 1})
417 | self.assertEqual(to_chunks._task_count(), 3)
418 |
419 | to_chunks = xbeam.DatasetToChunks(
420 | datasets, chunks={'x': 1}, split_vars=True
421 | )
422 | self.assertEqual(to_chunks._task_count(), 7)
423 |
424 | to_chunks = xbeam.DatasetToChunks(
425 | datasets, chunks={'y': 1}, split_vars=True
426 | )
427 | self.assertEqual(to_chunks._task_count(), 8)
428 |
429 | to_chunks = xbeam.DatasetToChunks(
430 | datasets, chunks={'z': 1}, split_vars=True
431 | )
432 | self.assertEqual(to_chunks._task_count(), 12)
433 |
434 | def test_validate(self):
435 | dataset = xarray.Dataset({
436 | 'foo': (('x', 'y'), np.zeros((3, 6))),
437 | 'bar': ('x', np.zeros(3)),
438 | 'baz': ('z', np.zeros(10)),
439 | })
440 |
441 | with self.assertRaisesWithLiteralMatch(
442 | ValueError, 'dataset must be chunked or chunks must be provided'
443 | ):
444 | test_util.EagerPipeline() | xbeam.DatasetToChunks(dataset, chunks=None)
445 |
446 | with self.assertRaisesWithLiteralMatch(
447 | ValueError,
448 | "chunks key 'invalid' is not a dimension on the provided dataset(s)",
449 | ):
450 | test_util.EagerPipeline() | xbeam.DatasetToChunks(
451 | dataset, chunks={'invalid': 1}
452 | )
453 |
454 | with self.assertRaisesWithLiteralMatch(
455 | TypeError,
456 | "'dataset' must be an 'xarray.Dataset' or 'list[xarray.Dataset]'",
457 | ):
458 | test_util.EagerPipeline() | xbeam.DatasetToChunks({'foo': dataset})
459 |
460 | with self.assertRaisesWithLiteralMatch(
461 | ValueError, 'dataset list cannot be empty'
462 | ):
463 | test_util.EagerPipeline() | xbeam.DatasetToChunks([])
464 |
465 | with self.assertRaisesWithLiteralMatch(
466 | ValueError,
467 | 'dimension z does not appear on the first dataset',
468 | ):
469 | test_util.EagerPipeline() | xbeam.DatasetToChunks(
470 | [dataset.drop_dims('z'), dataset], chunks={'x': 1}
471 | )
472 |
473 | with self.assertRaisesWithLiteralMatch(
474 | ValueError,
475 | 'dimension z has an inconsistent size on different datasets',
476 | ):
477 | test_util.EagerPipeline() | xbeam.DatasetToChunks(
478 | [dataset.isel(z=slice(5, 10), drop=True), dataset], chunks={'x': 1}
479 | )
480 |
481 | try:
482 | test_util.EagerPipeline() | xbeam.DatasetToChunks(
483 | [dataset, dataset.isel(z=0, drop=True)], chunks={'x': 1}
484 | )
485 | except ValueError:
486 | self.fail('should allow a pipeline where the first has more dimensions.')
487 |
488 | with self.assertRaisesWithLiteralMatch(
489 | ValueError,
490 | (
491 | 'inconsistent data_vars when splitting variables: '
492 | "('foo', 'bar', 'baz', 'qux') != ('foo', 'bar', 'baz')"
493 | ),
494 | ):
495 | test_util.EagerPipeline() | xbeam.DatasetToChunks(
496 | [dataset, dataset.assign(qux=2 * dataset.bar)],
497 | chunks={'x': 1},
498 | split_vars=True,
499 | )
500 |
501 | try:
502 | test_util.EagerPipeline() | xbeam.DatasetToChunks(
503 | [dataset, dataset.isel(y=0, drop=True)],
504 | chunks={'x': 1},
505 | split_vars=True,
506 | )
507 | except ValueError:
508 | self.fail('should allow a pipeline where the first has more dimensions.')
509 |
510 |
511 | class ValidateEachChunkTest(test_util.TestCase):
512 |
513 | def test_validate_chunk_raises_on_dask_chunked(self):
514 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}).chunk()
515 | key = xbeam.Key({'x': 0})
516 |
517 | with self.assertRaisesRegex(
518 | ValueError,
519 | re.escape(
520 | "Dataset variable 'foo' corresponding to key Key(offsets={'x': 0},"
521 | ' vars=None) is chunked with Dask. Datasets passed to'
522 | ' validate_chunk must be fully computed (not chunked):'
523 | ),
524 | ):
525 | core.validate_chunk(key, dataset)
526 |
527 | def test_unmatched_dimension_raises_error(self):
528 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
529 | key = xbeam.Key({'x': 0, 'y': 0})
530 | with self.assertRaisesRegex(
531 | ValueError,
532 | re.escape(
533 | "Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not "
534 | 'found in Dataset dimensions'
535 | ),
536 | ):
537 | core.validate_chunk(key, dataset)
538 |
539 | def test_unmatched_variables_raises_error_core(self):
540 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
541 | key = xbeam.Key({'x': 0}, {'bar'})
542 | with self.assertRaisesRegex(
543 | ValueError,
544 | re.escape(
545 | "Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found"
546 | ' in Dataset data variables'
547 | ),
548 | ):
549 | core.validate_chunk(key, dataset)
550 |
551 | def test_unmatched_variables_multiple_datasets_raises_error_core(self):
552 | datasets = [
553 | xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(11)
554 | ]
555 | datasets[5] = datasets[5].rename({'foo': 'bar'})
556 | key = xbeam.Key({'x': 0}, vars={'foo'})
557 |
558 | with self.assertRaisesRegex(
559 | ValueError,
560 | re.escape(
561 | "Key var(s) 'foo' in Key(offsets={'x': 0}, vars={'foo'}) "
562 | f'not found in Dataset data variables: {datasets[5]}'
563 | ),
564 | ):
565 | core.validate_chunk(key, datasets)
566 |
567 | def test_validate_chunks_compose_in_pipeline(self):
568 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
569 | expected = [(xbeam.Key({'x': 0}), dataset)]
570 | actual = (
571 | test_util.EagerPipeline()
572 | | xbeam.DatasetToChunks(dataset, chunks={'x': -1})
573 | | xbeam.ValidateEachChunk()
574 | )
575 | self.assertIdenticalChunks(actual, expected)
576 |
577 |
578 | if __name__ == '__main__':
579 | absltest.main()
580 |
--------------------------------------------------------------------------------
/xarray_beam/_src/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A high-level interface for Xarray-Beam datasets.
15 |
16 | Usage example (not fully implemented yet!):
17 |
18 | import xarray_beam as xbeam
19 |
20 | transform = (
21 | xbeam.Dataset.from_zarr(input_path)
22 | .rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
23 | .map_blocks(lambda x: x.median('time'))
24 | .to_zarr(output_path)
25 | )
26 | with beam.Pipeline() as p:
27 | p | transform
28 | """
29 | from __future__ import annotations
30 |
31 | import collections
32 | from collections import abc
33 | import dataclasses
34 | import itertools
35 | import os.path
36 | import tempfile
37 |
38 | import apache_beam as beam
39 | import xarray
40 | from xarray_beam._src import core
41 | from xarray_beam._src import zarr
42 |
43 |
44 | class _CountNamer:
45 |
46 | def __init__(self):
47 | self._counts = collections.defaultdict(itertools.count)
48 |
49 | def apply(self, name: str) -> str:
50 | return f'{name}_{next(self._counts[name])}'
51 |
52 |
53 | _get_label = _CountNamer().apply
54 |
55 |
56 | @dataclasses.dataclass
57 | class Dataset:
58 | """Experimental high-level representation of an Xarray-Beam dataset."""
59 |
60 | template: xarray.Dataset
61 | chunks: dict[str, int]
62 | split_vars: bool
63 | ptransform: beam.PTransform
64 |
65 | @classmethod
66 | def from_xarray(
67 | cls,
68 | source: xarray.Dataset,
69 | chunks: abc.Mapping[str, int],
70 | split_vars: bool = False,
71 | ) -> Dataset:
72 | """Create an xarray_beam.Dataset from an xarray.Dataset."""
73 | template = zarr.make_template(source)
74 | ptransform = _get_label('from_xarray') >> core.DatasetToChunks(
75 | source, chunks, split_vars
76 | )
77 | return cls(template, dict(chunks), split_vars, ptransform)
78 |
79 | @classmethod
80 | def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset:
81 | """Create an xarray_beam.Dataset from a zarr file."""
82 | source, chunks = zarr.open_zarr(path)
83 | result = cls.from_xarray(source, chunks, split_vars)
84 | result.ptransform = _get_label('from_zarr') >> result.ptransform
85 | return result
86 |
87 | def to_zarr(self, path: str) -> beam.PTransform:
88 | """Write to a Zarr file."""
89 | return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr(
90 | path, self.template, self.chunks
91 | )
92 |
93 | def collect_with_direct_runner(self) -> xarray.Dataset:
94 | """Collect a dataset in memory by writing it to a temp file."""
95 | # TODO(shoyer): generalize this function to something that support
96 | # alternative runners can we figure out a suitable temp file location for
97 | # distributed runners?
98 |
99 | with tempfile.TemporaryDirectory() as temp_dir:
100 | temp_path = os.path.join(temp_dir, 'tmp.zarr')
101 | with beam.Pipeline(runner='DirectRunner') as pipeline:
102 | pipeline |= self.to_zarr(temp_path)
103 | return xarray.open_zarr(temp_path).compute()
104 |
105 | # TODO(shoyer): implement map_blocks, rechunking, merge, rename, mean, etc
106 |
107 | @property
108 | def sizes(self) -> dict[str, int]:
109 | """Size of each dimension on this dataset."""
110 | return dict(self.template.sizes) # pytype: disable=bad-return-type
111 |
112 | def pipe(self, func, *args, **kwargs):
113 | return func(*args, **kwargs)
114 |
115 | def __repr__(self):
116 | base = repr(self.template)
117 | chunks_str = ', '.join(f'{k}: {v}' for k, v in self.chunks.items())
118 | return (
119 | f''
120 | + '\n'
121 | + '\n'.join(base.split('\n')[1:])
122 | )
123 |
--------------------------------------------------------------------------------
/xarray_beam/_src/dataset_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import textwrap
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import apache_beam as beam
19 | import numpy as np
20 | import xarray
21 | import xarray_beam as xbeam
22 | from xarray_beam._src import test_util
23 |
24 |
25 | class DatasetTest(test_util.TestCase):
26 |
27 | def test_from_xarray(self):
28 | ds = xarray.Dataset({'foo': ('x', np.arange(10))})
29 | beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
30 | self.assertIsInstance(beam_ds, xbeam.Dataset)
31 | self.assertEqual(beam_ds.sizes, {'x': 10})
32 | self.assertEqual(beam_ds.template.keys(), {'foo'})
33 | self.assertEqual(beam_ds.chunks, {'x': 5})
34 | self.assertFalse(beam_ds.split_vars)
35 | self.assertRegex(beam_ds.ptransform.label, r'^from_xarray_\d+$')
36 | self.assertEqual(
37 | repr(beam_ds).split('\n')[0],
38 | "",
39 | )
40 | expected = [
41 | (xbeam.Key({'x': 0}), ds.head(x=5)),
42 | (xbeam.Key({'x': 5}), ds.tail(x=5)),
43 | ]
44 | actual = test_util.EagerPipeline() | beam_ds.ptransform
45 | self.assertIdenticalChunks(expected, actual)
46 |
47 | def test_collect_with_direct_runner(self):
48 | ds = xarray.Dataset({'foo': ('x', np.arange(10))})
49 | beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
50 | collected = beam_ds.collect_with_direct_runner()
51 | xarray.testing.assert_identical(ds, collected)
52 |
53 | @parameterized.parameters(
54 | dict(split_vars=False),
55 | dict(split_vars=True),
56 | )
57 | def test_from_zarr(self, split_vars):
58 | temp_dir = self.create_tempdir().full_path
59 | ds = xarray.Dataset({'foo': ('x', np.arange(10))})
60 | ds.chunk({'x': 5}).to_zarr(temp_dir)
61 |
62 | beam_ds = xbeam.Dataset.from_zarr(temp_dir, split_vars)
63 |
64 | self.assertRegex(beam_ds.ptransform.label, r'^from_zarr_\d+$')
65 | self.assertEqual(beam_ds.chunks, {'x': 5})
66 | self.assertEqual(beam_ds.split_vars, split_vars)
67 |
68 | collected = beam_ds.collect_with_direct_runner()
69 | xarray.testing.assert_identical(ds, collected)
70 |
71 | def test_to_zarr(self):
72 | temp_dir = self.create_tempdir().full_path
73 | ds = xarray.Dataset({'foo': ('x', np.arange(10))})
74 | beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
75 | to_zarr = beam_ds.to_zarr(temp_dir)
76 |
77 | self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$')
78 | with beam.Pipeline() as p:
79 | p |= to_zarr
80 | opened = xarray.open_zarr(temp_dir).compute()
81 | xarray.testing.assert_identical(ds, opened)
82 |
83 |
84 | if __name__ == '__main__':
85 | absltest.main()
86 |
--------------------------------------------------------------------------------
/xarray_beam/_src/integration_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Integration tests for Xarray-Beam."""
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 | import apache_beam as beam
18 | import numpy as np
19 | import xarray
20 | import xarray_beam as xbeam
21 | from xarray_beam._src import test_util
22 |
23 |
24 | # pylint: disable=expression-not-assigned
25 |
26 |
27 | class IntegrationTest(test_util.TestCase):
28 |
29 | @parameterized.named_parameters(
30 | {
31 | 'testcase_name': 'eager_unified',
32 | 'template_method': 'eager',
33 | 'split_vars': False,
34 | },
35 | {
36 | 'testcase_name': 'eager_split',
37 | 'template_method': 'eager',
38 | 'split_vars': True,
39 | },
40 | {
41 | 'testcase_name': 'eager_unified_sharded',
42 | 'template_method': 'eager',
43 | 'split_vars': False,
44 | 'shard_keys_threshold': 20,
45 | },
46 | {
47 | 'testcase_name': 'eager_split_sharded',
48 | 'template_method': 'eager',
49 | 'split_vars': True,
50 | 'shard_keys_threshold': 20,
51 | },
52 | {
53 | 'testcase_name': 'lazy_unified',
54 | 'template_method': 'lazy',
55 | 'split_vars': False,
56 | },
57 | {
58 | 'testcase_name': 'infer_unified',
59 | 'template_method': 'infer',
60 | 'split_vars': False,
61 | },
62 | {
63 | 'testcase_name': 'infer_split',
64 | 'template_method': 'infer',
65 | 'split_vars': True,
66 | },
67 | )
68 | def test_rechunk_zarr_to_zarr(
69 | self, template_method, split_vars, shard_keys_threshold=1_000_000
70 | ):
71 | src_dir = self.create_tempdir('source').full_path
72 | dest_dir = self.create_tempdir('destination').full_path
73 |
74 | source_chunks = {'t': 1, 'x': 100, 'y': 120}
75 | target_chunks = {'t': -1, 'x': 20, 'y': 20}
76 |
77 | rs = np.random.RandomState(0)
78 | raw_data = rs.randint(2**30, size=(60, 100, 120)) # 5.76 MB
79 | dataset = xarray.Dataset(
80 | {
81 | 'foo': (('t', 'x', 'y'), raw_data),
82 | 'bar': (('t', 'x', 'y'), raw_data - 1),
83 | }
84 | )
85 | dataset.chunk(source_chunks).to_zarr(src_dir, consolidated=True)
86 |
87 | on_disk = xarray.open_zarr(src_dir, consolidated=True)
88 | on_disk_chunked = on_disk.chunk(target_chunks)
89 | with beam.Pipeline('DirectRunner') as pipeline:
90 | # make template
91 | if template_method == 'eager':
92 | target_template = on_disk_chunked
93 | elif template_method == 'lazy':
94 | target_template = beam.pvalue.AsSingleton(
95 | pipeline | beam.Create([on_disk_chunked])
96 | )
97 | elif template_method == 'infer':
98 | target_template = None
99 | # run pipeline
100 | (
101 | pipeline
102 | | xbeam.DatasetToChunks(
103 | on_disk,
104 | split_vars=split_vars,
105 | shard_keys_threshold=shard_keys_threshold,
106 | )
107 | | xbeam.Rechunk(
108 | on_disk.sizes,
109 | source_chunks,
110 | target_chunks,
111 | itemsize=8,
112 | max_mem=10_000_000, # require two stages
113 | )
114 | | xbeam.ChunksToZarr(dest_dir, target_template)
115 | )
116 | roundtripped = xarray.open_zarr(dest_dir, consolidated=True, chunks=False)
117 |
118 | xarray.testing.assert_identical(roundtripped, dataset)
119 |
120 | @parameterized.named_parameters(
121 | {
122 | 'testcase_name': 'unified_unsharded',
123 | 'split_vars': False,
124 | 'shard_keys_threshold': 1_000_000,
125 | },
126 | {
127 | 'testcase_name': 'split_unsharded',
128 | 'split_vars': True,
129 | 'shard_keys_threshold': 1_000_000,
130 | },
131 | {
132 | 'testcase_name': 'unified_sharded',
133 | 'split_vars': False,
134 | 'shard_keys_threshold': 3,
135 | },
136 | {
137 | 'testcase_name': 'split_sharded',
138 | 'split_vars': True,
139 | 'shard_keys_threshold': 3,
140 | },
141 | )
142 | def test_dataset_to_zarr_with_irregular_variables(
143 | self, split_vars, shard_keys_threshold
144 | ):
145 | dataset = xarray.Dataset(
146 | {
147 | 'volume1': (
148 | ('t', 'x', 'y', 'z1'),
149 | np.arange(240).reshape(10, 2, 3, 4),
150 | ),
151 | 'volume2': (
152 | ('t', 'x', 'y', 'z2'),
153 | np.arange(300).reshape(10, 2, 3, 5),
154 | ),
155 | 'surface': (('t', 'x', 'y'), np.arange(60).reshape(10, 2, 3)),
156 | 'static': (('x', 'y'), np.arange(6).reshape(2, 3)),
157 | }
158 | )
159 | temp_dir = self.create_tempdir().full_path
160 | template = dataset.chunk()
161 | chunks = {'t': 1, 'z1': 2, 'z2': 3}
162 | (
163 | test_util.EagerPipeline()
164 | | xbeam.DatasetToChunks(
165 | dataset,
166 | chunks,
167 | split_vars=split_vars,
168 | shard_keys_threshold=shard_keys_threshold,
169 | )
170 | | xbeam.ChunksToZarr(temp_dir, template, chunks)
171 | )
172 | actual = xarray.open_zarr(temp_dir, consolidated=True)
173 | xarray.testing.assert_identical(actual, dataset)
174 |
175 |
176 | if __name__ == '__main__':
177 | absltest.main()
178 |
--------------------------------------------------------------------------------
/xarray_beam/_src/rechunk.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Rechunking for xarray.Dataset objets."""
15 | import collections
16 | import dataclasses
17 | import itertools
18 | import logging
19 | import math
20 | import textwrap
21 | from typing import (
22 | Any,
23 | Dict,
24 | Iterable,
25 | Iterator,
26 | List,
27 | Optional,
28 | Mapping,
29 | Sequence,
30 | Tuple,
31 | Union,
32 | )
33 |
34 | import apache_beam as beam
35 | import numpy as np
36 | from rechunker import algorithm
37 | import xarray
38 |
39 | from xarray_beam._src import core
40 |
41 |
42 | # pylint: disable=logging-not-lazy
43 | # pylint: disable=logging-fstring-interpolation
44 |
45 |
46 | def normalize_chunks(
47 | chunks: Mapping[str, Union[int, Tuple[int, ...]]],
48 | dim_sizes: Mapping[str, int],
49 | ) -> Dict[str, int]:
50 | """Normalize a dict of chunks."""
51 | if not chunks.keys() <= dim_sizes.keys():
52 | raise ValueError(
53 | 'all dimensions used in chunks must also have an indicated size: '
54 | f'chunks={chunks} vs dim_sizes={dim_sizes}'
55 | )
56 | result = {}
57 | for dim, size in dim_sizes.items():
58 | if dim not in chunks:
59 | result[dim] = size
60 | elif isinstance(chunks[dim], tuple):
61 | unique_chunks = set(chunks[dim])
62 | if len(unique_chunks) != 1:
63 | raise ValueError(
64 | f'chunks for dimension {dim} are not constant: {unique_chunks}',
65 | )
66 | (result[dim],) = unique_chunks
67 | elif chunks[dim] == -1:
68 | result[dim] = size
69 | else:
70 | result[dim] = chunks[dim]
71 | return result
72 |
73 |
74 | def rechunking_plan(
75 | dim_sizes: Mapping[str, int],
76 | source_chunks: Mapping[str, int],
77 | target_chunks: Mapping[str, int],
78 | itemsize: int,
79 | min_mem: int,
80 | max_mem: int,
81 | ) -> List[List[Dict[str, int]]]:
82 | """Make a rechunking plan."""
83 | stages = algorithm.multistage_rechunking_plan(
84 | shape=tuple(dim_sizes.values()),
85 | source_chunks=tuple(source_chunks[dim] for dim in dim_sizes),
86 | target_chunks=tuple(target_chunks[dim] for dim in dim_sizes),
87 | itemsize=itemsize,
88 | min_mem=min_mem,
89 | max_mem=max_mem,
90 | )
91 | plan = []
92 | for stage in stages:
93 | plan.append([dict(zip(dim_sizes.keys(), shapes)) for shapes in stage])
94 | return plan
95 |
96 |
97 | def _consolidate_chunks_in_var_group(
98 | inputs: Sequence[Tuple[core.Key, xarray.Dataset]],
99 | combine_kwargs: Optional[Mapping[str, Any]],
100 | ) -> Tuple[core.Key, xarray.Dataset]:
101 | """Consolidate chunks across offsets with identical vars."""
102 | unique_offsets = collections.defaultdict(set)
103 | unique_var_groups = set()
104 | for key, chunk in inputs:
105 | for dim, offset in key.offsets.items():
106 | unique_offsets[dim].add(offset)
107 | unique_var_groups.add(key.vars)
108 |
109 | if len(unique_var_groups) != 1:
110 | raise ValueError(
111 | f'expected exactly one unique var group, got {unique_var_groups}'
112 | )
113 | (cur_vars,) = unique_var_groups
114 |
115 | offsets = {k: sorted(v) for k, v in unique_offsets.items()}
116 | combined_offsets = {k: v[0] for k, v in offsets.items()}
117 | combined_key = core.Key(combined_offsets, cur_vars)
118 |
119 | # Consolidate inputs in a single xarray.Dataset.
120 | # `inputs` is a flat list like `[(k_00, ds_00), (k_01, ds_01), ...]` where
121 | # `k_ij` is a Key giving the (multi-dimensional) index of `ds_ij` in a
122 | # virtual larger Dataset.
123 | # Now we want to actually concatenate along all those dimensions, e.g., the
124 | # equivalent of building a large matrix out of sub-matrices:
125 | # ⌈[x_00 x_01] ...⌉ ⌈x_00 x_01 ...⌉
126 | # X = |[x_10 x_11] ...| = |x_10 x_11 ...|
127 | # |[x_20 x_21] ...| |x_20 x_21 ...|
128 | # ⌊ ... ...⌋ ⌊ ... ... ...⌋
129 | # In NumPy, this would be done with `np.block()`.
130 | offset_index = core.compute_offset_index(offsets)
131 | shape = [len(v) for v in offsets.values()]
132 | nested_array = np.empty(dtype=object, shape=shape)
133 | if np.prod(shape) != len(inputs):
134 | raise ValueError(
135 | f'some expected chunks are missing for vars={cur_vars} '
136 | f'shape: {shape} len(inputs): {len(inputs)}'
137 | )
138 |
139 | for key, chunk in inputs:
140 | nested_key = tuple(offset_index[dim][key.offsets[dim]] for dim in offsets)
141 | assert nested_array[nested_key] is None
142 | nested_array[nested_key] = chunk
143 |
144 | kwargs = dict(
145 | data_vars='minimal',
146 | coords='minimal',
147 | join='exact',
148 | combine_attrs='override',
149 | )
150 | if combine_kwargs is not None:
151 | kwargs.update(combine_kwargs)
152 |
153 | try:
154 | combined_dataset = xarray.combine_nested(
155 | nested_array.tolist(), concat_dim=list(offsets), **kwargs
156 | )
157 | return combined_key, combined_dataset
158 | except (ValueError, xarray.MergeError) as original_error:
159 | summaries = []
160 | for axis, dim in enumerate(offsets):
161 | repr_string = '\n'.join(
162 | repr(ds) for ds in nested_array[(0,) * axis + (slice(2),)].tolist()
163 | )
164 | if nested_array.shape[axis] > 2:
165 | repr_string += '\n...'
166 | repr_string = textwrap.indent(repr_string, prefix=' ')
167 | summaries.append(
168 | f'Leading datasets along dimension {dim!r}:\n{repr_string}'
169 | )
170 | summaries_str = '\n'.join(summaries)
171 | raise ValueError(
172 | f'combining nested dataset chunks for vars={cur_vars} with '
173 | f'offsets={offsets} failed.\n'
174 | + summaries_str
175 | ) from original_error
176 |
177 |
178 | def consolidate_chunks(
179 | inputs: Iterable[Tuple[core.Key, xarray.Dataset]],
180 | combine_kwargs: Optional[Mapping[str, Any]] = None,
181 | ) -> Iterable[Tuple[core.Key, xarray.Dataset]]:
182 | """Consolidate chunks across offsets into (Key, Dataset) pairs."""
183 | inputs = list(inputs)
184 | keys = [key for key, _ in inputs]
185 | if len(set(keys)) < len(keys):
186 | raise ValueError(f'chunk keys are not unique: {keys}')
187 |
188 | # Group chunks by variable groups and combine offsets to validate inputs.
189 | inputs_by_vars = collections.defaultdict(list)
190 | combined_offsets_by_dim = collections.defaultdict(set)
191 | combined_offsets_by_vars = collections.defaultdict(set)
192 | for key, chunk in inputs:
193 | inputs_by_vars[key.vars].append((key, chunk))
194 | for dim, offset in key.offsets.items():
195 | combined_offsets_by_dim[dim].add(offset)
196 | combined_offsets_by_vars[(key.vars, dim)].add(offset)
197 |
198 | # All var groups need to have the exact same set of offsets on common
199 | # dimensions.
200 | for (cur_vars, dim), offsets in combined_offsets_by_vars.items():
201 | if offsets != combined_offsets_by_dim[dim]:
202 | raise ValueError(f'some expected chunks are missing for vars={cur_vars}')
203 |
204 | for cur_vars, cur_inputs in inputs_by_vars.items():
205 | combined_key, combined_dataset = _consolidate_chunks_in_var_group(
206 | cur_inputs, combine_kwargs
207 | )
208 | yield combined_key, combined_dataset
209 |
210 |
211 | def consolidate_variables(
212 | inputs: Iterable[Tuple[core.Key, xarray.Dataset]],
213 | merge_kwargs: Optional[Mapping[str, Any]] = None,
214 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
215 | """Consolidate chunks across distinct variables into (Key, Dataset) pairs."""
216 | kwargs = dict(
217 | compat='equals',
218 | join='exact',
219 | combine_attrs='override',
220 | )
221 | if merge_kwargs is not None:
222 | kwargs.update(merge_kwargs)
223 |
224 | chunks_by_offsets = collections.defaultdict(list)
225 | for key, chunk in inputs:
226 | chunks_by_offsets[key.offsets].append(chunk)
227 |
228 | for offsets, chunks in chunks_by_offsets.items():
229 | all_vars = [set(chunk.keys()) for chunk in chunks]
230 | new_vars = set.union(*all_vars)
231 | if len(new_vars) != sum(map(len, all_vars)):
232 | raise ValueError(
233 | f'cannot merge chunks with overlapping variables: {all_vars}'
234 | )
235 | key = core.Key(offsets, new_vars)
236 |
237 | try:
238 | dataset = xarray.merge(chunks, **kwargs)
239 | except (ValueError, xarray.MergeError) as original_error:
240 | repr_string = '\n'.join(repr(ds) for ds in chunks[:2])
241 | if len(chunks) > 2:
242 | repr_string += '\n...'
243 | repr_string = textwrap.indent(repr_string, prefix=' ')
244 | raise ValueError(
245 | f'merging dataset chunks with variables {all_vars} failed.\n'
246 | + repr_string
247 | ) from original_error
248 | yield key, dataset
249 |
250 |
251 | def consolidate_fully(
252 | inputs: Iterable[Tuple[core.Key, xarray.Dataset]],
253 | *,
254 | merge_kwargs: Optional[Mapping[str, Any]] = None,
255 | combine_kwargs: Optional[Mapping[str, Any]] = None,
256 | ) -> Tuple[core.Key, xarray.Dataset]:
257 | """Consolidate chunks via merge/concat into a single (Key, Dataset) pair."""
258 | concatenated_chunks = []
259 | combined_offsets = {}
260 | combined_vars = set()
261 | for key, chunk in consolidate_chunks(inputs, combine_kwargs):
262 | # We expect all chunks to be fully combined in all dimensions and all chunks
263 | # to have the same offset (in each dimension). The chunks from
264 | # consolidate_chunks() should already have this property but we explicitly
265 | # check it here again in case consolidate_chunks changes.
266 | for dim, offset in key.offsets.items():
267 | if dim in combined_offsets and combined_offsets[dim] != offset:
268 | raise ValueError(
269 | 'consolidating chunks fully failed because '
270 | f'chunk\n{chunk}\n has offsets {key.offsets} '
271 | f'that differ from {combined_offsets}'
272 | )
273 | combined_offsets[dim] = offset
274 | concatenated_chunks.append(chunk)
275 | combined_vars.update(chunk.keys())
276 |
277 | # Merge variables, but unlike consolidate_variables, we merge all chunks and
278 | # not just chunks per unique key.
279 | kwargs = dict(
280 | compat='equals',
281 | join='exact',
282 | combine_attrs='override',
283 | )
284 | if merge_kwargs is not None:
285 | kwargs.update(merge_kwargs)
286 |
287 | try:
288 | dataset = xarray.merge(concatenated_chunks, **kwargs)
289 | except (ValueError, xarray.MergeError) as original_error:
290 | repr_string = '\n'.join(repr(ds) for ds in concatenated_chunks[:2])
291 | if len(concatenated_chunks) > 2:
292 | repr_string += '\n...'
293 | repr_string = textwrap.indent(repr_string, prefix=' ')
294 | raise ValueError(
295 | f'merging dataset chunks with variables {combined_vars} failed.\n'
296 | + repr_string
297 | ) from original_error
298 | return core.Key(combined_offsets, combined_vars), dataset # pytype: disable=wrong-arg-types
299 |
300 |
301 | class _ConsolidateBase(beam.PTransform):
302 |
303 | def expand(self, pcoll):
304 | return (
305 | pcoll
306 | | 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key)
307 | | 'GroupByTempKeys' >> beam.GroupByKey()
308 | | 'Consolidate' >> beam.MapTuple(self._consolidate_chunks)
309 | )
310 |
311 |
312 | def _round_chunk_key(
313 | key: core.Key,
314 | target_chunks: Mapping[str, int],
315 | ) -> core.Key:
316 | """Round down a chunk-key to offsets corresponding to new chunks."""
317 | new_offsets = {}
318 | for dim, offset in key.offsets.items():
319 | chunk_size = target_chunks.get(dim)
320 | if chunk_size is None:
321 | new_offsets[dim] = offset
322 | elif chunk_size == -1:
323 | new_offsets[dim] = 0
324 | else:
325 | new_offsets[dim] = chunk_size * (offset // chunk_size)
326 | return key.replace(new_offsets)
327 |
328 |
329 | @dataclasses.dataclass
330 | class ConsolidateChunks(beam.PTransform):
331 | """Consolidate existing chunks across offsets into bigger chunks."""
332 |
333 | target_chunks: Mapping[str, int]
334 |
335 | def _prepend_chunk_key(self, key, chunk):
336 | rounded_key = _round_chunk_key(key, self.target_chunks)
337 | return rounded_key, (key, chunk)
338 |
339 | def _consolidate(self, key, inputs):
340 | ((consolidated_key, dataset),) = consolidate_chunks(inputs)
341 | assert key == consolidated_key, (key, consolidated_key)
342 | return consolidated_key, dataset
343 |
344 | def expand(self, pcoll):
345 | return (
346 | pcoll
347 | | 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key)
348 | | 'GroupByTempKeys' >> beam.GroupByKey()
349 | | 'Consolidate' >> beam.MapTuple(self._consolidate)
350 | )
351 |
352 |
353 | class ConsolidateVariables(beam.PTransform):
354 | """Consolidate existing chunks across variables into bigger chunks."""
355 |
356 | # TODO(shoyer): add support for partial consolidation into explicit sets
357 | # of variables.
358 |
359 | def _prepend_chunk_key(self, key, chunk):
360 | return key.replace(vars=None), (key, chunk)
361 |
362 | def _consolidate(self, key, inputs):
363 | ((consolidated_key, dataset),) = consolidate_variables(inputs)
364 | assert key.offsets == consolidated_key.offsets, (key, consolidated_key)
365 | assert key.vars is None
366 | # TODO(shoyer): consider carefully whether it is better to return key or
367 | # consolidated_key. They are both valid in the xarray-beam data model -- the
368 | # difference is whether vars=None or is an explicit set of variables.
369 | # For now, conservatively return the version of key with vars=None so
370 | # users don't rely on it.
371 | return key, dataset
372 |
373 | def expand(self, pcoll):
374 | return (
375 | pcoll
376 | | 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key)
377 | | 'GroupByTempKeys' >> beam.GroupByKey()
378 | | 'Consolidate' >> beam.MapTuple(self._consolidate)
379 | )
380 |
381 |
382 | def _split_chunk_bounds(
383 | start: int,
384 | stop: int,
385 | multiple: int,
386 | ) -> List[Tuple[int, int]]:
387 | # pylint: disable=g-doc-args
388 | # pylint: disable=g-doc-return-or-yield
389 | """Calculate the size of divided chunks along a dimension.
390 |
391 | Example usage:
392 |
393 | >>> _split_chunk_bounds(0, 10, 3)
394 | [(0, 3), (3, 6), (6, 9), (9, 10)]
395 | >>> _split_chunk_bounds(5, 10, 3)
396 | [(5, 6), (6, 9), (9, 10)]
397 | >>> _split_chunk_bounds(10, 20, 12)
398 | [(10, 12), (12, 20)]
399 | """
400 | if multiple == -1:
401 | return [(start, stop)]
402 | assert start >= 0 and stop > start and multiple > 0, (start, stop, multiple)
403 | first_multiple = (start // multiple + 1) * multiple
404 | breaks = list(range(first_multiple, stop, multiple))
405 | return list(zip([start] + breaks, breaks + [stop]))
406 |
407 |
408 | def split_chunks(
409 | key: core.Key,
410 | dataset: xarray.Dataset,
411 | target_chunks: Mapping[str, int],
412 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
413 | """Split a single (Key, xarray.Dataset) pair into many chunks."""
414 | # This function splits consolidated arrays into blocks of new sizes, e.g.,
415 | # ⌈x_00 x_01 ...⌉ ⌈⌈x_00⌉ ⌈x_01⌉ ...⌉
416 | # X = |x_10 x_11 ...| = ||x_10| |x_11| ...|
417 | # |x_20 x_21 ...| |⌊x_20⌋ ⌊x_21⌋ ...|
418 | # ⌊ ... ... ...⌋ ⌊ ... ... ...⌋
419 | # and emits them as (Key, xarray.Dataset) pairs.
420 | all_bounds = []
421 | for dim, chunk_size in target_chunks.items():
422 | start = key.offsets.get(dim, 0)
423 | stop = start + dataset.sizes[dim]
424 | all_bounds.append(_split_chunk_bounds(start, stop, chunk_size))
425 |
426 | for bounds in itertools.product(*all_bounds):
427 | new_offsets = dict(key.offsets)
428 | slices = {}
429 | for dim, (start, stop) in zip(target_chunks, bounds):
430 | base = key.offsets.get(dim, 0)
431 | new_offsets[dim] = start
432 | slices[dim] = slice(start - base, stop - base)
433 |
434 | new_key = key.replace(new_offsets)
435 | new_chunk = dataset.isel(slices)
436 | yield new_key, new_chunk
437 |
438 |
439 | @dataclasses.dataclass
440 | class SplitChunks(beam.PTransform):
441 | """Split existing chunks into smaller chunks."""
442 |
443 | target_chunks: Mapping[str, int]
444 |
445 | def _split_chunks(self, key, dataset):
446 | target_chunks = {
447 | k: v for k, v in self.target_chunks.items() if k in dataset.dims
448 | }
449 | yield from split_chunks(key, dataset, target_chunks)
450 |
451 | def expand(self, pcoll):
452 | return pcoll | beam.FlatMapTuple(self._split_chunks)
453 |
454 |
455 | def split_variables(
456 | key: core.Key,
457 | dataset: xarray.Dataset,
458 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
459 | """Split a single (Key, xarray.Dataset) pair into separate variables."""
460 | # TODO(shoyer): add support for partial splitting, into explicitly provided
461 | # sets of variables
462 | for var_name in dataset:
463 | new_dataset = dataset[[var_name]]
464 | offsets = {k: v for k, v in key.offsets.items() if k in new_dataset.dims}
465 | new_key = core.Key(offsets, vars={var_name}) # pytype: disable=wrong-arg-types
466 | yield new_key, new_dataset
467 |
468 |
469 | @dataclasses.dataclass
470 | class SplitVariables(beam.PTransform):
471 | """Split existing chunks into a separate chunk per data variable."""
472 |
473 | def expand(self, pcoll):
474 | return pcoll | beam.FlatMapTuple(split_variables)
475 |
476 |
477 | def in_memory_rechunk(
478 | inputs: List[Tuple[core.Key, xarray.Dataset]],
479 | target_chunks: Mapping[str, int],
480 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
481 | """Rechunk in-memory pairs of (Key, xarray.Dataset)."""
482 | consolidated = consolidate_chunks(inputs)
483 | for key, dataset in consolidated:
484 | yield from split_chunks(key, dataset, target_chunks)
485 |
486 |
487 | @dataclasses.dataclass
488 | class RechunkStage(beam.PTransform):
489 | """A single stage of a rechunking pipeline."""
490 |
491 | source_chunks: Mapping[str, int]
492 | target_chunks: Mapping[str, int]
493 |
494 | def expand(self, pcoll):
495 | source_values = self.source_chunks.values()
496 | target_values = self.target_chunks.values()
497 | if any(t % s for s, t in zip(source_values, target_values)):
498 | pcoll |= 'Split' >> SplitChunks(self.target_chunks)
499 | if any(s % t for s, t in zip(source_values, target_values)):
500 | pcoll |= 'Consolidate' >> ConsolidateChunks(self.target_chunks)
501 | return pcoll
502 |
503 |
504 | class Rechunk(beam.PTransform):
505 | """Rechunk to an arbitrary new chunking scheme with bounded memory usage.
506 |
507 | The approach taken here builds on Rechunker [1], but differs in two key ways:
508 |
509 | 1. It is performed via collective Beam operations, instead of writing
510 | intermediates arrays to disk.
511 | 2. It is performed collectively on full xarray.Dataset objects, instead of
512 | NumPy arrays.
513 |
514 | [1] rechunker.readthedocs.io
515 | """
516 |
517 | def __init__(
518 | self,
519 | dim_sizes: Mapping[str, int],
520 | source_chunks: Mapping[str, Union[int, Tuple[int, ...]]],
521 | target_chunks: Mapping[str, Union[int, Tuple[int, ...]]],
522 | itemsize: int,
523 | min_mem: Optional[int] = None,
524 | max_mem: int = 2 ** 30, # 1 GB
525 | ):
526 | """Initialize Rechunk().
527 |
528 | Args:
529 | dim_sizes: size of the full (combined) dataset of all chunks.
530 | source_chunks: sizes of source chunks. Missing keys or values equal to -1
531 | indicate "non-chunked" dimensions.
532 | target_chunks: sizes of target chunks, like `source_keys`. Keys must
533 | exactly match those found in source_chunks.
534 | itemsize: approximate number of bytes per xarray.Dataset element, after
535 | indexing out by all dimensions, e.g., `4 * len(dataset)` for float32
536 | data or roughly `dataset.nbytes / np.prod(dataset.sizes)`.
537 | min_mem: minimum memory that a single intermediate chunk must consume.
538 | max_mem: maximum memory that a single intermediate chunk may consume.
539 | """
540 | if source_chunks.keys() != target_chunks.keys():
541 | raise ValueError(
542 | 'source_chunks and target_chunks have different keys: '
543 | f'{source_chunks} vs {target_chunks}'
544 | )
545 | if min_mem is None:
546 | min_mem = max_mem // 100
547 | self.dim_sizes = dim_sizes
548 | self.source_chunks = normalize_chunks(source_chunks, dim_sizes)
549 | self.target_chunks = normalize_chunks(target_chunks, dim_sizes)
550 |
551 | if self.source_chunks == self.target_chunks:
552 | self.stage_in = self.stage_out = []
553 | logging.info(f'Rechunk with chunks {self.source_chunks} is a no-op')
554 | return
555 |
556 | plan = rechunking_plan(
557 | dim_sizes,
558 | self.source_chunks,
559 | self.target_chunks,
560 | itemsize=itemsize,
561 | min_mem=min_mem,
562 | max_mem=max_mem,
563 | )
564 | plan = (
565 | [[self.source_chunks, self.source_chunks, plan[0][0]]]
566 | + plan
567 | + [[plan[-1][-1], self.target_chunks, self.target_chunks]]
568 | )
569 | self.stage_in, (_, *intermediates, _), self.stage_out = zip(*plan)
570 |
571 | logging.info(
572 | 'Rechunking plan:\n'
573 | + '\n'.join(
574 | f'Stage{i}: {s} -> {t}'
575 | for i, (s, t) in enumerate(zip(self.stage_in, self.stage_out))
576 | )
577 | )
578 | min_size = min(
579 | itemsize * math.prod(chunks.values()) for chunks in intermediates
580 | )
581 | logging.info(f'Smallest intermediates have size {min_size:1.3e}')
582 |
583 | def expand(self, pcoll):
584 | for stage, (in_chunks, out_chunks) in enumerate(
585 | zip(self.stage_in, self.stage_out)
586 | ):
587 | pcoll |= f'Stage{stage}' >> RechunkStage(in_chunks, out_chunks)
588 | return pcoll
589 |
--------------------------------------------------------------------------------
/xarray_beam/_src/test_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Testing utilities for Xarray-Beam."""
15 | import pickle
16 | import tempfile
17 |
18 | from absl.testing import parameterized
19 | import apache_beam as beam
20 | import numpy as np
21 | import pandas as pd
22 | import xarray
23 |
24 | # pylint: disable=expression-not-assigned
25 |
26 |
27 | def _write_pickle(value, path):
28 | with open(path, 'wb') as f:
29 | pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)
30 |
31 |
32 | class EagerPipeline:
33 | """A mock Beam pipeline for testing that returns lists of Python objects.
34 |
35 | Example usage:
36 |
37 | >>> EagerPipeline() | beam.Create([1, 2, 3]) | beam.Map(lambda x: x**2)
38 | [1, 4, 9]
39 | """
40 |
41 | def __or__(self, ptransform):
42 | with tempfile.NamedTemporaryFile() as f:
43 | with beam.Pipeline('DirectRunner') as pipeline:
44 | (
45 | pipeline
46 | | ptransform
47 | | beam.combiners.ToList()
48 | | beam.Map(_write_pickle, f.name)
49 | )
50 | pipeline.run()
51 | return pickle.load(f)
52 |
53 |
54 | class TestCase(parameterized.TestCase):
55 | """TestCase for use in internal Xarray-Beam tests."""
56 |
57 | def _assert_chunks(self, array_assert_func, actual, expected):
58 | actual = dict(actual)
59 | expected = dict(expected)
60 | self.assertCountEqual(expected, actual, msg='inconsistent keys')
61 | for key in expected:
62 | actual_chunk, expected_chunk = actual[key], expected[key]
63 | self.assertEqual(type(actual_chunk), type(expected_chunk))
64 | if type(actual_chunk) is not list:
65 | actual_chunk, expected_chunk = (actual_chunk,), (expected_chunk,)
66 | for a, e in zip(actual_chunk, expected_chunk):
67 | array_assert_func(a, e)
68 |
69 | def assertIdenticalChunks(self, actual, expected):
70 | self._assert_chunks(xarray.testing.assert_identical, actual, expected)
71 |
72 | def assertAllCloseChunks(self, actual, expected):
73 | self._assert_chunks(xarray.testing.assert_allclose, actual, expected)
74 |
75 |
76 | def dummy_era5_surface_dataset(
77 | variables=2,
78 | latitudes=73,
79 | longitudes=144,
80 | times=365 * 4,
81 | freq='6H',
82 | ):
83 | """A mock version of the Pangeo ERA5 surface reanalysis dataset."""
84 | # based on: gs://pangeo-era5/reanalysis/spatial-analysis
85 | dims = ('time', 'latitude', 'longitude')
86 | shape = (times, latitudes, longitudes)
87 | var_names = ['asn', 'd2m', 'e', 'mn2t', 'mx2t', 'ptype'][:variables]
88 | rng = np.random.default_rng(0)
89 | data_vars = {
90 | name: (dims, rng.normal(size=shape).astype(np.float32), {'var_index': i})
91 | for i, name in enumerate(var_names)
92 | }
93 |
94 | latitude = np.linspace(90, 90, num=latitudes)
95 | longitude = np.linspace(0, 360, num=longitudes, endpoint=False)
96 | time = pd.date_range('1979-01-01T00', periods=times, freq=freq)
97 | coords = {'time': time, 'latitude': latitude, 'longitude': longitude}
98 |
99 | return xarray.Dataset(data_vars, coords, {'global_attr': 'yes'})
100 |
--------------------------------------------------------------------------------
/xarray_beam/_src/threadmap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Beam Map transforms that execute in parallel using a thread pool.
15 |
16 | This can be a good idea for IO bound tasks, but generally should be avoided for
17 | CPU bound tasks, especially CPU bound tasks that do not release the Python GIL.
18 |
19 | They can be used as drop-in substitutes for the corresponding Beam transforms:
20 | - beam.Map -> ThreadMap
21 | - beam.MapTuple -> ThreadMapTuple
22 | - beam.FlatMap -> FlatThreadMap
23 | - beam.FlatMapTuple -> FlatThreadMapTuple
24 |
25 | By default, 16 threads are used per task. This can be adjusted via the
26 | `num_threads` keyword argument.
27 | """
28 | import concurrent.futures
29 | import functools
30 |
31 | import apache_beam as beam
32 |
33 |
34 | class ThreadDoFn(beam.DoFn):
35 | """A DoFn that executes inputs in a ThreadPool."""
36 |
37 | def __init__(self, func, num_threads):
38 | self.func = func
39 | self.num_threads = num_threads
40 |
41 | def setup(self):
42 | self.executor = concurrent.futures.ThreadPoolExecutor(self.num_threads)
43 |
44 | def teardown(self):
45 | self.executor.shutdown()
46 |
47 | def process(self, element, *args, **kwargs):
48 | futures = []
49 | for x in element:
50 | futures.append(self.executor.submit(self.func, x, *args, **kwargs))
51 | for future in futures:
52 | yield future.result()
53 |
54 |
55 | class FlatThreadDoFn(ThreadDoFn):
56 |
57 | def process(self, element, *args, **kwargs):
58 | for results in super().process(element, *args, **kwargs):
59 | yield from results
60 |
61 |
62 | class _ThreadMap(beam.PTransform):
63 | """Like beam.Map, but executed in a thread-pool."""
64 |
65 | def __init__(self, func, *args, num_threads, **kwargs):
66 | self.func = func
67 | self.args = args
68 | self.kwargs = kwargs
69 | self.num_threads = num_threads
70 |
71 | def get_dofn(self):
72 | return ThreadDoFn(self.func, self.num_threads)
73 |
74 | def expand(self, pcoll):
75 | return (
76 | pcoll
77 | | 'BatchElements'
78 | >> beam.BatchElements(
79 | min_batch_size=self.num_threads,
80 | max_batch_size=self.num_threads,
81 | )
82 | | 'ParDo' >> beam.ParDo(self.get_dofn(), *self.args, **self.kwargs)
83 | )
84 |
85 |
86 | class _ThreadMapTuple(_ThreadMap):
87 | """Like beam.MapTuple, but executed in a thread-pool."""
88 |
89 | def get_dofn(self):
90 | func = lambda xs, **kwargs: self.func(*xs, **kwargs)
91 | return ThreadDoFn(func, self.num_threads)
92 |
93 |
94 | class _FlatThreadMap(_ThreadMap):
95 | """Like beam.FlatMap, but executed in a thread-pool."""
96 |
97 | def get_dofn(self):
98 | return FlatThreadDoFn(self.func, self.num_threads)
99 |
100 |
101 | class _FlatThreadMapTuple(_ThreadMap):
102 | """Like beam.FlatMapTuple, but executed in a thread-pool."""
103 |
104 | def get_dofn(self):
105 | func = lambda xs, **kwargs: self.func(*xs, **kwargs)
106 | return FlatThreadDoFn(func, self.num_threads)
107 |
108 |
109 | def _maybe_threaded(beam_transform, thread_transform):
110 | @functools.wraps(thread_transform)
111 | def create(func, *args, num_threads=16, **kwargs):
112 | if num_threads is None:
113 | return beam_transform(func, *args, **kwargs)
114 | else:
115 | return thread_transform(func, *args, num_threads=num_threads, **kwargs)
116 |
117 | return create
118 |
119 |
120 | # These functions don't use threads if num_threads=None.
121 | ThreadMap = _maybe_threaded(beam.Map, _ThreadMap)
122 | ThreadMapTuple = _maybe_threaded(beam.MapTuple, _ThreadMapTuple)
123 | FlatThreadMap = _maybe_threaded(beam.FlatMap, _FlatThreadMap)
124 | FlatThreadMapTuple = _maybe_threaded(beam.FlatMapTuple, _FlatThreadMapTuple)
125 |
--------------------------------------------------------------------------------
/xarray_beam/_src/threadmap_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for threadmap."""
15 |
16 | from absl.testing import absltest
17 | import apache_beam as beam
18 | import unittest
19 |
20 | from xarray_beam._src import test_util
21 | from xarray_beam._src import threadmap
22 |
23 |
24 | # pylint: disable=expression-not-assigned
25 | # pylint: disable=pointless-statement
26 |
27 |
28 | class ThreadMapTest(test_util.TestCase):
29 |
30 | def test_map(self):
31 | def f(*args, **kwargs):
32 | return args, kwargs
33 |
34 | expected = [1, 2, 3] | beam.Map(f, 4, y=5)
35 | actual = [1, 2, 3] | threadmap.ThreadMap(f, 4, y=5)
36 | self.assertEqual(expected, actual)
37 |
38 | actual = [1, 2, 3] | threadmap.ThreadMap(f, 4, y=5, num_threads=None)
39 | self.assertEqual(expected, actual)
40 |
41 | @unittest.skip('this is failing with recent Apache Beam releases')
42 | def test_flat_map(self):
43 | def f(*args, **kwargs):
44 | return [(args, kwargs)] * 2
45 |
46 | expected = [1, 2, 3] | beam.FlatMap(f, 4, y=5)
47 | actual = [1, 2, 3] | threadmap.FlatThreadMap(f, 4, y=5)
48 | self.assertEqual(expected, actual)
49 |
50 | actual = [1, 2, 3] | threadmap.FlatThreadMap(f, 4, y=5, num_threads=None)
51 | self.assertEqual(expected, actual)
52 |
53 | def test_map_tuple(self):
54 | def f(a, b, y=None):
55 | return a, b, y
56 |
57 | expected = [(1, 2), (3, 4)] | beam.MapTuple(f, y=5)
58 | actual = [(1, 2), (3, 4)] | threadmap.ThreadMapTuple(f, y=5)
59 | self.assertEqual(expected, actual)
60 |
61 | actual = [(1, 2), (3, 4)] | threadmap.ThreadMapTuple(
62 | f, y=5, num_threads=None
63 | )
64 | self.assertEqual(expected, actual)
65 |
66 | def test_flat_map_tuple(self):
67 | def f(a, b, y=None):
68 | return a, b, y
69 |
70 | expected = [(1, 2), (3, 4)] | beam.FlatMapTuple(f, y=5)
71 | actual = [(1, 2), (3, 4)] | threadmap.FlatThreadMapTuple(f, y=5)
72 | self.assertEqual(expected, actual)
73 |
74 | actual = [(1, 2), (3, 4)] | threadmap.FlatThreadMapTuple(
75 | f, y=5, num_threads=None
76 | )
77 | self.assertEqual(expected, actual)
78 |
79 | def test_maybe_thread_map(self):
80 | transform = threadmap.ThreadMap(lambda x: x)
81 | self.assertIsInstance(transform, threadmap._ThreadMap)
82 |
83 | transform = threadmap.ThreadMap(lambda x: x, num_threads=None)
84 | self.assertIsInstance(transform, beam.ParDo)
85 |
86 | transform = threadmap.ThreadMap(lambda x: x, num_threads=1)
87 | self.assertIsInstance(transform, threadmap._ThreadMap)
88 |
89 |
90 | if __name__ == '__main__':
91 | absltest.main()
92 |
--------------------------------------------------------------------------------
/xarray_beam/_src/zarr_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import re
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import dask.array as da
19 | import numpy as np
20 | import pandas as pd
21 | import xarray
22 | import xarray_beam as xbeam
23 | from xarray_beam._src import test_util
24 |
25 |
26 | # pylint: disable=expression-not-assigned
27 | # pylint: disable=pointless-statement
28 |
29 |
30 | class DatasetToZarrTest(test_util.TestCase):
31 |
32 | def test_open_zarr(self):
33 | source_ds = xarray.Dataset(
34 | {'foo': ('x', da.arange(0, 60, 10, chunks=2))},
35 | )
36 | temp_dir = self.create_tempdir().full_path
37 | source_ds.to_zarr(temp_dir)
38 | roundtripped_ds, chunks = xbeam.open_zarr(temp_dir)
39 | xarray.testing.assert_identical(roundtripped_ds, source_ds)
40 | self.assertEqual(roundtripped_ds.chunks, {})
41 | self.assertEqual(chunks, {'x': 2})
42 |
43 | def test_open_zarr_inconsistent(self):
44 | source_ds = xarray.Dataset(
45 | {
46 | 'foo': ('x', da.arange(0, 60, 10, chunks=2)),
47 | 'bar': ('x', da.arange(0, 60, 10, chunks=3)),
48 | },
49 | )
50 | temp_dir = self.create_tempdir().full_path
51 | source_ds.to_zarr(temp_dir)
52 | with self.assertRaisesRegex(
53 | ValueError,
54 | "inconsistent chunk sizes on Zarr dataset for dimension 'x': {2, 3}",
55 | ):
56 | xbeam.open_zarr(temp_dir)
57 |
58 | def test_make_template(self):
59 | source = xarray.Dataset(
60 | {
61 | 'foo': ('x', np.ones(3)),
62 | 'bar': ('x', np.ones(3)),
63 | },
64 | )
65 | template = xbeam.make_template(source)
66 | self.assertEqual(list(template.data_vars), ['foo', 'bar'])
67 | self.assertEqual(template.chunks, {'x': (3,)})
68 | self.assertEqual(template.sizes, {'x': 3})
69 | with self.assertRaisesRegex(
70 | ValueError, 'cannot compute array values of xarray.Dataset objects'
71 | ):
72 | template.compute()
73 |
74 | def test_make_template_lazy_vars_on_numpy(self):
75 | source = xarray.Dataset(
76 | {
77 | 'foo': ('x', np.ones(3)),
78 | 'bar': ('x', np.ones(3)),
79 | },
80 | )
81 | template = xbeam.make_template(source, lazy_vars={'foo'})
82 | self.assertEqual(template.foo.chunks, ((3,),))
83 | self.assertIsNone(template.bar.chunks)
84 |
85 | def test_make_template_lazy_vars_on_dask(self):
86 | source = xarray.Dataset(
87 | {
88 | 'foo': ('x', np.ones(3)),
89 | 'bar': ('x', np.ones(3)),
90 | },
91 | ).chunk({'x': 2})
92 | template = xbeam.make_template(source, lazy_vars={'foo'})
93 | self.assertEqual(template.foo.chunks, ((3,),)) # one chunk
94 | self.assertIsInstance(template.bar.data, np.ndarray) # computed
95 |
96 | def test_make_template_from_chunked(self):
97 | source = xarray.Dataset(
98 | {
99 | 'foo': ('x', da.ones(3)),
100 | 'bar': ('x', np.ones(3)),
101 | },
102 | )
103 | template = xbeam._src.zarr._make_template_from_chunked(source)
104 | self.assertEqual(template.foo.chunks, ((3,),))
105 | self.assertIsNone(template.bar.chunks)
106 |
107 | def test_replace_template_dims_with_coords(self):
108 | source = xarray.Dataset(
109 | {'foo': (('x', 'y'), np.zeros((1, 2)))},
110 | coords={'x': [0], 'y': [10, 20]},
111 | )
112 | template = xbeam.make_template(source)
113 | new_x_coords = pd.date_range('2000-01-01', periods=5)
114 | new_template = xbeam.replace_template_dims(template, x=new_x_coords)
115 |
116 | self.assertEqual(new_template.sizes, {'x': 5, 'y': 2})
117 | expected_x_coord = xarray.DataArray(
118 | new_x_coords, dims='x', coords={'x': new_x_coords}
119 | )
120 | xarray.testing.assert_equal(new_template.x, expected_x_coord)
121 | xarray.testing.assert_equal(new_template.y, source.y) # Unchanged coord
122 | self.assertEqual(new_template.foo.shape, (5, 2))
123 | self.assertIsInstance(new_template.foo.data, da.Array) # Still lazy
124 |
125 | def test_replace_template_dims_with_size(self):
126 | source = xarray.Dataset(
127 | {'foo': (('x', 'y'), np.zeros((1, 2)))},
128 | coords={'x': [0], 'y': [10, 20]},
129 | )
130 | template = xbeam.make_template(source)
131 | new_template = xbeam.replace_template_dims(template, x=10)
132 |
133 | self.assertEqual(new_template.sizes, {'x': 10, 'y': 2})
134 | self.assertNotIn(
135 | 'x', new_template.coords
136 | ) # Coord is dropped when replaced by size
137 | xarray.testing.assert_equal(new_template.y, source.y)
138 | self.assertEqual(new_template.foo.shape, (10, 2))
139 | self.assertIsInstance(new_template.foo.data, da.Array)
140 |
141 | def test_replace_template_dims_multiple(self):
142 | source = xarray.Dataset(
143 | {'foo': (('x', 'y'), np.zeros((1, 2)))},
144 | coords={'x': [0], 'y': [10, 20]},
145 | )
146 | template = xbeam.make_template(source)
147 | new_x_coords = pd.date_range('2000-01-01', periods=5)
148 | new_template = xbeam.replace_template_dims(template, x=new_x_coords, y=3)
149 |
150 | self.assertEqual(new_template.sizes, {'x': 5, 'y': 3})
151 | expected_x_coord = xarray.DataArray(
152 | new_x_coords, dims='x', coords={'x': new_x_coords}
153 | )
154 | xarray.testing.assert_equal(new_template.x, expected_x_coord)
155 | self.assertNotIn('y', new_template.coords)
156 | self.assertEqual(new_template.foo.shape, (5, 3))
157 | self.assertIsInstance(new_template.foo.data, da.Array)
158 |
159 | def test_replace_template_dims_multiple_vars(self):
160 | source = xarray.Dataset(
161 | {
162 | 'foo': (('x', 'y'), np.zeros((1, 2))),
163 | 'bar': ('x', np.zeros(1)),
164 | 'baz': ('z', np.zeros(3)), # Unrelated dim
165 | },
166 | coords={'x': [0], 'y': [10, 20], 'z': [1, 2, 3]},
167 | )
168 | template = xbeam.make_template(source)
169 | new_template = xbeam.replace_template_dims(template, x=5)
170 |
171 | self.assertEqual(new_template.sizes, {'x': 5, 'y': 2, 'z': 3})
172 | self.assertNotIn('x', new_template.coords)
173 | xarray.testing.assert_equal(new_template.y, source.y)
174 | xarray.testing.assert_equal(new_template.z, source.z)
175 | self.assertEqual(new_template.foo.shape, (5, 2))
176 | self.assertEqual(new_template.bar.shape, (5,))
177 | self.assertEqual(new_template.baz.shape, (3,)) # Unchanged var
178 | self.assertIsInstance(new_template.foo.data, da.Array)
179 | self.assertIsInstance(new_template.bar.data, da.Array)
180 | self.assertIsInstance(new_template.baz.data, da.Array)
181 |
182 | def test_replace_template_dims_error_on_non_template(self):
183 | source = xarray.Dataset({'foo': ('x', np.zeros(1))}) # Not a template
184 | with self.assertRaisesRegex(ValueError, 'is not chunked with Dask'):
185 | xbeam.replace_template_dims(source, x=5)
186 |
187 | def test_chunks_to_zarr(self):
188 | dataset = xarray.Dataset(
189 | {'foo': ('x', np.arange(0, 60, 10))},
190 | coords={'x': np.arange(6)},
191 | )
192 | chunked = dataset.chunk()
193 | inputs = [
194 | (xbeam.Key({'x': 0}), dataset),
195 | ]
196 | with self.subTest('no template'):
197 | temp_dir = self.create_tempdir().full_path
198 | with self.assertWarnsRegex(FutureWarning, 'No template provided'):
199 | inputs | xbeam.ChunksToZarr(temp_dir, template=None)
200 | result = xarray.open_zarr(temp_dir, consolidated=True)
201 | xarray.testing.assert_identical(dataset, result)
202 | with self.subTest('with template'):
203 | temp_dir = self.create_tempdir().full_path
204 | inputs | xbeam.ChunksToZarr(temp_dir, chunked)
205 | result = xarray.open_zarr(temp_dir, consolidated=True)
206 | xarray.testing.assert_identical(dataset, result)
207 | with self.subTest('with template and needs_setup=False'):
208 | temp_dir = self.create_tempdir().full_path
209 | xbeam.setup_zarr(chunked, temp_dir)
210 | inputs | xbeam.ChunksToZarr(temp_dir, chunked, needs_setup=False)
211 | result = xarray.open_zarr(temp_dir, consolidated=True)
212 | xarray.testing.assert_identical(dataset, result)
213 | with self.subTest('with zarr_chunks and with template'):
214 | temp_dir = self.create_tempdir().full_path
215 | zarr_chunks = {'x': 3}
216 | inputs | xbeam.ChunksToZarr(temp_dir, chunked, zarr_chunks)
217 | result = xarray.open_zarr(temp_dir, consolidated=True)
218 | xarray.testing.assert_identical(dataset, result)
219 | self.assertEqual(result.chunks, {'x': (3, 3)})
220 | with self.subTest('with zarr_chunks and no template'):
221 | temp_dir = self.create_tempdir().full_path
222 | zarr_chunks = {'x': 3}
223 | with self.assertWarnsRegex(FutureWarning, 'No template provided'):
224 | inputs | xbeam.ChunksToZarr(
225 | temp_dir, template=None, zarr_chunks=zarr_chunks
226 | )
227 | result = xarray.open_zarr(temp_dir, consolidated=True)
228 | xarray.testing.assert_identical(dataset, result)
229 | self.assertEqual(result.chunks, {'x': (3, 3)})
230 |
231 | temp_dir = self.create_tempdir().full_path
232 | with self.assertRaisesRegex(
233 | ValueError,
234 | 'template does not have any variables chunked with Dask',
235 | ):
236 | xbeam.ChunksToZarr(temp_dir, dataset)
237 |
238 | temp_dir = self.create_tempdir().full_path
239 | template = chunked.assign_coords(x=np.zeros(6))
240 | with self.assertRaisesRegex(
241 | ValueError,
242 | 'template and chunk indexes do not match',
243 | ):
244 | inputs | xbeam.ChunksToZarr(temp_dir, template)
245 |
246 | inputs2 = [
247 | (xbeam.Key({'x': 0}), dataset.expand_dims(z=[1, 2])),
248 | ]
249 | temp_dir = self.create_tempdir().full_path
250 | with self.assertRaisesRegex(
251 | ValueError,
252 | 'unexpected new indexes found in chunk',
253 | ):
254 | inputs2 | xbeam.ChunksToZarr(temp_dir, template)
255 |
256 | def test_multiple_vars_chunks_to_zarr(self):
257 | dataset = xarray.Dataset(
258 | {
259 | 'foo': ('x', np.arange(0, 60, 10)),
260 | 'bar': ('x', -np.arange(6)),
261 | },
262 | coords={'x': np.arange(6)},
263 | )
264 | chunked = dataset.chunk()
265 | inputs = [
266 | (xbeam.Key({'x': 0}, {'foo'}), dataset[['foo']]),
267 | (xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']]),
268 | ]
269 | with self.subTest('no template'):
270 | temp_dir = self.create_tempdir().full_path
271 | with self.assertWarnsRegex(FutureWarning, 'No template provided'):
272 | inputs | xbeam.ChunksToZarr(temp_dir, template=None)
273 | result = xarray.open_zarr(temp_dir, consolidated=True)
274 | xarray.testing.assert_identical(dataset, result)
275 | with self.subTest('with template'):
276 | temp_dir = self.create_tempdir().full_path
277 | inputs | xbeam.ChunksToZarr(temp_dir, chunked)
278 | result = xarray.open_zarr(temp_dir, consolidated=True)
279 | xarray.testing.assert_identical(dataset, result)
280 |
281 | @parameterized.named_parameters(
282 | {
283 | 'testcase_name': 'combined_coords',
284 | 'coords': {'bar': (('x', 'y'), -np.arange(6).reshape(3, 2))},
285 | },
286 | {
287 | 'testcase_name': 'separate_coords',
288 | 'coords': {'x': np.arange(3), 'y': np.arange(2)},
289 | },
290 | )
291 | def test_2d_chunks_to_zarr(self, coords):
292 | dataset = xarray.Dataset(
293 | {'foo': (('x', 'y'), np.arange(0, 60, 10).reshape(3, 2))},
294 | coords=coords,
295 | )
296 | with self.subTest('partial key'):
297 | inputs = [(xbeam.Key({'x': 0}), dataset)]
298 | temp_dir = self.create_tempdir().full_path
299 | inputs | xbeam.ChunksToZarr(temp_dir, template=dataset.chunk())
300 | result = xarray.open_zarr(temp_dir, consolidated=True)
301 | xarray.testing.assert_identical(dataset, result)
302 | with self.subTest('split along partial key'):
303 | inputs = [(xbeam.Key({'x': 0}), dataset)]
304 | temp_dir = self.create_tempdir().full_path
305 | inputs | xbeam.SplitChunks({'x': 1}) | xbeam.ChunksToZarr(
306 | temp_dir, template=dataset.chunk({'x': 1})
307 | )
308 | result = xarray.open_zarr(temp_dir, consolidated=True)
309 | xarray.testing.assert_identical(dataset, result)
310 | with self.subTest('full key'):
311 | inputs = [(xbeam.Key({'x': 0, 'y': 0}), dataset)]
312 | temp_dir = self.create_tempdir().full_path
313 | inputs | xbeam.ChunksToZarr(temp_dir, template=dataset.chunk())
314 | result = xarray.open_zarr(temp_dir, consolidated=True)
315 | xarray.testing.assert_identical(dataset, result)
316 |
317 | def test_chunks_to_zarr_dask_chunks(self):
318 | dataset = xarray.Dataset(
319 | {'foo': ('x', np.arange(0, 60, 10))},
320 | coords={'x': np.arange(6)},
321 | )
322 | chunked = dataset.chunk()
323 | inputs = [
324 | (xbeam.Key({'x': 0}), dataset.chunk(3)),
325 | ]
326 | temp_dir = self.create_tempdir().full_path
327 | inputs | xbeam.ChunksToZarr(temp_dir, chunked)
328 | result = xarray.open_zarr(temp_dir)
329 | xarray.testing.assert_identical(dataset, result)
330 |
331 | def test_dataset_to_zarr_simple(self):
332 | dataset = xarray.Dataset(
333 | {'foo': ('x', np.arange(0, 60, 10))},
334 | coords={'x': np.arange(6)},
335 | attrs={'meta': 'data'},
336 | )
337 | chunked = dataset.chunk({'x': 3})
338 | temp_dir = self.create_tempdir().full_path
339 | test_util.EagerPipeline() | xbeam.DatasetToZarr(chunked, temp_dir)
340 | actual = xarray.open_zarr(temp_dir, consolidated=True)
341 | xarray.testing.assert_identical(actual, dataset)
342 |
343 | def test_dataset_to_zarr_unchunked(self):
344 | dataset = xarray.Dataset(
345 | {'foo': ('x', np.arange(0, 60, 10))},
346 | )
347 | temp_dir = self.create_tempdir().full_path
348 | with self.assertRaisesRegex(
349 | ValueError, 'dataset must be chunked or chunks must be provided'
350 | ):
351 | test_util.EagerPipeline() | xbeam.DatasetToZarr(dataset, temp_dir)
352 |
353 | def test_validate_zarr_chunk_accepts_partial_key(self):
354 | dataset = xarray.Dataset(
355 | {'foo': (('x', 'y'), np.zeros((3, 2)))},
356 | coords={'x': np.arange(3), 'y': np.arange(2)},
357 | )
358 | # Should not raise an exception:
359 | xbeam.validate_zarr_chunk(
360 | key=xbeam.Key({'x': 0}),
361 | chunk=dataset,
362 | template=dataset.chunk(),
363 | zarr_chunks=None,
364 | )
365 |
366 | def test_to_zarr_wrong_multiple_error(self):
367 | ds = xarray.Dataset({'foo': ('x', np.arange(6))})
368 | inputs = [
369 | (xbeam.Key({'x': 3}), ds.tail(3)),
370 | ]
371 | temp_dir = self.create_tempdir().full_path
372 | with self.assertRaisesRegex(
373 | ValueError,
374 | (
375 | "chunk offset 3 along dimension 'x' is not a multiple of zarr "
376 | "chunks {'x': 4}"
377 | ),
378 | ):
379 | inputs | xbeam.ChunksToZarr(
380 | temp_dir, template=ds.chunk(4), zarr_chunks={'x': 4}
381 | )
382 |
383 | def test_to_zarr_needs_consolidation_error(self):
384 | ds = xarray.Dataset({'foo': ('x', np.arange(6))})
385 | inputs = [
386 | (xbeam.Key({'x': 0}), ds.head(3)),
387 | (xbeam.Key({'x': 3}), ds.tail(3)),
388 | ]
389 | temp_dir = self.create_tempdir().full_path
390 | with self.assertRaisesRegex(
391 | ValueError, 'chunk is smaller than zarr chunks'
392 | ):
393 | inputs | xbeam.ChunksToZarr(
394 | temp_dir, template=ds.chunk(), zarr_chunks={'x': 6}
395 | )
396 | with self.assertRaisesRegex(
397 | ValueError, 'chunk is smaller than zarr chunks'
398 | ):
399 | inputs | xbeam.ChunksToZarr(temp_dir, template=ds.chunk())
400 |
401 | def test_to_zarr_fixed_template(self):
402 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
403 | template = dataset.chunk({'x': 3})
404 | inputs = [
405 | (xbeam.Key({'x': 0}), dataset.head(3)),
406 | (xbeam.Key({'x': 3}), dataset.tail(3)),
407 | ]
408 | temp_dir = self.create_tempdir().full_path
409 | chunks_to_zarr = xbeam.ChunksToZarr(temp_dir, template)
410 | self.assertEqual(chunks_to_zarr.template.chunks, {'x': (6,)})
411 | self.assertEqual(chunks_to_zarr.zarr_chunks, {'x': 3})
412 | inputs | chunks_to_zarr
413 | actual = xarray.open_zarr(temp_dir, consolidated=True)
414 | xarray.testing.assert_identical(actual, dataset)
415 |
416 | def test_infer_zarr_chunks(self):
417 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
418 |
419 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset)
420 | self.assertEqual(chunks, {})
421 |
422 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.chunk())
423 | self.assertEqual(chunks, {'x': 6})
424 |
425 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.head(0).chunk())
426 | self.assertEqual(chunks, {'x': 0})
427 |
428 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.chunk(3))
429 | self.assertEqual(chunks, {'x': 3})
430 |
431 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.chunk(4))
432 | self.assertEqual(chunks, {'x': 4})
433 |
434 | with self.assertRaisesRegex(
435 | ValueError,
436 | re.escape(
437 | "Zarr cannot handle inconsistent chunk sizes along dimension 'x': "
438 | '(2, 4)'
439 | ),
440 | ):
441 | xbeam._src.zarr._infer_zarr_chunks(dataset.chunk({'x': (2, 4)}))
442 |
443 | with self.assertRaisesRegex(
444 | ValueError,
445 | re.escape(
446 | "Zarr cannot handle inconsistent chunk sizes along dimension 'x': "
447 | '(3, 2, 1)'
448 | ),
449 | ):
450 | xbeam._src.zarr._infer_zarr_chunks(dataset.chunk({'x': (3, 2, 1)}))
451 |
452 | def test_chunks_to_zarr_docs_demo(self):
453 | # verify that the ChunksToChunk demo from our docs works
454 | data = np.random.RandomState(0).randn(2920, 25, 53)
455 | ds = xarray.Dataset({'temperature': (('time', 'lat', 'lon'), data)})
456 | chunks = {'time': 1000, 'lat': 25, 'lon': 53}
457 | temp_dir = self.create_tempdir().full_path
458 | (
459 | test_util.EagerPipeline()
460 | | xbeam.DatasetToChunks(ds, chunks)
461 | | xbeam.ChunksToZarr(
462 | temp_dir, template=xbeam.make_template(ds), zarr_chunks=chunks
463 | )
464 | )
465 | result = xarray.open_zarr(temp_dir)
466 | xarray.testing.assert_identical(result, ds)
467 |
468 |
469 | if __name__ == '__main__':
470 | absltest.main()
471 |
--------------------------------------------------------------------------------