├── .github └── workflows │ └── ci-build.yml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── conftest.py ├── docs ├── Makefile ├── _static │ ├── README.md │ ├── xarray-beam-logo.png │ └── xarray-beam-vs-xarray-dask.png ├── aggregation.ipynb ├── api.md ├── conf.py ├── data-model.ipynb ├── index.md ├── make.bat ├── read-write.ipynb ├── rechunking.ipynb ├── requirements.txt └── why-xarray-beam.md ├── examples ├── README.md ├── __init__.py ├── era5_climatology.py ├── era5_climatology_test.py ├── xbeam_rechunk.py └── xbeam_rechunk_test.py ├── pyproject.toml ├── setup.py └── xarray_beam ├── __init__.py └── _src ├── __init__.py ├── combiners.py ├── combiners_test.py ├── core.py ├── core_test.py ├── dataset.py ├── dataset_test.py ├── integration_test.py ├── rechunk.py ├── rechunk_test.py ├── test_util.py ├── threadmap.py ├── threadmap_test.py ├── zarr.py └── zarr_test.py /.github/workflows/ci-build.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | # Triggers the workflow on push or pull request events but only for the main branch 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build: 14 | name: "python ${{ matrix.python-version }}" 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.9", "3.10", "3.11"] 20 | steps: 21 | - name: Cancel previous 22 | uses: styfle/cancel-workflow-action@0.7.0 23 | with: 24 | access_token: ${{ github.token }} 25 | if: ${{github.ref != 'refs/head/main'}} 26 | - uses: actions/checkout@v4 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Get pip cache dir 32 | id: pip-cache 33 | run: | 34 | python -m pip install --upgrade pip wheel 35 | echo "::set-output name=dir::$(pip cache dir)" 36 | - name: pip cache 37 | uses: actions/cache@v4 38 | with: 39 | path: ${{ steps.pip-cache.outputs.dir }} 40 | key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} 41 | restore-keys: | 42 | ${{ runner.os }}-pip- 43 | - name: Install Xarray-Beam 44 | run: | 45 | pip install -e .[tests] 46 | - name: Run unit tests 47 | # TODO(shoyer): remove the pangeo-forge module? the tests no longer pass 48 | run: | 49 | pytest xarray_beam 50 | - name: Run example tests 51 | # The examples define some of the same flags, so we run pytest in separate processes. 52 | run: | 53 | pytest examples/era5_climatology_test.py 54 | pytest examples/xbeam_rechunk_test.py 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .DS_Store 3 | build 4 | dist 5 | docs/.ipynb_checkpoints 6 | docs/_build 7 | docs/_autosummary 8 | docs/*.zarr 9 | __pycache__ 10 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.10" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optionally declare the Python requirements required to build your docs 18 | python: 19 | install: 20 | - method: pip 21 | path: . 22 | - requirements: docs/requirements.txt 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Xarray-Beam 2 | 3 | Xarray-Beam is a Python library for building 4 | [Apache Beam](https://beam.apache.org/) pipelines with 5 | [Xarray](http://xarray.pydata.org/en/stable/) datasets. 6 | 7 | The project aims to facilitate data transformations and analysis on large-scale 8 | multi-dimensional labeled arrays, such as: 9 | 10 | - Ad-hoc computation on Xarray data, by dividing a `xarray.Dataset` into many 11 | smaller pieces ("chunks"). 12 | - Adjusting array chunks, using the 13 | [Rechunker algorithm](https://rechunker.readthedocs.io/en/latest/algorithm.html). 14 | - Ingesting large, multi-dimensional array datasets into an analysis-ready, 15 | cloud-optimized format, namely [Zarr](https://zarr.readthedocs.io/) (see 16 | also [Pangeo Forge](https://github.com/pangeo-forge/pangeo-forge-recipes)). 17 | - Calculating statistics (e.g., "climatology") across distributed datasets 18 | with arbitrary groups. 19 | 20 | For more about our approach and how to get started, 21 | **[read the documentation](https://xarray-beam.readthedocs.io/)**! 22 | 23 | **Warning: Xarray-Beam is a sharp tool 🔪** 24 | 25 | Xarray-Beam is relatively new, and focused on expert users: 26 | 27 | - We use it extensively at Google for processing large-scale weather datasets, 28 | but there is not yet a vibrant external community. 29 | - It provides low-level abstractions that facilitate writing very large 30 | scale data pipelines (e.g., 100+ TB), but by design it requires explicitly 31 | thinking about how every operation is parallelized. 32 | 33 | ## Installation 34 | 35 | Xarray-Beam requires recent versions of immutabledict, Xarray, Dask, Rechunker, 36 | Zarr, and Apache Beam. For best performance when writing Zarr files, use Xarray 37 | 0.19.0 or later. 38 | 39 | ## Disclaimer 40 | 41 | Xarray-Beam is an experiment that we are sharing with the outside world in the 42 | hope that it will be useful. It is not a supported Google product. We welcome 43 | feedback, bug reports and code contributions, but cannot guarantee they will be 44 | addressed. 45 | 46 | See the "Contribution guidelines" for more. 47 | 48 | ## Credits 49 | 50 | Contributors: 51 | 52 | - Stephan Hoyer 53 | - Jason Hickey 54 | - Cenk Gazen 55 | - Alex Merose 56 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configure FLAGS with default values for absltest.""" 15 | from absl import app 16 | 17 | try: 18 | app.run(lambda argv: None) 19 | except SystemExit: 20 | pass 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/README.md: -------------------------------------------------------------------------------- 1 | These images were created in Google drawings. 2 | 3 | - [Logo](https://docs.google.com/drawings/d/1gBt0tOOnY7pMDB1g45T7ARuwCc1CZjojZT2IfdvA5FI/edit?usp=sharing) 4 | - [Xarray-Beam vs Xarray-Dask](https://docs.google.com/drawings/d/1LBM4AKHadzuS8tsHQ5TcTAWaBM7E6NeCd7Aq16hpIxk/edit?usp=sharing) 5 | -------------------------------------------------------------------------------- /docs/_static/xarray-beam-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/xarray-beam/da2ff60f48209e9c82ea1760d1d98540556e7271/docs/_static/xarray-beam-logo.png -------------------------------------------------------------------------------- /docs/_static/xarray-beam-vs-xarray-dask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/xarray-beam/da2ff60f48209e9c82ea1760d1d98540556e7271/docs/_static/xarray-beam-vs-xarray-dask.png -------------------------------------------------------------------------------- /docs/aggregation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "aaf2f2cb", 6 | "metadata": {}, 7 | "source": [ 8 | "# Aggregation\n", 9 | "\n", 10 | "Xarray-Beam can perform efficient distributed data aggregation in the \"map-reduce\" model. \n", 11 | "\n", 12 | "This currently only includes `Mean`, but we would welcome contributions of other aggregation functions such as `Sum`, `Std`, `Var`, `Min`, `Max`, etc." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "4b77b953", 18 | "metadata": {}, 19 | "source": [ 20 | "## High-level API\n", 21 | "\n", 22 | "The `Mean` transformation comes in three forms: {py:class}`Mean `, {py:class}`Mean.Globally `, and {py:class}`Mean.PerKey `. The implementation is highly scalable, based on a Beam's [`CombineFn`](https://beam.apache.org/documentation/transforms/python/aggregation/combineglobally/#example-4-combining-with-a-combinefn)." 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "fca45196", 28 | "metadata": {}, 29 | "source": [ 30 | "The high-level `Mean` transform can be used to aggregate a distributed dataset across an existing dimension or dimensions, similar to Xarray's `.mean()` method:" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "id": "3e387dd1", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "\n", 44 | "Dimensions: (lat: 25, time: 2920, lon: 53)\n", 45 | "Coordinates:\n", 46 | " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", 47 | " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", 48 | " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", 49 | "Data variables:\n", 50 | " air (time, lat, lon) float32 241.2 242.5 243.5 ... 296.5 296.2 295.7\n", 51 | "Attributes:\n", 52 | " Conventions: COARDS\n", 53 | " title: 4x daily NMC reanalysis (1948)\n", 54 | " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", 55 | " platform: Model\n", 56 | " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "import apache_beam as beam\n", 62 | "import numpy as np\n", 63 | "import xarray_beam as xbeam\n", 64 | "import xarray\n", 65 | "\n", 66 | "ds = xarray.tutorial.load_dataset('air_temperature')\n", 67 | "print(ds)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "id": "5967340a", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "application/javascript": [ 79 | "\n", 80 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n", 81 | " var jqueryScript = document.createElement('script');\n", 82 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", 83 | " jqueryScript.type = 'text/javascript';\n", 84 | " jqueryScript.onload = function() {\n", 85 | " var datatableScript = document.createElement('script');\n", 86 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", 87 | " datatableScript.type = 'text/javascript';\n", 88 | " datatableScript.onload = function() {\n", 89 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n", 90 | " window.interactive_beam_jquery(document).ready(function($){\n", 91 | " \n", 92 | " });\n", 93 | " }\n", 94 | " document.head.appendChild(datatableScript);\n", 95 | " };\n", 96 | " document.head.appendChild(jqueryScript);\n", 97 | " } else {\n", 98 | " window.interactive_beam_jquery(document).ready(function($){\n", 99 | " \n", 100 | " });\n", 101 | " }" 102 | ] 103 | }, 104 | "metadata": {}, 105 | "output_type": "display_data" 106 | }, 107 | { 108 | "name": "stderr", 109 | "output_type": "stream", 110 | "text": [ 111 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[2]: Mean/PerKey/CombinePerKey(MeanCombineFn)/GroupByKey'. \n", 112 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[2]: Mean/PerKey/CombinePerKey(MeanCombineFn)/GroupByKey'. \n", 113 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[2]: Mean/PerKey/CombinePerKey(MeanCombineFn)'. \n" 114 | ] 115 | }, 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "(Key(offsets={'lat': 0, 'lon': 0}, vars=None), \n", 121 | "Dimensions: (lat: 25, lon: 53)\n", 122 | "Coordinates:\n", 123 | " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", 124 | " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", 125 | "Data variables:\n", 126 | " air (lat, lon) float64 260.4 260.2 259.9 259.5 ... 297.3 297.3 297.3)\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "with beam.Pipeline() as p:\n", 132 | " p | xbeam.DatasetToChunks(ds, chunks={'time': 1000}) | xbeam.Mean('time') | beam.Map(print)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "id": "7a584485", 138 | "metadata": {}, 139 | "source": [ 140 | "## Lower-level API" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "6f7585a9", 146 | "metadata": {}, 147 | "source": [ 148 | "Xarray-Beam also includes lower-level transforations modelled off of [`beam.Mean`](https://beam.apache.org/documentation/transforms/python/aggregation/mean/) rather than {py:meth}`xarray.Dataset.mean`: they compute averages over sequences of `xarray.Dataset` objects or (`key`, `xarray.Dataset`) pairs, rather than calculating an average over an existing Xarray dimension or based on `xarray_beam.Key` objects, e.g.," 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 3, 154 | "id": "86a925af", 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "name": "stderr", 159 | "output_type": "stream", 160 | "text": [ 161 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-b19331c2-e263-4737-bd64-012081154884.json']\n" 162 | ] 163 | }, 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "[\n", 168 | " Dimensions: (x: 3)\n", 169 | " Dimensions without coordinates: x\n", 170 | " Data variables:\n", 171 | " foo (x) float64 0.05667 -0.02306 -0.1648]" 172 | ] 173 | }, 174 | "execution_count": 3, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "datasets = [\n", 181 | " xarray.Dataset({'foo': ('x', np.random.randn(3))})\n", 182 | " for _ in range(100)\n", 183 | "]\n", 184 | "datasets | xbeam.Mean.Globally()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "5670130b", 190 | "metadata": {}, 191 | "source": [ 192 | "Notice how existing dimensions on each datasets are unchanged by the transformation. If you want to average over existing dimensions, use the high-level `Mean` transform or do that aggregation yourself, e.g., by averaging inside each chunk before combining the data." 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "6ab152b6", 198 | "metadata": {}, 199 | "source": [ 200 | "Similarly, the keys fed into `xbeam.Mean.PerKey` can be any hashables, including but not limited to `xbeam.Key`:" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 4, 206 | "id": "b2399483", 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-b19331c2-e263-4737-bd64-012081154884.json']\n" 214 | ] 215 | }, 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "[('DJF',\n", 220 | " \n", 221 | " Dimensions: ()\n", 222 | " Data variables:\n", 223 | " air float64 273.6),\n", 224 | " ('MAM',\n", 225 | " \n", 226 | " Dimensions: ()\n", 227 | " Data variables:\n", 228 | " air float64 279.0),\n", 229 | " ('JJA',\n", 230 | " \n", 231 | " Dimensions: ()\n", 232 | " Data variables:\n", 233 | " air float64 289.2),\n", 234 | " ('SON',\n", 235 | " \n", 236 | " Dimensions: ()\n", 237 | " Data variables:\n", 238 | " air float64 283.0)]" 239 | ] 240 | }, 241 | "execution_count": 4, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "datasets = [\n", 248 | " (time.dt.season.item(), ds.sel(time=time).mean())\n", 249 | " for time in ds.time\n", 250 | "]\n", 251 | "datasets | xbeam.Mean.PerKey()" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "c9db1627", 257 | "metadata": {}, 258 | "source": [ 259 | "`Mean.PerKey` is particularly useful in combination with {class}`beam.GroupByKey` for performing large-scale \"group by\" operations. For example, that a look at the [ERA5 climatology example](https://github.com/google/xarray-beam/blob/main/examples/era5_climatology.py)." 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "id": "0fb6fc6a", 265 | "metadata": {}, 266 | "source": [ 267 | "## Custom aggregations\n", 268 | "\n", 269 | "The \"tree reduction\" algorithm used by the combiner inside `Mean` is great, but it isn't the only way to aggregate a dataset with Xarray-Beam.\n", 270 | "\n", 271 | "In many cases, the easiest way to scale up an aggregation pipeline is to make use of [rechunking](rechunking.ipynb) to convert the many small datasets inside your pipeline into a form that is easier to calculate in a scalable way. However, rechunking is much less efficient than using combiner, because each use of `Rechunk` requires a complete shuffle of the input data (i.e., writing all data in the pipepilne to temporary files on disk).\n", 272 | "\n", 273 | "For example, here's how one could compute the `median`, which is a notoriously difficult statistic to calculate with distributed algorithms:" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 5, 279 | "id": "ef1ef099", 280 | "metadata": { 281 | "scrolled": false 282 | }, 283 | "outputs": [ 284 | { 285 | "name": "stderr", 286 | "output_type": "stream", 287 | "text": [ 288 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n", 289 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n", 290 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n", 291 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n", 292 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: ConsolidateChunks/GroupByTempKeys'. \n", 293 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in '[5]: ConsolidateChunks/GroupByTempKeys'. \n" 294 | ] 295 | }, 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "\n", 301 | "Dimensions: (lat: 25, lon: 53)\n", 302 | "Coordinates:\n", 303 | " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", 304 | " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", 305 | "Data variables:\n", 306 | " air (lat, lon) float32 261.3 261.1 260.9 260.3 ... 297.3 297.3 297.3\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "source_chunks = {'time': 100, 'lat': -1, 'lon': -1}\n", 312 | "working_chunks = {'lat': 10, 'lon': 10, 'time': -1}\n", 313 | "\n", 314 | "with beam.Pipeline() as p:\n", 315 | " (\n", 316 | " p\n", 317 | " | xbeam.DatasetToChunks(ds, source_chunks)\n", 318 | " | xbeam.Rechunk(ds.sizes, source_chunks, working_chunks, itemsize=4)\n", 319 | " | beam.MapTuple(lambda k, v: (k.with_offsets(time=None), v.median('time')))\n", 320 | " | xbeam.ConsolidateChunks({'lat': -1, 'lon': -1})\n", 321 | " | beam.MapTuple(lambda k, v: print(v))\n", 322 | " )" 323 | ] 324 | } 325 | ], 326 | "metadata": { 327 | "kernelspec": { 328 | "display_name": "Python 3 (ipykernel)", 329 | "language": "python", 330 | "name": "python3" 331 | }, 332 | "language_info": { 333 | "codemirror_mode": { 334 | "name": "ipython", 335 | "version": 3 336 | }, 337 | "file_extension": ".py", 338 | "mimetype": "text/x-python", 339 | "name": "python", 340 | "nbconvert_exporter": "python", 341 | "pygments_lexer": "ipython3", 342 | "version": "3.9.13" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 5 347 | } 348 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API docs 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: xarray_beam 5 | ``` 6 | 7 | ## Core data model 8 | 9 | ```{eval-rst} 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | Key 14 | ``` 15 | 16 | ## Reading and writing data 17 | 18 | ```{eval-rst} 19 | .. autosummary:: 20 | :toctree: _autosummary 21 | 22 | open_zarr 23 | DatasetToChunks 24 | ChunksToZarr 25 | DatasetToZarr 26 | make_template 27 | replace_template_dims 28 | setup_zarr 29 | validate_zarr_chunk 30 | write_chunk_to_zarr 31 | ``` 32 | 33 | ## Aggregation 34 | 35 | ```{eval-rst} 36 | .. autosummary:: 37 | :toctree: _autosummary 38 | 39 | Mean 40 | Mean.Globally 41 | Mean.PerKey 42 | MeanCombineFn 43 | ``` 44 | 45 | ## Rechunking 46 | 47 | ```{eval-rst} 48 | .. autosummary:: 49 | :toctree: _autosummary 50 | 51 | ConsolidateChunks 52 | ConsolidateVariables 53 | SplitChunks 54 | SplitVariables 55 | Rechunk 56 | ``` 57 | 58 | ## Utility transforms 59 | 60 | ```{eval-rst} 61 | .. autosummary:: 62 | :toctree: _autosummary 63 | 64 | ValidateEachChunk 65 | ``` 66 | 67 | ## Utility functions 68 | 69 | ```{eval-rst} 70 | .. autosummary:: 71 | :toctree: _autosummary 72 | 73 | offsets_to_slices 74 | validate_chunk 75 | consolidate_chunks 76 | consolidate_variables 77 | consolidate_fully 78 | split_chunks 79 | split_variables 80 | in_memory_rechunk 81 | ``` -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | # Print Python environment info for easier debugging on ReadTheDocs 18 | 19 | import sys 20 | import subprocess 21 | import xarray_beam # verify this works 22 | 23 | print("python exec:", sys.executable) 24 | print("sys.path:", sys.path) 25 | print("pip environment:") 26 | subprocess.run([sys.executable, "-m", "pip", "list"]) 27 | 28 | print(f"xarray_beam: {xarray_beam.__version__}, {xarray_beam.__file__}") 29 | 30 | # -- Project information ----------------------------------------------------- 31 | 32 | project = 'Xarray-Beam' 33 | copyright = '2021, Google LCC' 34 | author = 'The Xarray-Beam authors' 35 | 36 | 37 | # -- General configuration --------------------------------------------------- 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | 'sphinx.ext.autodoc', 44 | 'sphinx.ext.autosummary', 45 | 'sphinx.ext.napoleon', 46 | 'myst_nb', 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['_templates'] 51 | 52 | # List of patterns, relative to source directory, that match files and 53 | # directories to ignore when looking for source files. 54 | # This pattern also affects html_static_path and html_extra_path. 55 | exclude_patterns = ['_build', '_templates', 'Thumbs.db', '.DS_Store'] 56 | 57 | intersphinx_mapping = { 58 | "xarray": ("https://xarray.pydata.org/en/latest/", None), 59 | } 60 | 61 | # -- Options for HTML output ------------------------------------------------- 62 | 63 | # The theme to use for HTML and HTML Help pages. See the documentation for 64 | # a list of builtin themes. 65 | # 66 | html_theme = 'sphinx_rtd_theme' 67 | 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ['_static'] 72 | 73 | # -- Extension config 74 | 75 | autosummary_generate = True 76 | 77 | # https://myst-nb.readthedocs.io/en/latest/use/execute.html 78 | jupyter_execute_notebooks = "cache" 79 | # https://myst-nb.readthedocs.io/en/latest/use/formatting_outputs.html#removing-stdout-and-stderr 80 | nb_output_stderr = "remove-warn" 81 | 82 | # https://stackoverflow.com/a/66295922/809705 83 | autodoc_typehints = "description" 84 | -------------------------------------------------------------------------------- /docs/data-model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Core data model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Xarray-Beam tries to make it _straightforward_ to write distributed pipelines with Xarray objects, but unlike libraries like [Xarray with Dask](http://xarray.pydata.org/en/stable/user-guide/dask.html) or Dask/Spark DataFrames, it doesn't hide the distributed magic inside high-level objects.\n", 15 | "\n", 16 | "Xarray-Beam is a lower-level tool. You will be manipulating large datasets piece-by-piece yourself, and you as the developer will be responsible for maintaining Xarray-Beam's internal invariants. This means that to successfully use Xarray-Beam, **you will need to understand how how it represents distributed datasets**.\n", 17 | "\n", 18 | "This responsibility requires a bit more coding and understanding, but offers benefits in performance and flexibility. This brief tutorial will show you how.\n", 19 | "\n", 20 | "We'll start off with some standard imports:" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import apache_beam as beam\n", 30 | "import numpy as np\n", 31 | "import xarray_beam as xbeam\n", 32 | "import xarray" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Keys in Xarray-Beam\n", 40 | "\n", 41 | "Xarray-Beam is designed around the model that every stage in your Beam pipeline _could_ be stored in a single `xarray.Dataset` object, but is instead represented by a distributed beam `PCollection` of smaller `xarray.Dataset` objects, distributed in two possible ways:\n", 42 | "\n", 43 | "- Distinct _variables_ in a Dataset may be separated across multiple records.\n", 44 | "- Individual arrays can also be split into multiple _chunks_, similar to those used by [dask.array](https://docs.dask.org/en/latest/array.html).\n", 45 | "\n", 46 | "To keep track of how individual records could be combined into a larger (virtual) dataset, Xarray-Beam defines a {py:class}`~xarray_beam.Key` object. Key objects consist of:\n", 47 | "\n", 48 | "1. `offsets`: integer offests for chunks from the origin in an `immutabledict`\n", 49 | "2. `vars`: The subset of variables included in each chunk, either as a `frozenset`, or as `None` to indicate \"all variables\"." 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "Making a {py:class}`~xarray_beam.Key` from scratch is simple:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "Key(offsets={'x': 0, 'y': 10}, vars=None)" 68 | ] 69 | }, 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "key = xbeam.Key({'x': 0, 'y': 10}, vars=None)\n", 77 | "key" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "Or given an existing {py:class}`~xarray_beam.Key`, you can easily modify it with `replace()` or `with_offsets()`:" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "Key(offsets={'x': 0, 'y': 10}, vars={'bar', 'foo'})" 96 | ] 97 | }, 98 | "execution_count": 3, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "key.replace(vars={'foo', 'bar'})" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "Key(offsets={'y': 10, 'z': 1}, vars=None)" 116 | ] 117 | }, 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "key.with_offsets(x=None, z=1)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "{py:class}`~xarray_beam.Key` objects don't do very much. They are just simple structs with two attributes, along with various special methods required to use them as `dict` keys or as keys in Beam pipelines. You can find a more examples of manipulating keys in the docstring of {py:class}`~xarray_beam.Key`." 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "## Creating PCollections" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "The standard inputs & outputs for Xarray-Beam are PCollections of `(xbeam.Key, xarray.Dataset)` pairs. Xarray-Beam provides a bunch of PCollections for typical tasks, but many pipelines will still involve some manual manipulation of `Key` and `Dataset` objects, e.g., with builtin Beam transforms like `beam.Map`.\n", 146 | "\n", 147 | "To start off, let's write a helper functions for creating our first collection from scratch:" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 5, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "def create_records():\n", 157 | " for offset in [0, 4]:\n", 158 | " key = xbeam.Key({'x': offset, 'y': 0})\n", 159 | " data = 2 * offset + np.arange(8).reshape(4, 2)\n", 160 | " chunk = xarray.Dataset({\n", 161 | " 'foo': (('x', 'y'), data),\n", 162 | " 'bar': (('x', 'y'), 100 + data),\n", 163 | " })\n", 164 | " yield key, chunk" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "Let's take a look the entries, which are lazily constructed with the generator:" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 6, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "inputs = list(create_records())" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 7, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/plain": [ 191 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n", 192 | " \n", 193 | " Dimensions: (x: 4, y: 2)\n", 194 | " Dimensions without coordinates: x, y\n", 195 | " Data variables:\n", 196 | " foo (x, y) int64 0 1 2 3 4 5 6 7\n", 197 | " bar (x, y) int64 100 101 102 103 104 105 106 107),\n", 198 | " (Key(offsets={'x': 4, 'y': 0}, vars=None),\n", 199 | " \n", 200 | " Dimensions: (x: 4, y: 2)\n", 201 | " Dimensions without coordinates: x, y\n", 202 | " Data variables:\n", 203 | " foo (x, y) int64 8 9 10 11 12 13 14 15\n", 204 | " bar (x, y) int64 108 109 110 111 112 113 114 115)]" 205 | ] 206 | }, 207 | "execution_count": 7, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "inputs" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "```{note}\n", 221 | "There are multiple valid ways to represent a chunk of a larger dataset with a `Key`.\n", 222 | "\n", 223 | "- **Offsets for unchunked dimensions are optional**. Because all chunks have the same offset along the `y` axis, including `y` in `offsets` is not required as long as we don't need to create multiple chunks along that dimension.\n", 224 | "- **Indicating variables is optional, if all chunks have the same variables**. We could have set `vars={'foo', 'bar'}` on each of these `Key` objects instead of `vars=None`. This would be an equally valid representation of the same records, since all of our datasets have the same variables.\n", 225 | "```" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "We now have the inputs we need to use Xarray-Beam's helper functions and PTransforms. For example, we can fully consolidate chunks & variables to see what single `xarray.Dataset` these values would correspond to:" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "(Key(offsets={'x': 0, 'y': 0}, vars={'bar', 'foo'}),\n", 244 | " \n", 245 | " Dimensions: (x: 8, y: 2)\n", 246 | " Dimensions without coordinates: x, y\n", 247 | " Data variables:\n", 248 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15\n", 249 | " bar (x, y) int64 100 101 102 103 104 105 ... 110 111 112 113 114 115)" 250 | ] 251 | }, 252 | "execution_count": 8, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "xbeam.consolidate_fully(inputs)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "To execute with Beam, of course, we need to turn Python lists/generators into Beam PCollections, e.g., with `beam.Create()`:" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 9, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "application/javascript": [ 276 | "\n", 277 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n", 278 | " var jqueryScript = document.createElement('script');\n", 279 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", 280 | " jqueryScript.type = 'text/javascript';\n", 281 | " jqueryScript.onload = function() {\n", 282 | " var datatableScript = document.createElement('script');\n", 283 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", 284 | " datatableScript.type = 'text/javascript';\n", 285 | " datatableScript.onload = function() {\n", 286 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n", 287 | " window.interactive_beam_jquery(document).ready(function($){\n", 288 | " \n", 289 | " });\n", 290 | " }\n", 291 | " document.head.appendChild(datatableScript);\n", 292 | " };\n", 293 | " document.head.appendChild(jqueryScript);\n", 294 | " } else {\n", 295 | " window.interactive_beam_jquery(document).ready(function($){\n", 296 | " \n", 297 | " });\n", 298 | " }" 299 | ] 300 | }, 301 | "metadata": {}, 302 | "output_type": "display_data" 303 | }, 304 | { 305 | "name": "stdout", 306 | "output_type": "stream", 307 | "text": [ 308 | "(Key(offsets={'x': 0, 'y': 0}, vars=None), \n", 309 | "Dimensions: (x: 4, y: 2)\n", 310 | "Dimensions without coordinates: x, y\n", 311 | "Data variables:\n", 312 | " foo (x, y) int64 0 1 2 3 4 5 6 7\n", 313 | " bar (x, y) int64 100 101 102 103 104 105 106 107)\n", 314 | "(Key(offsets={'x': 4, 'y': 0}, vars=None), \n", 315 | "Dimensions: (x: 4, y: 2)\n", 316 | "Dimensions without coordinates: x, y\n", 317 | "Data variables:\n", 318 | " foo (x, y) int64 8 9 10 11 12 13 14 15\n", 319 | " bar (x, y) int64 108 109 110 111 112 113 114 115)\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "with beam.Pipeline() as p:\n", 325 | " p | beam.Create(create_records()) | beam.Map(print)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "## Writing pipelines" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "Transforms in Xarray-Beam typically act on (key, value) pairs of `(xbeam.Key, xarray.Dataset)`. For example, we can dump our dataset on disk in the scalable [Zarr](https://zarr.readthedocs.io/) format using {py:class}`~xarray_beam.ChunksToZarr`:" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 10, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stderr", 349 | "output_type": "stream", 350 | "text": [ 351 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2d587cc9-3460-4867-a777-2f0c8bc4e743.json']\n" 352 | ] 353 | }, 354 | { 355 | "data": { 356 | "text/plain": [ 357 | "[None, None]" 358 | ] 359 | }, 360 | "execution_count": 10, 361 | "metadata": {}, 362 | "output_type": "execute_result" 363 | } 364 | ], 365 | "source": [ 366 | "inputs | xbeam.ChunksToZarr('my-data.zarr')" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "Xarray-Beam doesn't try to provide transformations for everything. In particular, it omits most [embarrassingly parallel](https://en.wikipedia.org/wiki/Embarrassingly_parallel) operations that can be performed independently on each chunk of a larger dataset. You can write these yourself using [`beam.Map`](https://beam.apache.org/documentation/transforms/python/elementwise/map/).\n", 374 | "\n", 375 | "For example, consider elementwise arithmetic. We can write a `lambda` function that acts on each key-value pair updating the xarray.Dataset objects appropriately, and put it into an Xarray-Beam pipeline using `beam.MapTuple`:" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 11, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "name": "stderr", 385 | "output_type": "stream", 386 | "text": [ 387 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-beb6b630-5659-435b-8a62-c213f6177340.json']\n" 388 | ] 389 | }, 390 | { 391 | "data": { 392 | "text/plain": [ 393 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n", 394 | " \n", 395 | " Dimensions: (x: 4, y: 2)\n", 396 | " Dimensions without coordinates: x, y\n", 397 | " Data variables:\n", 398 | " foo (x, y) int64 1 2 3 4 5 6 7 8\n", 399 | " bar (x, y) int64 101 102 103 104 105 106 107 108),\n", 400 | " (Key(offsets={'x': 4, 'y': 0}, vars=None),\n", 401 | " \n", 402 | " Dimensions: (x: 4, y: 2)\n", 403 | " Dimensions without coordinates: x, y\n", 404 | " Data variables:\n", 405 | " foo (x, y) int64 9 10 11 12 13 14 15 16\n", 406 | " bar (x, y) int64 109 110 111 112 113 114 115 116)]" 407 | ] 408 | }, 409 | "execution_count": 11, 410 | "metadata": {}, 411 | "output_type": "execute_result" 412 | } 413 | ], 414 | "source": [ 415 | "inputs | beam.MapTuple(lambda k, v: (k, v + 1))" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "For operations that add or remove (unchunked) dimensions, you may need to update `Key` objects as well to maintain the Xarray-Beam invariants, e.g., if we want to remove the `y` dimension entirely:" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 12, 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "name": "stderr", 432 | "output_type": "stream", 433 | "text": [ 434 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-beb6b630-5659-435b-8a62-c213f6177340.json']\n" 435 | ] 436 | }, 437 | { 438 | "data": { 439 | "text/plain": [ 440 | "[(Key(offsets={'x': 0}, vars=None),\n", 441 | " \n", 442 | " Dimensions: (x: 4)\n", 443 | " Dimensions without coordinates: x\n", 444 | " Data variables:\n", 445 | " foo (x) float64 0.5 2.5 4.5 6.5\n", 446 | " bar (x) float64 100.5 102.5 104.5 106.5),\n", 447 | " (Key(offsets={'x': 4}, vars=None),\n", 448 | " \n", 449 | " Dimensions: (x: 4)\n", 450 | " Dimensions without coordinates: x\n", 451 | " Data variables:\n", 452 | " foo (x) float64 8.5 10.5 12.5 14.5\n", 453 | " bar (x) float64 108.5 110.5 112.5 114.5)]" 454 | ] 455 | }, 456 | "execution_count": 12, 457 | "metadata": {}, 458 | "output_type": "execute_result" 459 | } 460 | ], 461 | "source": [ 462 | "inputs | beam.MapTuple(lambda k, v: (k.with_offsets(y=None), v.mean('y')))" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "```{note}\n", 470 | "Missing transformations in Xarray-Beam is partially an intentional design decision to reduce scope, and partially just a reflection of what we've gotten around to implementing. If after reading through the rest of docs you notice missing transformations or are wondering how to compute something in Xarray-Beam, please [open an issue](https://github.com/google/xarray-beam/issues) to discuss.\n", 471 | "```" 472 | ] 473 | } 474 | ], 475 | "metadata": { 476 | "kernelspec": { 477 | "display_name": "Python 3 (ipykernel)", 478 | "language": "python", 479 | "name": "python3" 480 | }, 481 | "language_info": { 482 | "codemirror_mode": { 483 | "name": "ipython", 484 | "version": 3 485 | }, 486 | "file_extension": ".py", 487 | "mimetype": "text/x-python", 488 | "name": "python", 489 | "nbconvert_exporter": "python", 490 | "pygments_lexer": "ipython3", 491 | "version": "3.9.13" 492 | } 493 | }, 494 | "nbformat": 4, 495 | "nbformat_minor": 2 496 | } 497 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Xarray-Beam: distributed Xarray with Apache Beam 2 | 3 | Xarray-Beam is a library for writing [Apache Beam](http://beam.apache.org/) pipelines consisting of [xarray](http://xarray.pydata.org) Dataset objects. This documentation (and Xarray-Beam itself) assumes basic familiarity with both Beam and Xarray. 4 | 5 | The documentation includes narrative documentation that will walk you through the basics of writing a pipeline with Xarray-Beam, and also comprehensive API docs. 6 | 7 | We recommend reading both, as well as a few [end to end examples](https://github.com/google/xarray-beam/tree/main/examples) to understand what code using Xarray-Beam typically looks like. 8 | 9 | ## Contents 10 | 11 | ```{toctree} 12 | :maxdepth: 1 13 | why-xarray-beam.md 14 | data-model.ipynb 15 | read-write.ipynb 16 | aggregation.ipynb 17 | rechunking.ipynb 18 | api.md 19 | ``` -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/rechunking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Rechunking" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Rechunking lets us re-distribute how datasets are split between variables and chunks across a Beam PCollection." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "To get started we'll recreate our dummy data from the data model tutorial:" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "tags": [ 29 | "hide-inputs" 30 | ] 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import apache_beam as beam\n", 35 | "import numpy as np\n", 36 | "import xarray_beam as xbeam\n", 37 | "import xarray\n", 38 | "\n", 39 | "def create_records():\n", 40 | " for offset in [0, 4]:\n", 41 | " key = xbeam.Key({'x': offset, 'y': 0})\n", 42 | " data = 2 * offset + np.arange(8).reshape(4, 2)\n", 43 | " chunk = xarray.Dataset({\n", 44 | " 'foo': (('x', 'y'), data),\n", 45 | " 'bar': (('x', 'y'), 100 + data),\n", 46 | " })\n", 47 | " yield key, chunk\n", 48 | " \n", 49 | "inputs = list(create_records())" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Choosing chunks\n", 57 | "\n", 58 | "Chunking can be essential for some operations. Some operations are very hard or impossible to perform with certain chunking schemes. For example, to make a plot all the data needs to come toether on a single machine. Other calculations such as calculating a median are _possible_ to perform on distributed data, but require tricky algorithms and/or approximation.\n", 59 | "\n", 60 | "More broadly, chunking can have critical performance implications, similar to [those for Xarray and Dask](http://xarray.pydata.org/en/stable/user-guide/dask.html#chunking-and-performance). As a rule of thumb, chunk sizes of 10-100 MB work well. The optimal chunk size is a balance among a number of considerations, adapted here [from Dask docs](https://docs.dask.org/en/latest/array-chunks.html):\n", 61 | "\n", 62 | "1. Chunks should be small enough to fit comfortably into memory on a single machine. As an upper limit, chunks over roughly 2 GB in size will not fit into the protocol buffers Beam uses to pass data between workers. \n", 63 | "2. There should be enough chunks for Beam runners (like Cloud Dataflow) to elastically shard work over many workers.\n", 64 | "3. Chunks should be large enough to amortize the overhead of networking and the Python interpreter, which starts to become noticeable for arrays with fewer than 1 million elements.\n", 65 | "\n", 66 | "The `nbytes` attribute on both NumPy arrays and `xarray.Dataset` objects is a good easy way to figure out how larger chunks are." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Adjusting variables" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "The simplest transformation is splitting (or consoldating) different _variables_ in a Dataset with {py:class}`~xarray_beam.SplitVariables()` and {py:class}`~xarray_beam.ConsolidateVariables()`, e.g.," 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "name": "stderr", 90 | "output_type": "stream", 91 | "text": [ 92 | "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n" 93 | ] 94 | }, 95 | { 96 | "data": { 97 | "application/javascript": [ 98 | "\n", 99 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n", 100 | " var jqueryScript = document.createElement('script');\n", 101 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", 102 | " jqueryScript.type = 'text/javascript';\n", 103 | " jqueryScript.onload = function() {\n", 104 | " var datatableScript = document.createElement('script');\n", 105 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", 106 | " datatableScript.type = 'text/javascript';\n", 107 | " datatableScript.onload = function() {\n", 108 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n", 109 | " window.interactive_beam_jquery(document).ready(function($){\n", 110 | " \n", 111 | " });\n", 112 | " }\n", 113 | " document.head.appendChild(datatableScript);\n", 114 | " };\n", 115 | " document.head.appendChild(jqueryScript);\n", 116 | " } else {\n", 117 | " window.interactive_beam_jquery(document).ready(function($){\n", 118 | " \n", 119 | " });\n", 120 | " }" 121 | ] 122 | }, 123 | "metadata": {}, 124 | "output_type": "display_data" 125 | }, 126 | { 127 | "name": "stderr", 128 | "output_type": "stream", 129 | "text": [ 130 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n", 131 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n" 132 | ] 133 | }, 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "[(Key(offsets={'x': 0, 'y': 0}, vars={'foo'}),\n", 138 | " \n", 139 | " Dimensions: (x: 4, y: 2)\n", 140 | " Dimensions without coordinates: x, y\n", 141 | " Data variables:\n", 142 | " foo (x, y) int64 0 1 2 3 4 5 6 7),\n", 143 | " (Key(offsets={'x': 0, 'y': 0}, vars={'bar'}),\n", 144 | " \n", 145 | " Dimensions: (x: 4, y: 2)\n", 146 | " Dimensions without coordinates: x, y\n", 147 | " Data variables:\n", 148 | " bar (x, y) int64 100 101 102 103 104 105 106 107),\n", 149 | " (Key(offsets={'x': 4, 'y': 0}, vars={'foo'}),\n", 150 | " \n", 151 | " Dimensions: (x: 4, y: 2)\n", 152 | " Dimensions without coordinates: x, y\n", 153 | " Data variables:\n", 154 | " foo (x, y) int64 8 9 10 11 12 13 14 15),\n", 155 | " (Key(offsets={'x': 4, 'y': 0}, vars={'bar'}),\n", 156 | " \n", 157 | " Dimensions: (x: 4, y: 2)\n", 158 | " Dimensions without coordinates: x, y\n", 159 | " Data variables:\n", 160 | " bar (x, y) int64 108 109 110 111 112 113 114 115)]" 161 | ] 162 | }, 163 | "execution_count": 3, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "inputs | xbeam.SplitVariables()" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "```{tip}\n", 177 | "Instead of a separate transform for splitting variables, you can also set `split_vars=True` in {py:class}`~xarray_beam.DatasetToChunks`.\n", 178 | "```" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "## Adjusting chunks" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "You can also adjust _chunks_ in a dataset to distribute arrays of different sizes. Here you have two choices of API:\n", 193 | "\n", 194 | "1. The lower level {py:class}`~xarray_beam.SplitChunks` and {py:class}`~xarray_beam.ConsolidateChunks`. These transformations apply a single splitting (with indexing) or consolidation (with {py:function}`xarray.concat`) function to array elements.\n", 195 | "2. The high level {py:class}`~xarray_beam.Rechunk`, which uses a pipeline of multiple split/consolidate steps (as needed) to efficiently rechunk a dataset.\n" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### Low level rechunking" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "For minor adjustments (e.g., mostly along a single dimension), the more explicit `SplitChunks()` and `ConsolidateChunks()` are good options. They take a dict of _desired_ chunk sizes as a parameter, which can also be `-1` to indicate \"no chunking\" along a dimension:" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 4, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stderr", 219 | "output_type": "stream", 220 | "text": [ 221 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n", 222 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n", 223 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n", 224 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n" 225 | ] 226 | }, 227 | { 228 | "data": { 229 | "text/plain": [ 230 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n", 231 | " \n", 232 | " Dimensions: (x: 8, y: 2)\n", 233 | " Dimensions without coordinates: x, y\n", 234 | " Data variables:\n", 235 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15\n", 236 | " bar (x, y) int64 100 101 102 103 104 105 ... 110 111 112 113 114 115)]" 237 | ] 238 | }, 239 | "execution_count": 4, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "inputs | xbeam.ConsolidateChunks({'x': -1})" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "Note that because these transformations only split _or_ consolidate, they cannot necessary fully rechunk a dataset in a single step if the new chunk sizes are not multiples of old chunks (with consolidate) or do not even divide the old chunks (with split), e.g.," 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 5, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stderr", 262 | "output_type": "stream", 263 | "text": [ 264 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n", 265 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n" 266 | ] 267 | }, 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n", 272 | " \n", 273 | " Dimensions: (x: 4, y: 2)\n", 274 | " Dimensions without coordinates: x, y\n", 275 | " Data variables:\n", 276 | " foo (x, y) int64 0 1 2 3 4 5 6 7\n", 277 | " bar (x, y) int64 100 101 102 103 104 105 106 107),\n", 278 | " (Key(offsets={'x': 4, 'y': 0}, vars=None),\n", 279 | " \n", 280 | " Dimensions: (x: 1, y: 2)\n", 281 | " Dimensions without coordinates: x, y\n", 282 | " Data variables:\n", 283 | " foo (x, y) int64 8 9\n", 284 | " bar (x, y) int64 108 109),\n", 285 | " (Key(offsets={'x': 5, 'y': 0}, vars=None),\n", 286 | " \n", 287 | " Dimensions: (x: 3, y: 2)\n", 288 | " Dimensions without coordinates: x, y\n", 289 | " Data variables:\n", 290 | " foo (x, y) int64 10 11 12 13 14 15\n", 291 | " bar (x, y) int64 110 111 112 113 114 115)]" 292 | ] 293 | }, 294 | "execution_count": 5, 295 | "metadata": {}, 296 | "output_type": "execute_result" 297 | } 298 | ], 299 | "source": [ 300 | "inputs | xbeam.SplitChunks({'x': 5}) # notice that the first two chunks are still separate!" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "For such uneven cases, you'll need to use split followed by consolidate:" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 6, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stderr", 317 | "output_type": "stream", 318 | "text": [ 319 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n", 320 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n", 321 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n", 322 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n", 323 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n", 324 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'ConsolidateChunks/GroupByTempKeys'. \n" 325 | ] 326 | }, 327 | { 328 | "data": { 329 | "text/plain": [ 330 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n", 331 | " \n", 332 | " Dimensions: (x: 5, y: 2)\n", 333 | " Dimensions without coordinates: x, y\n", 334 | " Data variables:\n", 335 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9\n", 336 | " bar (x, y) int64 100 101 102 103 104 105 106 107 108 109),\n", 337 | " (Key(offsets={'x': 5, 'y': 0}, vars=None),\n", 338 | " \n", 339 | " Dimensions: (x: 3, y: 2)\n", 340 | " Dimensions without coordinates: x, y\n", 341 | " Data variables:\n", 342 | " foo (x, y) int64 10 11 12 13 14 15\n", 343 | " bar (x, y) int64 110 111 112 113 114 115)]" 344 | ] 345 | }, 346 | "execution_count": 6, 347 | "metadata": {}, 348 | "output_type": "execute_result" 349 | } 350 | ], 351 | "source": [ 352 | "inputs | xbeam.SplitChunks({'x': 5}) | xbeam.ConsolidateChunks({'x': 5})" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "### High level rechunking" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": {}, 365 | "source": [ 366 | "Alternatively, the high-level `Rechunk()` method applies multiple split and consolidate steps based on the [Rechunker](https://github.com/pangeo-data/rechunker) algorithm:" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 7, 372 | "metadata": {}, 373 | "outputs": [ 374 | { 375 | "name": "stderr", 376 | "output_type": "stream", 377 | "text": [ 378 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/shoyer/miniconda3/envs/xarray-beam/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/Users/shoyer/Library/Jupyter/runtime/kernel-2b08a4f7-b064-490e-9ac1-e22ceca4c6cd.json']\n", 379 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.9 interpreter.\n", 380 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n", 381 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage0/Consolidate/GroupByTempKeys'. \n", 382 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n", 383 | "WARNING:apache_beam.coders.coder_impl:Using fallback deterministic coder for type '' in 'Rechunk/Stage2/Consolidate/GroupByTempKeys'. \n" 384 | ] 385 | }, 386 | { 387 | "data": { 388 | "text/plain": [ 389 | "[(Key(offsets={'x': 0, 'y': 0}, vars=None),\n", 390 | " \n", 391 | " Dimensions: (x: 5, y: 2)\n", 392 | " Dimensions without coordinates: x, y\n", 393 | " Data variables:\n", 394 | " foo (x, y) int64 0 1 2 3 4 5 6 7 8 9\n", 395 | " bar (x, y) int64 100 101 102 103 104 105 106 107 108 109),\n", 396 | " (Key(offsets={'x': 5, 'y': 0}, vars=None),\n", 397 | " \n", 398 | " Dimensions: (x: 3, y: 2)\n", 399 | " Dimensions without coordinates: x, y\n", 400 | " Data variables:\n", 401 | " foo (x, y) int64 10 11 12 13 14 15\n", 402 | " bar (x, y) int64 110 111 112 113 114 115)]" 403 | ] 404 | }, 405 | "execution_count": 7, 406 | "metadata": {}, 407 | "output_type": "execute_result" 408 | } 409 | ], 410 | "source": [ 411 | "inputs | xbeam.Rechunk(dim_sizes={'x': 6}, source_chunks={'x': 3}, target_chunks={'x': 5}, itemsize=8)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": {}, 417 | "source": [ 418 | "`Rechunk` requires specifying a few more parameters, but based on that information **it can be _much_ more efficient for more complex rechunking tasks**, particular in cases where data needs to be distributed into a very different shape (e.g., distributing a matrix across rows vs. columns).\n", 419 | "\n", 420 | "The naive \"splitting\" approach in such cases may divide datasets into extremely small tasks corresponding to individual array elements, which adds a huge amount of overhead." 421 | ] 422 | } 423 | ], 424 | "metadata": { 425 | "celltoolbar": "Tags", 426 | "interpreter": { 427 | "hash": "aef148d7ea0dbd1f91630322dd5bc9e24a2135d95f24fe1a9dab9696856be2b9" 428 | }, 429 | "kernelspec": { 430 | "display_name": "Python 3 (ipykernel)", 431 | "language": "python", 432 | "name": "python3" 433 | }, 434 | "language_info": { 435 | "codemirror_mode": { 436 | "name": "ipython", 437 | "version": 3 438 | }, 439 | "file_extension": ".py", 440 | "mimetype": "text/x-python", 441 | "name": "python", 442 | "nbconvert_exporter": "python", 443 | "pygments_lexer": "ipython3", 444 | "version": "3.9.13" 445 | } 446 | }, 447 | "nbformat": 4, 448 | "nbformat_minor": 2 449 | } 450 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # doc requirements 2 | Jinja2==3.1.3 3 | myst-nb==0.17.2 4 | myst-parser==0.18.1 5 | sphinx_rtd_theme==1.2.1 6 | sphinx==5.3.0 7 | scipy==1.10.1 8 | 9 | # xarray-beam requirements 10 | apache-beam==2.47.0 11 | dask==2023.5.1 12 | immutabledict==2.2.4 13 | numpy==1.24.3 14 | pandas==1.5.3 15 | pooch==1.7.0 16 | rechunker==0.5.1 17 | xarray==2023.5.0 18 | zarr==2.14.2 19 | -------------------------------------------------------------------------------- /docs/why-xarray-beam.md: -------------------------------------------------------------------------------- 1 | # Why Xarray-Beam 2 | 3 | ## Our goals 4 | 5 | Xarray-Beam is a Python library for building 6 | [Apache Beam](https://beam.apache.org/) pipelines with 7 | [Xarray](http://xarray.pydata.org/en/stable/) datasets. 8 | 9 | The project aims to facilitate data transformations and analysis on large-scale 10 | multi-dimensional labeled arrays, such as: 11 | 12 | - Ad-hoc computation on Xarray data, by dividing a `xarray.Dataset` into many 13 | smaller pieces ("chunks"). 14 | - Adjusting array chunks, using the 15 | [Rechunker algorithm](https://rechunker.readthedocs.io/en/latest/algorithm.html). 16 | - Ingesting large, multi-dimensional array datasets into an analysis-ready, 17 | cloud-optimized format, namely [Zarr](https://zarr.readthedocs.io/) (see 18 | also [Pangeo Forge](https://github.com/pangeo-forge/pangeo-forge-recipes)). 19 | - Calculating statistics (e.g., "climatology") across distributed datasets 20 | with arbitrary groups. 21 | 22 | ## Our approach 23 | 24 | In Xarray-Beam, distributed Xarray datasets are represented by Beam PCollections 25 | of `(xarray_beam.Key, xarray.Dataset)` pairs, corresponding to a "chunk" of a 26 | larger (virtual) dataset. The {py:class}`~xarray_beam.Key` provides sufficient 27 | metadata for Beam PTransforms like those included in Xarray-Beam to perform 28 | collective operations on the entire dataset. This chunking model is highly 29 | flexible, allowing datasets to be split across multiple variables and/or 30 | into orthogonal, contiguous "chunks" along dimensions. 31 | 32 | Xarray-Beam does not (yet) include high-level abstrations like a "distributed 33 | dataset" object. Users need to have a mental model for how their data pipeline 34 | is distributed across many machines, which is facilitated by its direct 35 | representation as a Beam pipeline. (In our experience, building such a mental 36 | model is basically required to get good performance out of large-scale 37 | pipelines, anyways.) 38 | 39 | Implementation wise, Xarray-Beam is a _thin layer_ on top of existing libraries 40 | for working with large-scale Xarray datasets. For example, it leverages 41 | [Dask](https://dask.org/) for describing lazy arrays and for executing 42 | multi-threaded computation on a single machine. 43 | 44 | ## How does Dask compare? 45 | 46 | We love Dask! Xarray-Beam explores a different part of the design space for 47 | distributed data pipelines than Xarray's built-in Dask integration: 48 | 49 | - Xarray-Beam is built around explicit manipulation of `(xarray_beam.Key, 50 | xarray.Dataset)`. This requires more boilerplate but is also 51 | more robust than generating distributed computation graphs in Dask using 52 | Xarray's built-in API. 53 | - Xarray-Beam distributes datasets by splitting them into many 54 | `xarray.Dataset` chunks, rather than the chunks of NumPy arrays typically 55 | used by Xarray with Dask (unless using 56 | [xarray.map_blocks](http://xarray.pydata.org/en/stable/user-guide/dask.html#automatic-parallelization-with-apply-ufunc-and-map-blocks)). 57 | Chunks of datasets is a more convenient data-model for writing ad-hoc whole 58 | dataset transformations, but is potentially a bit less efficient. 59 | - Beam ([like Spark](https://docs.dask.org/en/latest/spark.html)) was designed 60 | around a higher-level model for distributed computation than Dask (although 61 | Dask has been making 62 | [progress in this direction](https://coiled.io/blog/dask-under-the-hood-scheduler-refactor/)). 63 | Roughly speaking, this trade-off favors scalability over flexibility. 64 | - Beam allows for executing distributed computation using multiple runners, 65 | notably including Google Cloud Dataflow and Apache Spark. These runners are 66 | more mature than Dask, and in many cases are supported as a service by major 67 | commercial cloud providers. 68 | 69 | ![Xarray-Beam datamodel vs Xarray-Dask](./_static/xarray-beam-vs-xarray-dask.png) 70 | 71 | These design choices are not set in stone. In particular, in the future we 72 | _could_ imagine writing a high-level `xarray_beam.Dataset` that emulates the 73 | `xarray.Dataset` API, similar to the popular high-level DataFrame APIs in Beam, 74 | Spark and Dask. This could be built on top of the lower-level transformations 75 | currently in Xarray-Beam, or alternatively could use a "chunks of NumPy arrays" 76 | representation similar to that used by dask.array. 77 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Xarray-Beam examples 2 | 3 | The examples in this directory use the ERA5 surface dataset (as assembled by 4 | Pangeo), which consists of 19 data variables stored in float32 precision. It 5 | totals 24.8 TB in size: 6 | 7 | ``` 8 | >>> xarray.open_zarr('gs://pangeo-era5/reanalysis/spatial-analysis', consolidated=True) 9 | 10 | Dimensions: (latitude: 721, longitude: 1440, time: 350640) 11 | Coordinates: 12 | * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 13 | * longitude (longitude) float32 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8 14 | * time (time) datetime64[ns] 1979-01-01 ... 2018-12-31T23:00:00 15 | Data variables: (12/17) 16 | asn (time, latitude, longitude) float32 dask.array 17 | d2m (time, latitude, longitude) float32 dask.array 18 | e (time, latitude, longitude) float32 dask.array 19 | mn2t (time, latitude, longitude) float32 dask.array 20 | mx2t (time, latitude, longitude) float32 dask.array 21 | ptype (time, latitude, longitude) float32 dask.array 22 | ... ... 23 | tcc (time, latitude, longitude) float32 dask.array 24 | tcrw (time, latitude, longitude) float32 dask.array 25 | tp (time, latitude, longitude) float32 dask.array 26 | tsn (time, latitude, longitude) float32 dask.array 27 | u10 (time, latitude, longitude) float32 dask.array 28 | v10 (time, latitude, longitude) float32 dask.array 29 | Attributes: 30 | Conventions: CF-1.6 31 | history: 2019-09-20 05:15:01 GMT by grib_to_netcdf-2.10.0: /opt/ecmw... 32 | ``` 33 | 34 | TODO(shoyer): add instructions for running these examples using Google Cloud 35 | DataFlow. 36 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/era5_climatology.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Calculate climatology for the Pangeo ERA5 surface dataset.""" 15 | from typing import Tuple 16 | 17 | from absl import app 18 | from absl import flags 19 | import apache_beam as beam 20 | import numpy as np 21 | import xarray 22 | import xarray_beam as xbeam 23 | 24 | 25 | INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path') 26 | OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path') 27 | RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') 28 | 29 | 30 | # pylint: disable=expression-not-assigned 31 | 32 | 33 | def rekey_chunk_on_month_hour( 34 | key: xbeam.Key, dataset: xarray.Dataset 35 | ) -> Tuple[xbeam.Key, xarray.Dataset]: 36 | """Replace the 'time' dimension with 'month'/'hour'.""" 37 | month = dataset.time.dt.month.item() 38 | hour = dataset.time.dt.hour.item() 39 | new_key = key.with_offsets(time=None, month=month - 1, hour=hour) 40 | new_dataset = dataset.squeeze('time', drop=True).expand_dims( 41 | month=[month], hour=[hour] 42 | ) 43 | return new_key, new_dataset 44 | 45 | 46 | def main(argv): 47 | source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value) 48 | 49 | # This lazy "template" allows us to setup the Zarr outputs before running the 50 | # pipeline. 51 | max_month = source_dataset.time.dt.month.max().item() # normally 12 52 | template = ( 53 | xbeam.make_template(source_dataset) 54 | .isel(time=0, drop=True) 55 | .expand_dims(month=np.arange(1, max_month + 1), hour=np.arange(24)) 56 | ) 57 | output_chunks = {'hour': 1, 'month': 1} 58 | 59 | with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: 60 | ( 61 | root 62 | | xbeam.DatasetToChunks(source_dataset, source_chunks) 63 | | xbeam.SplitChunks({'time': 1}) 64 | | beam.MapTuple(rekey_chunk_on_month_hour) 65 | | xbeam.Mean.PerKey() 66 | | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks) 67 | ) 68 | 69 | 70 | if __name__ == '__main__': 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /examples/era5_climatology_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for era5_climatology.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import flagsaver 18 | import xarray 19 | 20 | from . import era5_climatology 21 | from xarray_beam._src import test_util 22 | 23 | 24 | class Era5ClimatologyTest(test_util.TestCase): 25 | 26 | def test(self): 27 | input_path = self.create_tempdir('source').full_path 28 | output_path = self.create_tempdir('destination').full_path 29 | 30 | input_ds = test_util.dummy_era5_surface_dataset(times=90 * 24, freq='1H') 31 | input_ds.chunk({'time': 31}).to_zarr(input_path) 32 | 33 | expected = input_ds.groupby('time.month').apply( 34 | lambda x: x.groupby('time.hour').mean('time') 35 | ) 36 | 37 | with flagsaver.flagsaver( 38 | input_path=input_path, 39 | output_path=output_path, 40 | ): 41 | era5_climatology.main([]) 42 | 43 | actual = xarray.open_zarr(output_path) 44 | xarray.testing.assert_allclose(actual, expected, atol=1e-7) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /examples/xbeam_rechunk.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rechunk a Zarr dataset.""" 15 | from typing import Dict 16 | 17 | from absl import app 18 | from absl import flags 19 | import apache_beam as beam 20 | import xarray_beam as xbeam 21 | 22 | 23 | INPUT_PATH = flags.DEFINE_string('input_path', None, help='input Zarr path') 24 | OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='output Zarr path') 25 | TARGET_CHUNKS = flags.DEFINE_string( 26 | 'target_chunks', 27 | '', 28 | help=( 29 | 'chunks on the input Zarr dataset to change on the outputs, in the ' 30 | 'form of a comma separated dimension=size pairs, e.g., ' 31 | "--target_chunks='x=10,y=10'. Omitted dimensions are not changed and a " 32 | 'chunksize of -1 indicates not to chunk a dimension.' 33 | ), 34 | ) 35 | RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner') 36 | 37 | 38 | # pylint: disable=expression-not-assigned 39 | 40 | 41 | def _parse_chunks_str(chunks_str: str) -> Dict[str, int]: 42 | chunks = {} 43 | parts = chunks_str.split(',') 44 | for part in parts: 45 | k, v = part.split('=') 46 | chunks[k] = int(v) 47 | return chunks 48 | 49 | 50 | def main(argv): 51 | source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value) 52 | template = xbeam.make_template(source_dataset) 53 | target_chunks = dict(source_chunks, **_parse_chunks_str(TARGET_CHUNKS.value)) 54 | itemsize = max(variable.dtype.itemsize for variable in template.values()) 55 | 56 | with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: 57 | ( 58 | root 59 | | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True) 60 | | xbeam.Rechunk( # pytype: disable=wrong-arg-types 61 | source_dataset.sizes, 62 | source_chunks, 63 | target_chunks, 64 | itemsize=itemsize, 65 | ) 66 | | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks) 67 | ) 68 | 69 | 70 | if __name__ == '__main__': 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /examples/xbeam_rechunk_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for xbeam_rechunk.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import flagsaver 18 | import xarray 19 | 20 | from . import xbeam_rechunk 21 | from xarray_beam._src import test_util 22 | 23 | 24 | class Era5RechunkTest(test_util.TestCase): 25 | 26 | def test(self): 27 | input_path = self.create_tempdir('source').full_path 28 | output_path = self.create_tempdir('destination').full_path 29 | 30 | input_ds = test_util.dummy_era5_surface_dataset(times=365) 31 | input_ds.chunk({'time': 31}).to_zarr(input_path) 32 | 33 | with flagsaver.flagsaver( 34 | input_path=input_path, 35 | output_path=output_path, 36 | target_chunks='latitude=5,longitude=5,time=-1', 37 | ): 38 | xbeam_rechunk.main([]) 39 | 40 | output_ds = xarray.open_zarr(output_path) 41 | self.assertEqual( 42 | {k: v[0] for k, v in output_ds.chunks.items()}, 43 | {'latitude': 5, 'longitude': 5, 'time': 365} 44 | ) 45 | xarray.testing.assert_identical(input_ds, output_ds) 46 | 47 | 48 | if __name__ == '__main__': 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pyink] 2 | line-length = 80 3 | preview = true 4 | pyink-indentation = 2 5 | pyink-use-majority-quotes = true -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup Xarray-Beam.""" 16 | import setuptools 17 | 18 | 19 | base_requires = [ 20 | 'apache_beam>=2.31.0', 21 | 'dask', 22 | 'immutabledict', 23 | 'rechunker>=0.5.1', 24 | 'zarr', 25 | 'xarray', 26 | ] 27 | docs_requires = [ 28 | 'myst-nb', 29 | 'myst-parser', 30 | 'sphinx', 31 | 'sphinx_rtd_theme', 32 | 'scipy', 33 | ] 34 | tests_requires = [ 35 | 'absl-py', 36 | 'pandas', 37 | 'pytest', 38 | 'scipy', 39 | 'h5netcdf' 40 | ] 41 | 42 | setuptools.setup( 43 | name='xarray-beam', 44 | version='0.8.1', # keep in sync with __init__.py 45 | license='Apache 2.0', 46 | author='Google LLC', 47 | author_email='noreply@google.com', 48 | install_requires=base_requires, 49 | extras_require={ 50 | 'tests': tests_requires, 51 | 'docs': docs_requires, 52 | }, 53 | url='https://github.com/google/xarray-beam', 54 | packages=setuptools.find_packages(exclude=['examples']), 55 | python_requires='>=3', 56 | ) 57 | -------------------------------------------------------------------------------- /xarray_beam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for Xarray-Beam.""" 15 | 16 | # pylint: disable=g-multiple-import 17 | from xarray_beam._src.combiners import ( 18 | Mean, 19 | MeanCombineFn, 20 | ) 21 | from xarray_beam._src.core import ( 22 | Key, 23 | DatasetToChunks, 24 | ValidateEachChunk, 25 | offsets_to_slices, 26 | validate_chunk 27 | ) 28 | from xarray_beam._src.dataset import ( 29 | Dataset, 30 | ) 31 | from xarray_beam._src.rechunk import ( 32 | ConsolidateChunks, 33 | ConsolidateVariables, 34 | SplitChunks, 35 | SplitVariables, 36 | Rechunk, 37 | consolidate_chunks, 38 | consolidate_variables, 39 | consolidate_fully, 40 | split_chunks, 41 | split_variables, 42 | in_memory_rechunk, 43 | ) 44 | from xarray_beam._src.zarr import ( 45 | open_zarr, 46 | make_template, 47 | replace_template_dims, 48 | setup_zarr, 49 | validate_zarr_chunk, 50 | write_chunk_to_zarr, 51 | ChunksToZarr, 52 | DatasetToZarr, 53 | ) 54 | 55 | __version__ = '0.8.1' # keep in sync with setup.py 56 | -------------------------------------------------------------------------------- /xarray_beam/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /xarray_beam/_src/combiners.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Combiners for xarray-beam.""" 15 | from __future__ import annotations 16 | import dataclasses 17 | from typing import Optional, Sequence, Union 18 | 19 | import apache_beam as beam 20 | import numpy.typing as npt 21 | import xarray 22 | 23 | from xarray_beam._src import core 24 | 25 | 26 | # TODO(shoyer): add other combiners: sum, std, var, min, max, etc. 27 | 28 | 29 | DimLike = Optional[Union[str, Sequence[str]]] 30 | 31 | 32 | @dataclasses.dataclass 33 | class MeanCombineFn(beam.transforms.CombineFn): 34 | """CombineFn for computing an arithmetic mean of xarray.Dataset objects.""" 35 | 36 | dim: DimLike = None 37 | skipna: bool = True 38 | dtype: Optional[npt.DTypeLike] = None 39 | 40 | def create_accumulator(self): 41 | return (0, 0) 42 | 43 | def add_input(self, sum_count, element): 44 | (sum_, count) = sum_count 45 | 46 | if self.dtype is not None: 47 | element = element.astype(self.dtype) 48 | 49 | if self.skipna: 50 | sum_increment = element.fillna(0) 51 | count_increment = element.notnull() 52 | else: 53 | sum_increment = element 54 | count_increment = xarray.ones_like(element) 55 | 56 | if self.dim is not None: 57 | # unconditionally set skipna=False because we already explictly fill in 58 | # missing values explicitly above 59 | sum_increment = sum_increment.sum(self.dim, skipna=False) 60 | count_increment = count_increment.sum(self.dim) 61 | 62 | new_sum = sum_ + sum_increment 63 | new_count = count + count_increment 64 | 65 | return new_sum, new_count 66 | 67 | def merge_accumulators(self, accumulators): 68 | sums, counts = zip(*accumulators) 69 | return sum(sums), sum(counts) 70 | 71 | def extract_output(self, sum_count): 72 | (sum_, count) = sum_count 73 | return sum_ / count 74 | 75 | def for_input_type(self, input_type): 76 | return self 77 | 78 | 79 | @dataclasses.dataclass 80 | class Mean(beam.PTransform): 81 | """Calculate the mean over one or more distributed dataset dimensions.""" 82 | 83 | dim: Union[str, Sequence[str]] 84 | skipna: bool = True 85 | dtype: Optional[npt.DTypeLike] = None 86 | fanout: Optional[int] = None 87 | 88 | def _update_key( 89 | self, key: core.Key, chunk: xarray.Dataset 90 | ) -> tuple[core.Key, xarray.Dataset]: 91 | dims = [self.dim] if isinstance(self.dim, str) else self.dim 92 | new_key = key.with_offsets(**{d: None for d in dims if d in key.offsets}) 93 | return new_key, chunk 94 | 95 | def expand(self, pcoll): 96 | return ( 97 | pcoll 98 | | beam.MapTuple(self._update_key) 99 | | Mean.PerKey(self.dim, self.skipna, self.dtype, self.fanout) 100 | ) 101 | 102 | @dataclasses.dataclass 103 | class Globally(beam.PTransform): 104 | """Calculate global mean over a pcollection of xarray.Dataset objects.""" 105 | 106 | dim: DimLike = None 107 | skipna: bool = True 108 | dtype: Optional[npt.DTypeLike] = None 109 | fanout: Optional[int] = None 110 | 111 | def expand(self, pcoll): 112 | combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype) 113 | return pcoll | beam.CombineGlobally(combine_fn).with_fanout(self.fanout) 114 | 115 | @dataclasses.dataclass 116 | class PerKey(beam.PTransform): 117 | """Calculate per-key mean over a pcollection of (hashable, Dataset).""" 118 | 119 | dim: DimLike = None 120 | skipna: bool = True 121 | dtype: Optional[npt.DTypeLike] = None 122 | fanout: Optional[int] = None 123 | 124 | def expand(self, pcoll): 125 | combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype) 126 | return pcoll | beam.CombinePerKey(combine_fn).with_hot_key_fanout( 127 | self.fanout 128 | ) 129 | -------------------------------------------------------------------------------- /xarray_beam/_src/combiners_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for xarray_beam._src.combiners.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import numpy as np 19 | import xarray 20 | import xarray_beam as xbeam 21 | from xarray_beam._src import test_util 22 | 23 | 24 | # pylint: disable=expression-not-assigned 25 | # pylint: disable=pointless-statement 26 | 27 | 28 | class MeanTest(test_util.TestCase): 29 | 30 | def test_globally(self): 31 | nan = np.nan 32 | data_with_nans = np.array( 33 | [[1, 2, 3], [4, 5, nan], [6, nan, nan], [nan, nan, nan]] 34 | ) 35 | dataset = xarray.Dataset({'foo': (('x', 'y'), data_with_nans)}) 36 | inputs_x = [dataset.isel(x=i) for i in range(4)] 37 | inputs_y = [dataset.isel(y=i) for i in range(3)] 38 | 39 | with self.subTest('skipna-default'): 40 | expected = dataset.mean('y', skipna=True) 41 | (actual,) = inputs_y | xbeam.Mean.Globally() 42 | xarray.testing.assert_allclose(expected, actual) 43 | 44 | with self.subTest('skipna=True'): 45 | expected = dataset.mean('y', skipna=True) 46 | (actual,) = inputs_y | xbeam.Mean.Globally(skipna=True) 47 | xarray.testing.assert_allclose(expected, actual) 48 | 49 | expected = dataset.mean('x', skipna=True) 50 | (actual,) = inputs_x | xbeam.Mean.Globally(skipna=True) 51 | xarray.testing.assert_allclose(expected, actual) 52 | 53 | with self.subTest('skipna=False', skipna=False): 54 | expected = dataset.mean('y', skipna=False) 55 | (actual,) = inputs_y | xbeam.Mean.Globally(skipna=False) 56 | xarray.testing.assert_allclose(expected, actual) 57 | 58 | expected = dataset.mean('x', skipna=False) 59 | (actual,) = inputs_x | xbeam.Mean.Globally(skipna=False) 60 | xarray.testing.assert_allclose(expected, actual) 61 | 62 | with self.subTest('with-fanout'): 63 | expected = dataset.mean('y', skipna=True) 64 | (actual,) = inputs_y | xbeam.Mean.Globally(fanout=2) 65 | xarray.testing.assert_allclose(expected, actual) 66 | 67 | def test_dim_globally(self): 68 | inputs = [ 69 | xarray.Dataset({'x': ('time', [1, 2])}), 70 | xarray.Dataset({'x': ('time', [3])}), 71 | ] 72 | expected = xarray.Dataset({'x': 2.0}) 73 | (actual,) = inputs | xbeam.Mean.Globally(dim='time') 74 | xarray.testing.assert_allclose(expected, actual) 75 | 76 | def test_per_key(self): 77 | inputs = [ 78 | (0, xarray.Dataset({'x': 1})), 79 | (0, xarray.Dataset({'x': 2})), 80 | (1, xarray.Dataset({'x': 3})), 81 | (1, xarray.Dataset({'x': 4})), 82 | ] 83 | expected = [ 84 | (0, xarray.Dataset({'x': 1.5})), 85 | (1, xarray.Dataset({'x': 3.5})), 86 | ] 87 | actual = inputs | xbeam.Mean.PerKey() 88 | self.assertAllCloseChunks(actual, expected) 89 | 90 | def test_mean_1d(self): 91 | inputs = [ 92 | (xbeam.Key({'x': 0}), xarray.Dataset({'y': ('x', [1, 2, 3])})), 93 | (xbeam.Key({'x': 3}), xarray.Dataset({'y': ('x', [4, 5, 6])})), 94 | ] 95 | expected = [ 96 | (xbeam.Key({}), xarray.Dataset({'y': 3.5})), 97 | ] 98 | actual = inputs | xbeam.Mean('x') 99 | self.assertAllCloseChunks(actual, expected) 100 | actual = inputs | xbeam.Mean(['x']) 101 | self.assertAllCloseChunks(actual, expected) 102 | 103 | def test_mean_many(self): 104 | inputs = [] 105 | for i in range(0, 100, 10): 106 | inputs.append( 107 | (xbeam.Key({'x': i}), xarray.Dataset({'y': ('x', i + np.arange(10))})) 108 | ) 109 | expected = [ 110 | (xbeam.Key({}), xarray.Dataset({'y': 49.5})), 111 | ] 112 | actual = inputs | xbeam.Mean('x', fanout=2) 113 | self.assertAllCloseChunks(actual, expected) 114 | 115 | def test_mean_nans(self): 116 | nan = np.nan 117 | data_with_nans = np.array( 118 | [[1, 2, 3], [4, 5, nan], [6, nan, nan], [nan, nan, nan]] 119 | ) 120 | dataset = xarray.Dataset({'foo': (('x', 'y'), data_with_nans)}) 121 | eager = test_util.EagerPipeline() 122 | chunks = eager | xbeam.DatasetToChunks(dataset, {'x': 1, 'y': 1}) 123 | 124 | expected = eager | xbeam.DatasetToChunks( 125 | dataset.mean('y', skipna=False), {'x': 1} 126 | ) 127 | actual = chunks | xbeam.Mean('y', skipna=False) 128 | self.assertAllCloseChunks(actual, expected) 129 | 130 | expected = eager | xbeam.DatasetToChunks( 131 | dataset.mean('y', skipna=True), {'x': 1} 132 | ) 133 | actual = chunks | xbeam.Mean('y', skipna=True) 134 | self.assertAllCloseChunks(actual, expected) 135 | 136 | expected = eager | xbeam.DatasetToChunks( 137 | dataset.mean('x', skipna=True), {'y': 1} 138 | ) 139 | actual = chunks | xbeam.Mean('x', skipna=True) 140 | self.assertAllCloseChunks(actual, expected) 141 | 142 | expected = eager | xbeam.DatasetToChunks( 143 | dataset.mean('x', skipna=False), {'y': 1} 144 | ) 145 | actual = chunks | xbeam.Mean('x', skipna=False) 146 | self.assertAllCloseChunks(actual, expected) 147 | 148 | def test_mean_2d(self): 149 | inputs = [ 150 | (xbeam.Key({'y': 0}), xarray.Dataset({'z': (('x', 'y'), [[1, 2, 3]])})), 151 | (xbeam.Key({'y': 3}), xarray.Dataset({'z': (('x', 'y'), [[4, 5, 6]])})), 152 | ] 153 | 154 | expected = [ 155 | (xbeam.Key({'y': 0}), xarray.Dataset({'z': ('y', [1, 2, 3])})), 156 | (xbeam.Key({'y': 3}), xarray.Dataset({'z': ('y', [4, 5, 6])})), 157 | ] 158 | actual = inputs | xbeam.Mean('x') 159 | self.assertAllCloseChunks(actual, expected) 160 | 161 | expected = [ 162 | (xbeam.Key({}), xarray.Dataset({'z': (('x',), [3.5])})), 163 | ] 164 | actual = inputs | xbeam.Mean('y') 165 | self.assertAllCloseChunks(actual, expected) 166 | 167 | expected = [ 168 | (xbeam.Key({}), xarray.Dataset({'z': 3.5})), 169 | ] 170 | actual = inputs | xbeam.Mean(['x', 'y']) 171 | self.assertAllCloseChunks(actual, expected) 172 | 173 | 174 | if __name__ == '__main__': 175 | absltest.main() 176 | -------------------------------------------------------------------------------- /xarray_beam/_src/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Core data model for xarray-beam.""" 15 | import itertools 16 | import math 17 | from typing import ( 18 | AbstractSet, 19 | Dict, 20 | Generic, 21 | Iterator, 22 | List, 23 | Mapping, 24 | Optional, 25 | Sequence, 26 | Tuple, 27 | TypeVar, 28 | Union, 29 | ) 30 | 31 | import apache_beam as beam 32 | import immutabledict 33 | import numpy as np 34 | import xarray 35 | from xarray_beam._src import threadmap 36 | 37 | _DEFAULT = object() 38 | 39 | 40 | class Key: 41 | """A key for keeping track of chunks of a distributed xarray.Dataset. 42 | 43 | Key object in Xarray-Beam include two components: 44 | 45 | - "offsets": an immutable dict indicating integer offsets (total number of 46 | array elements) from the origin along each dimension for this chunk. 47 | - "vars": either an frozenset or None, indicating the subset of Dataset 48 | variables included in this chunk. None means that all variables are 49 | included. 50 | 51 | Key objects are "deterministically encoded" by Beam, which makes them suitable 52 | for use as keys in Beam pipelines, i.e., with beam.GroupByKey. They are also 53 | immutable and hashable, which makes them usable as keys in Python 54 | dictionaries. 55 | 56 | Example usage:: 57 | 58 | >>> key = xarray_beam.Key(offsets={'x': 10}, vars={'foo'}) 59 | 60 | >>> key 61 | xarray_beam.Key(offsets={'x': 10}, vars={'foo'}) 62 | 63 | >>> key.offsets 64 | immutabledict({'x': 10}) 65 | 66 | >>> key.vars 67 | frozenset({'foo'}) 68 | 69 | To replace some offsets:: 70 | 71 | >>> key.with_offsets(y=0) # insert 72 | xarray_beam.Key(offsets={'x': 10, 'y': 0}, vars={'foo'}) 73 | 74 | >>> key.with_offsets(x=20) # override 75 | xarray_beam.Key(offsets={'x': 20}, vars={'foo'}) 76 | 77 | >>> key.with_offsets(x=None) # remove 78 | xarray_beam.Key(offsets={}, vars={'foo'}) 79 | 80 | To entirely replace offsets or variables:: 81 | 82 | >>> key.replace(offsets={'y': 0}) 83 | xarray_beam.Key(offsets={'y': 0}, vars={'foo'}) 84 | 85 | >>> key.replace(vars=None) 86 | xarray_beam.Key(offsets={'x': 10}, vars=None) 87 | """ 88 | 89 | # pylint: disable=redefined-builtin 90 | 91 | def __init__( 92 | self, 93 | offsets: Optional[Mapping[str, int]] = None, 94 | vars: Optional[AbstractSet[str]] = None, 95 | ): 96 | if offsets is None: 97 | offsets = {} 98 | if isinstance(vars, str): 99 | raise TypeError(f"vars must be a set or None, but is {vars!r}") 100 | self.offsets = immutabledict.immutabledict(offsets) 101 | self.vars = None if vars is None else frozenset(vars) 102 | 103 | def replace( 104 | self, 105 | offsets: Union[Mapping[str, int], object] = _DEFAULT, 106 | vars: Union[AbstractSet[str], None, object] = _DEFAULT, 107 | ) -> "Key": 108 | if offsets is _DEFAULT: 109 | offsets = self.offsets 110 | if vars is _DEFAULT: 111 | vars = self.vars 112 | return type(self)(offsets, vars) 113 | 114 | def with_offsets(self, **offsets: Optional[int]) -> "Key": 115 | new_offsets = dict(self.offsets) 116 | for k, v in offsets.items(): 117 | if v is None: 118 | del new_offsets[k] 119 | else: 120 | new_offsets[k] = v 121 | return self.replace(offsets=new_offsets) 122 | 123 | def __repr__(self) -> str: 124 | offsets = dict(self.offsets) 125 | vars = set(self.vars) if self.vars is not None else None 126 | return f"{type(self).__name__}(offsets={offsets}, vars={vars})" 127 | 128 | def __hash__(self) -> int: 129 | return hash((self.offsets, self.vars)) 130 | 131 | def __eq__(self, other) -> bool: 132 | if not isinstance(other, Key): 133 | return NotImplemented 134 | return self.offsets == other.offsets and self.vars == other.vars 135 | 136 | def __ne__(self, other) -> bool: 137 | return not self == other 138 | 139 | # Beam uses these methods (also used for pickling) for "deterministic 140 | # encoding" of groupby keys 141 | def __getstate__(self): 142 | offsets_state = sorted(self.offsets.items()) 143 | vars_state = None if self.vars is None else sorted(self.vars) 144 | return (offsets_state, vars_state) 145 | 146 | def __setstate__(self, state): 147 | self.__init__(*state) 148 | 149 | 150 | def offsets_to_slices( 151 | offsets: Mapping[str, int], 152 | sizes: Mapping[str, int], 153 | base: Optional[Mapping[str, int]] = None, 154 | ) -> Dict[str, slice]: 155 | """Convert offsets into slices with an optional base offset. 156 | 157 | Args: 158 | offsets: integer offsets from the origin along each axis. 159 | sizes: dimension sizes for the corresponding chunks. 160 | base: optional base-offset to subract from this key. This allows for 161 | relative indexing, e.g., into a chunk of a larger Dataset. 162 | 163 | Returns: 164 | Slices suitable for indexing with xarray.Dataset.isel(). 165 | 166 | Raises: 167 | ValueError: if an offset is specified for a dimension where there is no 168 | corresponding size specified. 169 | 170 | Example usage:: 171 | 172 | >>> offsets_to_slices({'x': 100}, sizes={'x': 10}) 173 | {'x': slice(100, 110, 1)} 174 | >>> offsets_to_slices({'x': 100}, sizes={'x': 10}, base={'x': 100}) 175 | {'x': slice(0, 10, 1)} 176 | """ 177 | if base is None: 178 | base = {} 179 | slices = {} 180 | missing_chunk_sizes = [k for k in offsets.keys() if k not in sizes] 181 | if missing_chunk_sizes: 182 | raise ValueError( 183 | "some dimensions have offsets specified but no dimension sizes: " 184 | f"offsets={offsets} and sizes={sizes}" 185 | ) 186 | for k, size in sizes.items(): 187 | offset = offsets.get(k, 0) - base.get(k, 0) 188 | slices[k] = slice(offset, offset + size, 1) 189 | return slices 190 | 191 | 192 | def _chunks_to_offsets( 193 | chunks: Mapping[str, Sequence[int]], 194 | ) -> Dict[str, List[int]]: 195 | return { 196 | dim: np.concatenate([[0], np.cumsum(sizes)[:-1]]).tolist() 197 | for dim, sizes in chunks.items() 198 | } 199 | 200 | 201 | def iter_chunk_keys( 202 | offsets: Mapping[str, Sequence[int]], 203 | vars: Optional[AbstractSet[str]] = None, # pylint: disable=redefined-builtin 204 | ) -> Iterator[Key]: 205 | """Iterate over the Key objects corresponding to the given chunks.""" 206 | chunk_indices = [range(len(sizes)) for sizes in offsets.values()] 207 | for indices in itertools.product(*chunk_indices): 208 | key_offsets = { 209 | dim: offsets[dim][index] for dim, index in zip(offsets, indices) 210 | } 211 | yield Key(key_offsets, vars) 212 | 213 | 214 | def compute_offset_index( 215 | offsets: Mapping[str, Sequence[int]], 216 | ) -> Dict[str, Dict[int, int]]: 217 | """Compute a mapping from chunk offsets to chunk indices.""" 218 | index = {} 219 | for dim, dim_offsets in offsets.items(): 220 | index[dim] = {} 221 | for i, offset in enumerate(dim_offsets): 222 | index[dim][offset] = i 223 | return index 224 | 225 | 226 | def normalize_expanded_chunks( 227 | chunks: Mapping[str, Union[int, Tuple[int, ...]]], 228 | dim_sizes: Mapping[str, int], 229 | ) -> Dict[str, Tuple[int, ...]]: 230 | # pylint: disable=g-doc-args 231 | # pylint: disable=g-doc-return-or-yield 232 | """Normalize a dict of chunks to give the expanded size of each block. 233 | 234 | Example usage:: 235 | 236 | >>> normalize_expanded_chunks({'x': 3}, {'x': 10}) 237 | {'x': (3, 3, 3, 1)} 238 | """ 239 | result = {} 240 | for dim, dim_size in dim_sizes.items(): 241 | if dim not in chunks or chunks[dim] == -1: 242 | result[dim] = (dim_size,) 243 | elif isinstance(chunks[dim], tuple): 244 | total = sum(chunks[dim]) 245 | if total != dim_size: 246 | raise ValueError( 247 | f"sum of provided chunks does not match size of dimension {dim}: " 248 | f"{total} vs {dim_size}" 249 | ) 250 | result[dim] = chunks[dim] 251 | else: 252 | multiple, remainder = divmod(dim_size, chunks[dim]) 253 | result[dim] = multiple * (chunks[dim],) + ( 254 | (remainder,) if remainder else () 255 | ) 256 | return result 257 | 258 | 259 | DatasetOrDatasets = TypeVar( 260 | "DatasetOrDatasets", xarray.Dataset, List[xarray.Dataset] 261 | ) 262 | 263 | 264 | class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]): 265 | """Split one or more xarray.Datasets into keyed chunks.""" 266 | 267 | def __init__( 268 | self, 269 | dataset: DatasetOrDatasets, 270 | chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None, 271 | split_vars: bool = False, 272 | num_threads: Optional[int] = None, 273 | shard_keys_threshold: int = 200_000, 274 | ): 275 | """Initialize DatasetToChunks. 276 | 277 | Args: 278 | dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key, 279 | [xarray.Dataset, ...]) pairs. 280 | chunks: optional chunking scheme. Required if the dataset is *not* already 281 | chunked. If the dataset *is* already chunked with Dask, `chunks` takes 282 | precedence over the existing chunks. 283 | split_vars: whether to split the dataset into separate records for each 284 | data variable or to keep all data variables together. This is 285 | recommended if you don't need to perform joint operations on different 286 | dataset variables and individual variable chunks are sufficiently large. 287 | num_threads: optional number of Dataset chunks to load in parallel per 288 | worker. More threads can increase throughput, but also increases memory 289 | usage and makes it harder for Beam runners to shard work. Note that each 290 | variable in a Dataset is already loaded in parallel, so this is most 291 | useful for Datasets with a small number of variables or when using 292 | split_vars=True. 293 | shard_keys_threshold: threshold at which to compute keys on Beam workers, 294 | rather than only on the host process. This is important for scaling 295 | pipelines to millions of tasks. 296 | """ 297 | self.dataset = dataset 298 | self._validate(dataset, split_vars) 299 | if chunks is None: 300 | chunks = self._first.chunks 301 | if not chunks: 302 | raise ValueError("dataset must be chunked or chunks must be provided") 303 | for dim in chunks: 304 | if not any(dim in ds.dims for ds in self._datasets): 305 | raise ValueError( 306 | f"chunks key {dim!r} is not a dimension on the provided dataset(s)" 307 | ) 308 | expanded_chunks = normalize_expanded_chunks(chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation 309 | self.expanded_chunks = expanded_chunks 310 | self.split_vars = split_vars 311 | self.num_threads = num_threads 312 | self.shard_keys_threshold = shard_keys_threshold 313 | # TODO(shoyer): consider recalculating these potentially large properties on 314 | # each worker, rather than only once on the host. 315 | self.offsets = _chunks_to_offsets(expanded_chunks) 316 | self.offset_index = compute_offset_index(self.offsets) 317 | # We use the simple heuristic of only sharding inputs along the dimension 318 | # with the most chunks. 319 | lengths = {k: len(v) for k, v in self.offsets.items()} 320 | self.sharded_dim = max(lengths, key=lengths.get) if lengths else None 321 | self.shard_count = self._shard_count() 322 | 323 | @property 324 | def _first(self) -> xarray.Dataset: 325 | return self._datasets[0] 326 | 327 | @property 328 | def _datasets(self) -> List[xarray.Dataset]: 329 | if isinstance(self.dataset, xarray.Dataset): 330 | return [self.dataset] 331 | return list(self.dataset) # pytype: disable=bad-return-type 332 | 333 | def _validate(self, dataset, split_vars): 334 | """Raise errors if input parameters are invalid.""" 335 | if not isinstance(dataset, xarray.Dataset): 336 | if not ( 337 | isinstance(dataset, list) 338 | and all(isinstance(ds, xarray.Dataset) for ds in dataset) 339 | ): 340 | raise TypeError( 341 | "'dataset' must be an 'xarray.Dataset' or 'list[xarray.Dataset]'" 342 | ) 343 | if not dataset: 344 | raise ValueError("dataset list cannot be empty") 345 | for ds in self._datasets[1:]: 346 | for dim, size in ds.sizes.items(): 347 | if dim not in self._first.dims: 348 | raise ValueError( 349 | f"dimension {dim} does not appear on the first dataset" 350 | ) 351 | if size != self._first.sizes[dim]: 352 | raise ValueError( 353 | f"dimension {dim} has an inconsistent size on different datasets" 354 | ) 355 | if split_vars: 356 | for ds in self._datasets: 357 | if not ds.keys() <= self._first.keys(): 358 | raise ValueError( 359 | "inconsistent data_vars when splitting variables:" 360 | f" {tuple(ds.keys())} != {tuple(self._first.keys())}" 361 | ) 362 | 363 | def _task_count(self) -> int: 364 | """Count the number of tasks emitted by this transform.""" 365 | counts = {k: len(v) for k, v in self.expanded_chunks.items()} 366 | if not self.split_vars: 367 | return int(np.prod(list(counts.values()))) 368 | total = 0 369 | for variable in self._first.values(): 370 | count_list = [v for k, v in counts.items() if k in variable.dims] 371 | total += int(np.prod(count_list)) 372 | return total 373 | 374 | def _shard_count(self) -> Optional[int]: 375 | """Determine the number of times to shard input keys.""" 376 | task_count = self._task_count() 377 | if task_count <= self.shard_keys_threshold: 378 | return None # no sharding 379 | 380 | if not self.split_vars: 381 | return math.ceil(task_count / self.shard_keys_threshold) 382 | 383 | var_count = sum( 384 | self.sharded_dim in var.dims for var in self._first.values() 385 | ) 386 | return math.ceil(task_count / (var_count * self.shard_keys_threshold)) 387 | 388 | def _iter_all_keys(self) -> Iterator[Key]: 389 | """Iterate over all Key objects.""" 390 | if not self.split_vars: 391 | yield from iter_chunk_keys(self.offsets) 392 | else: 393 | for name, variable in self._first.items(): 394 | relevant_offsets = { 395 | k: v for k, v in self.offsets.items() if k in variable.dims 396 | } 397 | yield from iter_chunk_keys(relevant_offsets, vars={name}) # pytype: disable=wrong-arg-types # always-use-property-annotation 398 | 399 | def _iter_shard_keys( 400 | self, shard_id: Optional[int], var_name: Optional[str] 401 | ) -> Iterator[Key]: 402 | """Iterate over Key objects for a specific shard and variable.""" 403 | if var_name is None: 404 | offsets = self.offsets 405 | else: 406 | offsets = {dim: self.offsets[dim] for dim in self._first[var_name].dims} 407 | 408 | if shard_id is None: 409 | assert self.split_vars 410 | yield from iter_chunk_keys(offsets, vars={var_name}) 411 | else: 412 | assert self.split_vars == (var_name is not None) 413 | dim = self.sharded_dim 414 | count = math.ceil(len(self.offsets[dim]) / self.shard_count) 415 | dim_slice = slice(shard_id * count, (shard_id + 1) * count) 416 | offsets = {**offsets, dim: offsets[dim][dim_slice]} 417 | vars_ = {var_name} if self.split_vars else None 418 | yield from iter_chunk_keys(offsets, vars=vars_) 419 | 420 | def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]: 421 | """Create inputs for sharded key iterators.""" 422 | if not self.split_vars: 423 | return [(i, None) for i in range(self.shard_count)] 424 | 425 | inputs = [] 426 | for name, variable in self._first.items(): 427 | if self.sharded_dim in variable.dims: 428 | inputs.extend([(i, name) for i in range(self.shard_count)]) 429 | else: 430 | inputs.append((None, name)) 431 | return inputs # pytype: disable=bad-return-type # always-use-property-annotation 432 | 433 | def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, DatasetOrDatasets]]: 434 | """Convert a Key into an in-memory (Key, xarray.Dataset) pair.""" 435 | sizes = { 436 | dim: self.expanded_chunks[dim][self.offset_index[dim][offset]] 437 | for dim, offset in key.offsets.items() 438 | } 439 | slices = offsets_to_slices(key.offsets, sizes) 440 | results = [] 441 | for ds in self._datasets: 442 | dataset = ds if key.vars is None else ds[list(key.vars)] 443 | valid_slices = {k: v for k, v in slices.items() if k in dataset.dims} 444 | chunk = dataset.isel(valid_slices) 445 | # Load the data, using a separate thread for each variable 446 | num_threads = len(dataset) 447 | result = chunk.chunk().compute(num_workers=num_threads) 448 | results.append(result) 449 | 450 | if isinstance(self.dataset, xarray.Dataset): 451 | yield key, results[0] 452 | else: 453 | yield key, list(results) 454 | 455 | def expand(self, pcoll): 456 | if self.shard_count is None: 457 | # Create all keys on the machine launching the Beam pipeline. This is 458 | # faster if the number of keys is small. 459 | key_pcoll = pcoll | beam.Create(self._iter_all_keys()) 460 | else: 461 | # Create keys in separate shards on Beam workers. This is more scalable. 462 | key_pcoll = ( 463 | pcoll 464 | | beam.Create(self._shard_inputs()) 465 | | "GenerateKeys" >> beam.FlatMapTuple(self._iter_shard_keys) 466 | | beam.Reshuffle() 467 | ) 468 | 469 | return key_pcoll | "KeyToChunks" >> threadmap.FlatThreadMap( 470 | self._key_to_chunks, num_threads=self.num_threads 471 | ) 472 | 473 | 474 | def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None: 475 | """Verify that a key and dataset(s) are valid for xarray-beam transforms.""" 476 | if isinstance(datasets, xarray.Dataset): 477 | datasets: list[xarray.Dataset] = [datasets] 478 | 479 | for dataset in datasets: 480 | # Verify that no variables are chunked with Dask 481 | for var_name, variable in dataset.variables.items(): 482 | if variable.chunks is not None: 483 | raise ValueError( 484 | f"Dataset variable {var_name!r} corresponding to key {key} is" 485 | " chunked with Dask. Datasets passed to validate_chunk must be" 486 | f" fully computed (not chunked): {dataset}\nThis typically arises" 487 | " with datasets originating with `xarray.open_zarr()`, which by" 488 | " default use Dask. If this is the case, you can fix it by passing" 489 | " `chunks=None` or xarray_beam.open_zarr(). Alternatively, you" 490 | " can load datasets explicitly into memory with `.compute()`." 491 | ) 492 | 493 | # Validate key offsets 494 | missing_keys = [ 495 | repr(k) for k in key.offsets.keys() if k not in dataset.dims 496 | ] 497 | if missing_keys: 498 | raise ValueError( 499 | f"Key offset(s) {', '.join(missing_keys)} in {key} not found in" 500 | f" Dataset dimensions: {dataset!r}" 501 | ) 502 | 503 | # Validate key vars 504 | if key.vars is not None: 505 | missing_vars = [repr(v) for v in key.vars if v not in dataset.data_vars] 506 | if missing_vars: 507 | raise ValueError( 508 | f"Key var(s) {', '.join(missing_vars)} in {key} not found in" 509 | f" Dataset data variables: {dataset!r}" 510 | ) 511 | 512 | 513 | class ValidateEachChunk(beam.PTransform): 514 | """Check that keys and dataset(s) are valid for xarray-beam transforms.""" 515 | 516 | def _validate(self, key, dataset): 517 | validate_chunk(key, dataset) 518 | return key, dataset 519 | 520 | def expand(self, pcoll): 521 | return pcoll | beam.MapTuple(self._validate) 522 | -------------------------------------------------------------------------------- /xarray_beam/_src/core_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for xarray_beam._src.core.""" 15 | 16 | import re 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import apache_beam as beam 20 | import immutabledict 21 | import numpy as np 22 | import xarray 23 | import xarray_beam as xbeam 24 | from xarray_beam._src import core 25 | from xarray_beam._src import test_util 26 | 27 | 28 | # pylint: disable=expression-not-assigned 29 | # pylint: disable=pointless-statement 30 | 31 | 32 | class KeyTest(test_util.TestCase): 33 | 34 | def test_constructor(self): 35 | key = xbeam.Key({'x': 0, 'y': 10}) 36 | self.assertIsInstance(key.offsets, immutabledict.immutabledict) 37 | self.assertEqual(dict(key.offsets), {'x': 0, 'y': 10}) 38 | self.assertIsNone(key.vars) 39 | 40 | key = xbeam.Key(vars={'foo'}) 41 | self.assertEqual(dict(key.offsets), {}) 42 | self.assertIsInstance(key.vars, frozenset) 43 | self.assertEqual(set(key.vars), {'foo'}) 44 | 45 | with self.assertRaisesRegex(TypeError, 'vars must be a set or None'): 46 | xbeam.Key(vars='foo') 47 | 48 | def test_replace(self): 49 | key = xbeam.Key({'x': 0}, {'foo'}) 50 | 51 | expected = xbeam.Key({'x': 1}, {'foo'}) 52 | actual = key.replace({'x': 1}) 53 | self.assertEqual(expected, actual) 54 | 55 | expected = xbeam.Key({'y': 1}, {'foo'}) 56 | actual = key.replace({'y': 1}) 57 | self.assertEqual(expected, actual) 58 | 59 | expected = xbeam.Key({'x': 0}) 60 | actual = key.replace(vars=None) 61 | self.assertEqual(expected, actual) 62 | 63 | expected = xbeam.Key({'x': 0}, {'bar'}) 64 | actual = key.replace(vars={'bar'}) 65 | self.assertEqual(expected, actual) 66 | 67 | expected = xbeam.Key({'y': 1}, {'foo'}) 68 | actual = key.replace({'y': 1}, {'foo'}) 69 | self.assertEqual(expected, actual) 70 | 71 | expected = xbeam.Key({'y': 1}, {'bar'}) 72 | actual = key.replace({'y': 1}, {'bar'}) 73 | self.assertEqual(expected, actual) 74 | 75 | def test_with_offsets(self): 76 | key = xbeam.Key({'x': 0}) 77 | 78 | expected = xbeam.Key({'x': 1}) 79 | actual = key.with_offsets(x=1) 80 | self.assertEqual(expected, actual) 81 | 82 | expected = xbeam.Key({'x': 0, 'y': 1}) 83 | actual = key.with_offsets(y=1) 84 | self.assertEqual(expected, actual) 85 | 86 | expected = xbeam.Key() 87 | actual = key.with_offsets(x=None) 88 | self.assertEqual(expected, actual) 89 | 90 | expected = xbeam.Key({'y': 1, 'z': 2}) 91 | actual = key.with_offsets(x=None, y=1, z=2) 92 | self.assertEqual(expected, actual) 93 | 94 | key2 = xbeam.Key({'x': 0}, vars={'foo'}) 95 | expected = xbeam.Key({'x': 1}, vars={'foo'}) 96 | actual = key2.with_offsets(x=1) 97 | self.assertEqual(expected, actual) 98 | 99 | def test_repr(self): 100 | key = xbeam.Key({'x': 0, 'y': 10}) 101 | expected = "Key(offsets={'x': 0, 'y': 10}, vars=None)" 102 | self.assertEqual(repr(key), expected) 103 | 104 | key = xbeam.Key(vars={'foo'}) 105 | expected = "Key(offsets={}, vars={'foo'})" 106 | self.assertEqual(repr(key), expected) 107 | 108 | def test_dict_key(self): 109 | first = {xbeam.Key({'x': 0, 'y': 10}): 1} 110 | second = {xbeam.Key({'x': 0, 'y': 10}): 1} 111 | self.assertEqual(first, second) 112 | 113 | def test_equality(self): 114 | key = xbeam.Key({'x': 0, 'y': 10}) 115 | self.assertEqual(key, key) 116 | self.assertNotEqual(key, None) 117 | 118 | key2 = xbeam.Key({'x': 0, 'y': 10}, {'bar'}) 119 | self.assertEqual(key2, key2) 120 | self.assertNotEqual(key, key2) 121 | self.assertNotEqual(key2, key) 122 | 123 | def test_offsets_as_beam_key(self): 124 | inputs = [ 125 | (xbeam.Key({'x': 0, 'y': 1}), 1), 126 | (xbeam.Key({'x': 0, 'y': 2}), 2), 127 | (xbeam.Key({'y': 1, 'x': 0}), 3), 128 | ] 129 | expected = [ 130 | (xbeam.Key({'x': 0, 'y': 1}), [1, 3]), 131 | (xbeam.Key({'x': 0, 'y': 2}), [2]), 132 | ] 133 | actual = inputs | beam.GroupByKey() 134 | self.assertEqual(actual, expected) 135 | 136 | def test_vars_as_beam_key(self): 137 | inputs = [ 138 | (xbeam.Key(vars={'foo'}), 1), 139 | (xbeam.Key(vars={'bar'}), 2), 140 | (xbeam.Key(vars={'foo'}), 3), 141 | ] 142 | expected = [ 143 | (xbeam.Key(vars={'foo'}), [1, 3]), 144 | (xbeam.Key(vars={'bar'}), [2]), 145 | ] 146 | actual = inputs | beam.GroupByKey() 147 | self.assertEqual(actual, expected) 148 | 149 | 150 | class TestOffsetsToSlices(test_util.TestCase): 151 | 152 | def test_offsets_to_slices(self): 153 | offsets = {'x': 0, 'y': 10} 154 | 155 | expected = {'x': slice(0, 5, 1), 'y': slice(10, 20, 1)} 156 | slices = core.offsets_to_slices(offsets, {'x': 5, 'y': 10}) 157 | self.assertEqual(slices, expected) 158 | 159 | expected = { 160 | 'x': slice(0, 5, 1), 161 | 'y': slice(10, 20, 1), 162 | 'extra_key': slice(0, 100, 1), 163 | } 164 | slices = core.offsets_to_slices( 165 | offsets, {'x': 5, 'y': 10, 'extra_key': 100} 166 | ) 167 | self.assertEqual(slices, expected) 168 | 169 | with self.assertRaises(ValueError): 170 | core.offsets_to_slices(offsets, {'y': 10}) 171 | 172 | def test_offsets_to_slices_base(self): 173 | offsets = {'x': 100, 'y': 210} 174 | 175 | base = {'x': 100, 'y': 200} 176 | expected = {'x': slice(0, 5, 1), 'y': slice(10, 20, 1)} 177 | slices = core.offsets_to_slices(offsets, {'x': 5, 'y': 10}, base=base) 178 | self.assertEqual(slices, expected) 179 | 180 | base = {'x': 100} 181 | expected = {'x': slice(0, 5, 1), 'y': slice(210, 220, 1)} 182 | slices = core.offsets_to_slices(offsets, {'x': 5, 'y': 10}, base=base) 183 | self.assertEqual(slices, expected) 184 | 185 | 186 | class DatasetToChunksTest(test_util.TestCase): 187 | 188 | def test_iter_chunk_keys(self): 189 | actual = list(core.iter_chunk_keys({'x': (0, 3), 'y': (0, 2, 4)})) 190 | expected = [ 191 | xbeam.Key({'x': 0, 'y': 0}), 192 | xbeam.Key({'x': 0, 'y': 2}), 193 | xbeam.Key({'x': 0, 'y': 4}), 194 | xbeam.Key({'x': 3, 'y': 0}), 195 | xbeam.Key({'x': 3, 'y': 2}), 196 | xbeam.Key({'x': 3, 'y': 4}), 197 | ] 198 | self.assertEqual(actual, expected) 199 | 200 | def test_compute_offset_index(self): 201 | actual = core.compute_offset_index({'x': (0, 3), 'y': (0, 2, 4)}) 202 | expected = {'x': {0: 0, 3: 1}, 'y': {0: 0, 2: 1, 4: 2}} 203 | self.assertEqual(actual, expected) 204 | 205 | def test_normalize_expanded_chunks(self): 206 | actual = core.normalize_expanded_chunks({}, {'x': 10}) 207 | expected = {'x': (10,)} 208 | self.assertEqual(actual, expected) 209 | 210 | actual = core.normalize_expanded_chunks({'x': -1}, {'x': 10}) 211 | expected = {'x': (10,)} 212 | self.assertEqual(actual, expected) 213 | 214 | actual = core.normalize_expanded_chunks({'x': (5, 5)}, {'x': 10}) 215 | expected = {'x': (5, 5)} 216 | self.assertEqual(actual, expected) 217 | 218 | with self.assertRaisesRegex( 219 | ValueError, 220 | 'sum of provided chunks does not match', 221 | ): 222 | core.normalize_expanded_chunks({'x': (5, 5, 5)}, {'x': 10}) 223 | 224 | actual = core.normalize_expanded_chunks({'x': 3, 'y': 2}, {'x': 9, 'y': 4}) 225 | expected = {'x': (3, 3, 3), 'y': (2, 2)} 226 | self.assertEqual(actual, expected) 227 | 228 | actual = core.normalize_expanded_chunks({'x': 3}, {'x': 10}) 229 | expected = {'x': (3, 3, 3, 1)} 230 | self.assertEqual(actual, expected) 231 | 232 | def test_dataset_to_chunks_multiple(self): 233 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 234 | expected = [ 235 | (xbeam.Key({'x': 0}), dataset.head(x=3)), 236 | (xbeam.Key({'x': 3}), dataset.tail(x=3)), 237 | ] 238 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 239 | dataset.chunk({'x': 3}) 240 | ) 241 | self.assertIdenticalChunks(actual, expected) 242 | 243 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 244 | dataset.chunk({'x': 3}), num_threads=2 245 | ) 246 | self.assertIdenticalChunks(actual, expected) 247 | 248 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 249 | dataset, chunks={'x': 3} 250 | ) 251 | self.assertIdenticalChunks(actual, expected) 252 | 253 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 254 | dataset.chunk({'x': 3}), shard_keys_threshold=1 255 | ) 256 | self.assertIdenticalChunks(actual, expected) 257 | 258 | def test_datasets_to_chunks_multiple(self): 259 | datasets = [ 260 | xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(7) 261 | ] 262 | expected = [ 263 | (xbeam.Key({'x': 0}), [ds.head(x=3) for ds in datasets]), 264 | (xbeam.Key({'x': 3}), [ds.tail(x=3) for ds in datasets]), 265 | ] 266 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 267 | [ds.chunk({'x': 3}) for ds in datasets] 268 | ) 269 | self.assertIdenticalChunks(actual, expected) 270 | 271 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 272 | [ds.chunk({'x': 3}) for ds in datasets], num_threads=2 273 | ) 274 | self.assertIdenticalChunks(actual, expected) 275 | 276 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 277 | datasets, chunks={'x': 3} 278 | ) 279 | self.assertIdenticalChunks(actual, expected) 280 | 281 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 282 | [ds.chunk({'x': 3}) for ds in datasets], shard_keys_threshold=1 283 | ) 284 | self.assertIdenticalChunks(actual, expected) 285 | 286 | def test_dataset_to_chunks_whole(self): 287 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 288 | expected = [(xbeam.Key({'x': 0}), dataset)] 289 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 290 | dataset, chunks={'x': -1} 291 | ) 292 | self.assertIdenticalChunks(actual, expected) 293 | 294 | def test_datasets_to_chunks_whole(self): 295 | datasets = [ 296 | xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(11) 297 | ] 298 | expected = [(xbeam.Key({'x': 0}), datasets)] 299 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 300 | datasets, chunks={'x': -1} 301 | ) 302 | self.assertIdenticalChunks(actual, expected) 303 | 304 | def test_dataset_to_chunks_vars(self): 305 | dataset = xarray.Dataset({ 306 | 'foo': ('x', np.arange(6)), 307 | 'bar': ('x', -np.arange(6)), 308 | }) 309 | expected = [ 310 | (xbeam.Key({'x': 0}, {'foo'}), dataset.head(x=3)[['foo']]), 311 | (xbeam.Key({'x': 0}, {'bar'}), dataset.head(x=3)[['bar']]), 312 | (xbeam.Key({'x': 3}, {'foo'}), dataset.tail(x=3)[['foo']]), 313 | (xbeam.Key({'x': 3}, {'bar'}), dataset.tail(x=3)[['bar']]), 314 | ] 315 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 316 | dataset, chunks={'x': 3}, split_vars=True 317 | ) 318 | self.assertIdenticalChunks(actual, expected) 319 | 320 | def test_datasets_to_chunks_vars(self): 321 | datasets = [ 322 | xarray.Dataset({ 323 | 'foo': ('x', i + np.arange(6)), 324 | 'bar': ('x', i - np.arange(6)), 325 | }) 326 | for i in range(12) 327 | ] 328 | expected = [ 329 | ( 330 | xbeam.Key({'x': 0}, {'foo'}), 331 | [ds.head(x=3)[['foo']] for ds in datasets], 332 | ), 333 | ( 334 | xbeam.Key({'x': 0}, {'bar'}), 335 | [ds.head(x=3)[['bar']] for ds in datasets], 336 | ), 337 | ( 338 | xbeam.Key({'x': 3}, {'foo'}), 339 | [ds.tail(x=3)[['foo']] for ds in datasets], 340 | ), 341 | ( 342 | xbeam.Key({'x': 3}, {'bar'}), 343 | [ds.tail(x=3)[['bar']] for ds in datasets], 344 | ), 345 | ] 346 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 347 | datasets, chunks={'x': 3}, split_vars=True 348 | ) 349 | self.assertIdenticalChunks(actual, expected) 350 | 351 | @parameterized.parameters( 352 | {'shard_keys_threshold': 1}, 353 | {'shard_keys_threshold': 2}, 354 | {'shard_keys_threshold': 10}, 355 | ) 356 | def test_dataset_to_chunks_split_with_different_dims( 357 | self, shard_keys_threshold 358 | ): 359 | dataset = xarray.Dataset({ 360 | 'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])), 361 | 'bar': ('x', np.array([1, 2])), 362 | 'baz': ('z', np.array([1, 2, 3])), 363 | }) 364 | expected = [ 365 | (xbeam.Key({'x': 0, 'y': 0}, {'foo'}), dataset[['foo']].head(x=1)), 366 | (xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']].head(x=1)), 367 | (xbeam.Key({'x': 1, 'y': 0}, {'foo'}), dataset[['foo']].tail(x=1)), 368 | (xbeam.Key({'x': 1}, {'bar'}), dataset[['bar']].tail(x=1)), 369 | (xbeam.Key({'z': 0}, {'baz'}), dataset[['baz']]), 370 | ] 371 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 372 | dataset, 373 | chunks={'x': 1}, 374 | split_vars=True, 375 | shard_keys_threshold=shard_keys_threshold, 376 | ) 377 | self.assertIdenticalChunks(actual, expected) 378 | 379 | def test_dataset_to_chunks_empty(self): 380 | dataset = xarray.Dataset() 381 | expected = [(xbeam.Key({}), dataset)] 382 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 383 | dataset, chunks={} 384 | ) 385 | self.assertIdenticalChunks(actual, expected) 386 | 387 | def test_datasets_to_chunks_empty(self): 388 | datasets = [xarray.Dataset() for _ in range(5)] 389 | expected = [(xbeam.Key({}), datasets)] 390 | actual = test_util.EagerPipeline() | xbeam.DatasetToChunks( 391 | datasets, chunks={} 392 | ) 393 | self.assertIdenticalChunks(actual, expected) 394 | 395 | def test_task_count(self): 396 | dataset = xarray.Dataset({ 397 | 'foo': (('x', 'y'), np.zeros((3, 6))), 398 | 'bar': ('x', np.zeros(3)), 399 | 'baz': ('z', np.zeros(10)), 400 | }) 401 | 402 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1}) 403 | self.assertEqual(to_chunks._task_count(), 3) 404 | 405 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1}, split_vars=True) 406 | self.assertEqual(to_chunks._task_count(), 7) 407 | 408 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'y': 1}, split_vars=True) 409 | self.assertEqual(to_chunks._task_count(), 8) 410 | 411 | to_chunks = xbeam.DatasetToChunks(dataset, chunks={'z': 1}, split_vars=True) 412 | self.assertEqual(to_chunks._task_count(), 12) 413 | 414 | datasets = [dataset.copy() for _ in range(9)] 415 | 416 | to_chunks = xbeam.DatasetToChunks(datasets, chunks={'x': 1}) 417 | self.assertEqual(to_chunks._task_count(), 3) 418 | 419 | to_chunks = xbeam.DatasetToChunks( 420 | datasets, chunks={'x': 1}, split_vars=True 421 | ) 422 | self.assertEqual(to_chunks._task_count(), 7) 423 | 424 | to_chunks = xbeam.DatasetToChunks( 425 | datasets, chunks={'y': 1}, split_vars=True 426 | ) 427 | self.assertEqual(to_chunks._task_count(), 8) 428 | 429 | to_chunks = xbeam.DatasetToChunks( 430 | datasets, chunks={'z': 1}, split_vars=True 431 | ) 432 | self.assertEqual(to_chunks._task_count(), 12) 433 | 434 | def test_validate(self): 435 | dataset = xarray.Dataset({ 436 | 'foo': (('x', 'y'), np.zeros((3, 6))), 437 | 'bar': ('x', np.zeros(3)), 438 | 'baz': ('z', np.zeros(10)), 439 | }) 440 | 441 | with self.assertRaisesWithLiteralMatch( 442 | ValueError, 'dataset must be chunked or chunks must be provided' 443 | ): 444 | test_util.EagerPipeline() | xbeam.DatasetToChunks(dataset, chunks=None) 445 | 446 | with self.assertRaisesWithLiteralMatch( 447 | ValueError, 448 | "chunks key 'invalid' is not a dimension on the provided dataset(s)", 449 | ): 450 | test_util.EagerPipeline() | xbeam.DatasetToChunks( 451 | dataset, chunks={'invalid': 1} 452 | ) 453 | 454 | with self.assertRaisesWithLiteralMatch( 455 | TypeError, 456 | "'dataset' must be an 'xarray.Dataset' or 'list[xarray.Dataset]'", 457 | ): 458 | test_util.EagerPipeline() | xbeam.DatasetToChunks({'foo': dataset}) 459 | 460 | with self.assertRaisesWithLiteralMatch( 461 | ValueError, 'dataset list cannot be empty' 462 | ): 463 | test_util.EagerPipeline() | xbeam.DatasetToChunks([]) 464 | 465 | with self.assertRaisesWithLiteralMatch( 466 | ValueError, 467 | 'dimension z does not appear on the first dataset', 468 | ): 469 | test_util.EagerPipeline() | xbeam.DatasetToChunks( 470 | [dataset.drop_dims('z'), dataset], chunks={'x': 1} 471 | ) 472 | 473 | with self.assertRaisesWithLiteralMatch( 474 | ValueError, 475 | 'dimension z has an inconsistent size on different datasets', 476 | ): 477 | test_util.EagerPipeline() | xbeam.DatasetToChunks( 478 | [dataset.isel(z=slice(5, 10), drop=True), dataset], chunks={'x': 1} 479 | ) 480 | 481 | try: 482 | test_util.EagerPipeline() | xbeam.DatasetToChunks( 483 | [dataset, dataset.isel(z=0, drop=True)], chunks={'x': 1} 484 | ) 485 | except ValueError: 486 | self.fail('should allow a pipeline where the first has more dimensions.') 487 | 488 | with self.assertRaisesWithLiteralMatch( 489 | ValueError, 490 | ( 491 | 'inconsistent data_vars when splitting variables: ' 492 | "('foo', 'bar', 'baz', 'qux') != ('foo', 'bar', 'baz')" 493 | ), 494 | ): 495 | test_util.EagerPipeline() | xbeam.DatasetToChunks( 496 | [dataset, dataset.assign(qux=2 * dataset.bar)], 497 | chunks={'x': 1}, 498 | split_vars=True, 499 | ) 500 | 501 | try: 502 | test_util.EagerPipeline() | xbeam.DatasetToChunks( 503 | [dataset, dataset.isel(y=0, drop=True)], 504 | chunks={'x': 1}, 505 | split_vars=True, 506 | ) 507 | except ValueError: 508 | self.fail('should allow a pipeline where the first has more dimensions.') 509 | 510 | 511 | class ValidateEachChunkTest(test_util.TestCase): 512 | 513 | def test_validate_chunk_raises_on_dask_chunked(self): 514 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}).chunk() 515 | key = xbeam.Key({'x': 0}) 516 | 517 | with self.assertRaisesRegex( 518 | ValueError, 519 | re.escape( 520 | "Dataset variable 'foo' corresponding to key Key(offsets={'x': 0}," 521 | ' vars=None) is chunked with Dask. Datasets passed to' 522 | ' validate_chunk must be fully computed (not chunked):' 523 | ), 524 | ): 525 | core.validate_chunk(key, dataset) 526 | 527 | def test_unmatched_dimension_raises_error(self): 528 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 529 | key = xbeam.Key({'x': 0, 'y': 0}) 530 | with self.assertRaisesRegex( 531 | ValueError, 532 | re.escape( 533 | "Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not " 534 | 'found in Dataset dimensions' 535 | ), 536 | ): 537 | core.validate_chunk(key, dataset) 538 | 539 | def test_unmatched_variables_raises_error_core(self): 540 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 541 | key = xbeam.Key({'x': 0}, {'bar'}) 542 | with self.assertRaisesRegex( 543 | ValueError, 544 | re.escape( 545 | "Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found" 546 | ' in Dataset data variables' 547 | ), 548 | ): 549 | core.validate_chunk(key, dataset) 550 | 551 | def test_unmatched_variables_multiple_datasets_raises_error_core(self): 552 | datasets = [ 553 | xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(11) 554 | ] 555 | datasets[5] = datasets[5].rename({'foo': 'bar'}) 556 | key = xbeam.Key({'x': 0}, vars={'foo'}) 557 | 558 | with self.assertRaisesRegex( 559 | ValueError, 560 | re.escape( 561 | "Key var(s) 'foo' in Key(offsets={'x': 0}, vars={'foo'}) " 562 | f'not found in Dataset data variables: {datasets[5]}' 563 | ), 564 | ): 565 | core.validate_chunk(key, datasets) 566 | 567 | def test_validate_chunks_compose_in_pipeline(self): 568 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 569 | expected = [(xbeam.Key({'x': 0}), dataset)] 570 | actual = ( 571 | test_util.EagerPipeline() 572 | | xbeam.DatasetToChunks(dataset, chunks={'x': -1}) 573 | | xbeam.ValidateEachChunk() 574 | ) 575 | self.assertIdenticalChunks(actual, expected) 576 | 577 | 578 | if __name__ == '__main__': 579 | absltest.main() 580 | -------------------------------------------------------------------------------- /xarray_beam/_src/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A high-level interface for Xarray-Beam datasets. 15 | 16 | Usage example (not fully implemented yet!): 17 | 18 | import xarray_beam as xbeam 19 | 20 | transform = ( 21 | xbeam.Dataset.from_zarr(input_path) 22 | .rechunk({'time': -1, 'latitude': 10, 'longitude': 10}) 23 | .map_blocks(lambda x: x.median('time')) 24 | .to_zarr(output_path) 25 | ) 26 | with beam.Pipeline() as p: 27 | p | transform 28 | """ 29 | from __future__ import annotations 30 | 31 | import collections 32 | from collections import abc 33 | import dataclasses 34 | import itertools 35 | import os.path 36 | import tempfile 37 | 38 | import apache_beam as beam 39 | import xarray 40 | from xarray_beam._src import core 41 | from xarray_beam._src import zarr 42 | 43 | 44 | class _CountNamer: 45 | 46 | def __init__(self): 47 | self._counts = collections.defaultdict(itertools.count) 48 | 49 | def apply(self, name: str) -> str: 50 | return f'{name}_{next(self._counts[name])}' 51 | 52 | 53 | _get_label = _CountNamer().apply 54 | 55 | 56 | @dataclasses.dataclass 57 | class Dataset: 58 | """Experimental high-level representation of an Xarray-Beam dataset.""" 59 | 60 | template: xarray.Dataset 61 | chunks: dict[str, int] 62 | split_vars: bool 63 | ptransform: beam.PTransform 64 | 65 | @classmethod 66 | def from_xarray( 67 | cls, 68 | source: xarray.Dataset, 69 | chunks: abc.Mapping[str, int], 70 | split_vars: bool = False, 71 | ) -> Dataset: 72 | """Create an xarray_beam.Dataset from an xarray.Dataset.""" 73 | template = zarr.make_template(source) 74 | ptransform = _get_label('from_xarray') >> core.DatasetToChunks( 75 | source, chunks, split_vars 76 | ) 77 | return cls(template, dict(chunks), split_vars, ptransform) 78 | 79 | @classmethod 80 | def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset: 81 | """Create an xarray_beam.Dataset from a zarr file.""" 82 | source, chunks = zarr.open_zarr(path) 83 | result = cls.from_xarray(source, chunks, split_vars) 84 | result.ptransform = _get_label('from_zarr') >> result.ptransform 85 | return result 86 | 87 | def to_zarr(self, path: str) -> beam.PTransform: 88 | """Write to a Zarr file.""" 89 | return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr( 90 | path, self.template, self.chunks 91 | ) 92 | 93 | def collect_with_direct_runner(self) -> xarray.Dataset: 94 | """Collect a dataset in memory by writing it to a temp file.""" 95 | # TODO(shoyer): generalize this function to something that support 96 | # alternative runners can we figure out a suitable temp file location for 97 | # distributed runners? 98 | 99 | with tempfile.TemporaryDirectory() as temp_dir: 100 | temp_path = os.path.join(temp_dir, 'tmp.zarr') 101 | with beam.Pipeline(runner='DirectRunner') as pipeline: 102 | pipeline |= self.to_zarr(temp_path) 103 | return xarray.open_zarr(temp_path).compute() 104 | 105 | # TODO(shoyer): implement map_blocks, rechunking, merge, rename, mean, etc 106 | 107 | @property 108 | def sizes(self) -> dict[str, int]: 109 | """Size of each dimension on this dataset.""" 110 | return dict(self.template.sizes) # pytype: disable=bad-return-type 111 | 112 | def pipe(self, func, *args, **kwargs): 113 | return func(*args, **kwargs) 114 | 115 | def __repr__(self): 116 | base = repr(self.template) 117 | chunks_str = ', '.join(f'{k}: {v}' for k, v in self.chunks.items()) 118 | return ( 119 | f'' 120 | + '\n' 121 | + '\n'.join(base.split('\n')[1:]) 122 | ) 123 | -------------------------------------------------------------------------------- /xarray_beam/_src/dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import textwrap 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import apache_beam as beam 19 | import numpy as np 20 | import xarray 21 | import xarray_beam as xbeam 22 | from xarray_beam._src import test_util 23 | 24 | 25 | class DatasetTest(test_util.TestCase): 26 | 27 | def test_from_xarray(self): 28 | ds = xarray.Dataset({'foo': ('x', np.arange(10))}) 29 | beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) 30 | self.assertIsInstance(beam_ds, xbeam.Dataset) 31 | self.assertEqual(beam_ds.sizes, {'x': 10}) 32 | self.assertEqual(beam_ds.template.keys(), {'foo'}) 33 | self.assertEqual(beam_ds.chunks, {'x': 5}) 34 | self.assertFalse(beam_ds.split_vars) 35 | self.assertRegex(beam_ds.ptransform.label, r'^from_xarray_\d+$') 36 | self.assertEqual( 37 | repr(beam_ds).split('\n')[0], 38 | "", 39 | ) 40 | expected = [ 41 | (xbeam.Key({'x': 0}), ds.head(x=5)), 42 | (xbeam.Key({'x': 5}), ds.tail(x=5)), 43 | ] 44 | actual = test_util.EagerPipeline() | beam_ds.ptransform 45 | self.assertIdenticalChunks(expected, actual) 46 | 47 | def test_collect_with_direct_runner(self): 48 | ds = xarray.Dataset({'foo': ('x', np.arange(10))}) 49 | beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) 50 | collected = beam_ds.collect_with_direct_runner() 51 | xarray.testing.assert_identical(ds, collected) 52 | 53 | @parameterized.parameters( 54 | dict(split_vars=False), 55 | dict(split_vars=True), 56 | ) 57 | def test_from_zarr(self, split_vars): 58 | temp_dir = self.create_tempdir().full_path 59 | ds = xarray.Dataset({'foo': ('x', np.arange(10))}) 60 | ds.chunk({'x': 5}).to_zarr(temp_dir) 61 | 62 | beam_ds = xbeam.Dataset.from_zarr(temp_dir, split_vars) 63 | 64 | self.assertRegex(beam_ds.ptransform.label, r'^from_zarr_\d+$') 65 | self.assertEqual(beam_ds.chunks, {'x': 5}) 66 | self.assertEqual(beam_ds.split_vars, split_vars) 67 | 68 | collected = beam_ds.collect_with_direct_runner() 69 | xarray.testing.assert_identical(ds, collected) 70 | 71 | def test_to_zarr(self): 72 | temp_dir = self.create_tempdir().full_path 73 | ds = xarray.Dataset({'foo': ('x', np.arange(10))}) 74 | beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) 75 | to_zarr = beam_ds.to_zarr(temp_dir) 76 | 77 | self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$') 78 | with beam.Pipeline() as p: 79 | p |= to_zarr 80 | opened = xarray.open_zarr(temp_dir).compute() 81 | xarray.testing.assert_identical(ds, opened) 82 | 83 | 84 | if __name__ == '__main__': 85 | absltest.main() 86 | -------------------------------------------------------------------------------- /xarray_beam/_src/integration_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Integration tests for Xarray-Beam.""" 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | import apache_beam as beam 18 | import numpy as np 19 | import xarray 20 | import xarray_beam as xbeam 21 | from xarray_beam._src import test_util 22 | 23 | 24 | # pylint: disable=expression-not-assigned 25 | 26 | 27 | class IntegrationTest(test_util.TestCase): 28 | 29 | @parameterized.named_parameters( 30 | { 31 | 'testcase_name': 'eager_unified', 32 | 'template_method': 'eager', 33 | 'split_vars': False, 34 | }, 35 | { 36 | 'testcase_name': 'eager_split', 37 | 'template_method': 'eager', 38 | 'split_vars': True, 39 | }, 40 | { 41 | 'testcase_name': 'eager_unified_sharded', 42 | 'template_method': 'eager', 43 | 'split_vars': False, 44 | 'shard_keys_threshold': 20, 45 | }, 46 | { 47 | 'testcase_name': 'eager_split_sharded', 48 | 'template_method': 'eager', 49 | 'split_vars': True, 50 | 'shard_keys_threshold': 20, 51 | }, 52 | { 53 | 'testcase_name': 'lazy_unified', 54 | 'template_method': 'lazy', 55 | 'split_vars': False, 56 | }, 57 | { 58 | 'testcase_name': 'infer_unified', 59 | 'template_method': 'infer', 60 | 'split_vars': False, 61 | }, 62 | { 63 | 'testcase_name': 'infer_split', 64 | 'template_method': 'infer', 65 | 'split_vars': True, 66 | }, 67 | ) 68 | def test_rechunk_zarr_to_zarr( 69 | self, template_method, split_vars, shard_keys_threshold=1_000_000 70 | ): 71 | src_dir = self.create_tempdir('source').full_path 72 | dest_dir = self.create_tempdir('destination').full_path 73 | 74 | source_chunks = {'t': 1, 'x': 100, 'y': 120} 75 | target_chunks = {'t': -1, 'x': 20, 'y': 20} 76 | 77 | rs = np.random.RandomState(0) 78 | raw_data = rs.randint(2**30, size=(60, 100, 120)) # 5.76 MB 79 | dataset = xarray.Dataset( 80 | { 81 | 'foo': (('t', 'x', 'y'), raw_data), 82 | 'bar': (('t', 'x', 'y'), raw_data - 1), 83 | } 84 | ) 85 | dataset.chunk(source_chunks).to_zarr(src_dir, consolidated=True) 86 | 87 | on_disk = xarray.open_zarr(src_dir, consolidated=True) 88 | on_disk_chunked = on_disk.chunk(target_chunks) 89 | with beam.Pipeline('DirectRunner') as pipeline: 90 | # make template 91 | if template_method == 'eager': 92 | target_template = on_disk_chunked 93 | elif template_method == 'lazy': 94 | target_template = beam.pvalue.AsSingleton( 95 | pipeline | beam.Create([on_disk_chunked]) 96 | ) 97 | elif template_method == 'infer': 98 | target_template = None 99 | # run pipeline 100 | ( 101 | pipeline 102 | | xbeam.DatasetToChunks( 103 | on_disk, 104 | split_vars=split_vars, 105 | shard_keys_threshold=shard_keys_threshold, 106 | ) 107 | | xbeam.Rechunk( 108 | on_disk.sizes, 109 | source_chunks, 110 | target_chunks, 111 | itemsize=8, 112 | max_mem=10_000_000, # require two stages 113 | ) 114 | | xbeam.ChunksToZarr(dest_dir, target_template) 115 | ) 116 | roundtripped = xarray.open_zarr(dest_dir, consolidated=True, chunks=False) 117 | 118 | xarray.testing.assert_identical(roundtripped, dataset) 119 | 120 | @parameterized.named_parameters( 121 | { 122 | 'testcase_name': 'unified_unsharded', 123 | 'split_vars': False, 124 | 'shard_keys_threshold': 1_000_000, 125 | }, 126 | { 127 | 'testcase_name': 'split_unsharded', 128 | 'split_vars': True, 129 | 'shard_keys_threshold': 1_000_000, 130 | }, 131 | { 132 | 'testcase_name': 'unified_sharded', 133 | 'split_vars': False, 134 | 'shard_keys_threshold': 3, 135 | }, 136 | { 137 | 'testcase_name': 'split_sharded', 138 | 'split_vars': True, 139 | 'shard_keys_threshold': 3, 140 | }, 141 | ) 142 | def test_dataset_to_zarr_with_irregular_variables( 143 | self, split_vars, shard_keys_threshold 144 | ): 145 | dataset = xarray.Dataset( 146 | { 147 | 'volume1': ( 148 | ('t', 'x', 'y', 'z1'), 149 | np.arange(240).reshape(10, 2, 3, 4), 150 | ), 151 | 'volume2': ( 152 | ('t', 'x', 'y', 'z2'), 153 | np.arange(300).reshape(10, 2, 3, 5), 154 | ), 155 | 'surface': (('t', 'x', 'y'), np.arange(60).reshape(10, 2, 3)), 156 | 'static': (('x', 'y'), np.arange(6).reshape(2, 3)), 157 | } 158 | ) 159 | temp_dir = self.create_tempdir().full_path 160 | template = dataset.chunk() 161 | chunks = {'t': 1, 'z1': 2, 'z2': 3} 162 | ( 163 | test_util.EagerPipeline() 164 | | xbeam.DatasetToChunks( 165 | dataset, 166 | chunks, 167 | split_vars=split_vars, 168 | shard_keys_threshold=shard_keys_threshold, 169 | ) 170 | | xbeam.ChunksToZarr(temp_dir, template, chunks) 171 | ) 172 | actual = xarray.open_zarr(temp_dir, consolidated=True) 173 | xarray.testing.assert_identical(actual, dataset) 174 | 175 | 176 | if __name__ == '__main__': 177 | absltest.main() 178 | -------------------------------------------------------------------------------- /xarray_beam/_src/rechunk.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rechunking for xarray.Dataset objets.""" 15 | import collections 16 | import dataclasses 17 | import itertools 18 | import logging 19 | import math 20 | import textwrap 21 | from typing import ( 22 | Any, 23 | Dict, 24 | Iterable, 25 | Iterator, 26 | List, 27 | Optional, 28 | Mapping, 29 | Sequence, 30 | Tuple, 31 | Union, 32 | ) 33 | 34 | import apache_beam as beam 35 | import numpy as np 36 | from rechunker import algorithm 37 | import xarray 38 | 39 | from xarray_beam._src import core 40 | 41 | 42 | # pylint: disable=logging-not-lazy 43 | # pylint: disable=logging-fstring-interpolation 44 | 45 | 46 | def normalize_chunks( 47 | chunks: Mapping[str, Union[int, Tuple[int, ...]]], 48 | dim_sizes: Mapping[str, int], 49 | ) -> Dict[str, int]: 50 | """Normalize a dict of chunks.""" 51 | if not chunks.keys() <= dim_sizes.keys(): 52 | raise ValueError( 53 | 'all dimensions used in chunks must also have an indicated size: ' 54 | f'chunks={chunks} vs dim_sizes={dim_sizes}' 55 | ) 56 | result = {} 57 | for dim, size in dim_sizes.items(): 58 | if dim not in chunks: 59 | result[dim] = size 60 | elif isinstance(chunks[dim], tuple): 61 | unique_chunks = set(chunks[dim]) 62 | if len(unique_chunks) != 1: 63 | raise ValueError( 64 | f'chunks for dimension {dim} are not constant: {unique_chunks}', 65 | ) 66 | (result[dim],) = unique_chunks 67 | elif chunks[dim] == -1: 68 | result[dim] = size 69 | else: 70 | result[dim] = chunks[dim] 71 | return result 72 | 73 | 74 | def rechunking_plan( 75 | dim_sizes: Mapping[str, int], 76 | source_chunks: Mapping[str, int], 77 | target_chunks: Mapping[str, int], 78 | itemsize: int, 79 | min_mem: int, 80 | max_mem: int, 81 | ) -> List[List[Dict[str, int]]]: 82 | """Make a rechunking plan.""" 83 | stages = algorithm.multistage_rechunking_plan( 84 | shape=tuple(dim_sizes.values()), 85 | source_chunks=tuple(source_chunks[dim] for dim in dim_sizes), 86 | target_chunks=tuple(target_chunks[dim] for dim in dim_sizes), 87 | itemsize=itemsize, 88 | min_mem=min_mem, 89 | max_mem=max_mem, 90 | ) 91 | plan = [] 92 | for stage in stages: 93 | plan.append([dict(zip(dim_sizes.keys(), shapes)) for shapes in stage]) 94 | return plan 95 | 96 | 97 | def _consolidate_chunks_in_var_group( 98 | inputs: Sequence[Tuple[core.Key, xarray.Dataset]], 99 | combine_kwargs: Optional[Mapping[str, Any]], 100 | ) -> Tuple[core.Key, xarray.Dataset]: 101 | """Consolidate chunks across offsets with identical vars.""" 102 | unique_offsets = collections.defaultdict(set) 103 | unique_var_groups = set() 104 | for key, chunk in inputs: 105 | for dim, offset in key.offsets.items(): 106 | unique_offsets[dim].add(offset) 107 | unique_var_groups.add(key.vars) 108 | 109 | if len(unique_var_groups) != 1: 110 | raise ValueError( 111 | f'expected exactly one unique var group, got {unique_var_groups}' 112 | ) 113 | (cur_vars,) = unique_var_groups 114 | 115 | offsets = {k: sorted(v) for k, v in unique_offsets.items()} 116 | combined_offsets = {k: v[0] for k, v in offsets.items()} 117 | combined_key = core.Key(combined_offsets, cur_vars) 118 | 119 | # Consolidate inputs in a single xarray.Dataset. 120 | # `inputs` is a flat list like `[(k_00, ds_00), (k_01, ds_01), ...]` where 121 | # `k_ij` is a Key giving the (multi-dimensional) index of `ds_ij` in a 122 | # virtual larger Dataset. 123 | # Now we want to actually concatenate along all those dimensions, e.g., the 124 | # equivalent of building a large matrix out of sub-matrices: 125 | # ⌈[x_00 x_01] ...⌉ ⌈x_00 x_01 ...⌉ 126 | # X = |[x_10 x_11] ...| = |x_10 x_11 ...| 127 | # |[x_20 x_21] ...| |x_20 x_21 ...| 128 | # ⌊ ... ...⌋ ⌊ ... ... ...⌋ 129 | # In NumPy, this would be done with `np.block()`. 130 | offset_index = core.compute_offset_index(offsets) 131 | shape = [len(v) for v in offsets.values()] 132 | nested_array = np.empty(dtype=object, shape=shape) 133 | if np.prod(shape) != len(inputs): 134 | raise ValueError( 135 | f'some expected chunks are missing for vars={cur_vars} ' 136 | f'shape: {shape} len(inputs): {len(inputs)}' 137 | ) 138 | 139 | for key, chunk in inputs: 140 | nested_key = tuple(offset_index[dim][key.offsets[dim]] for dim in offsets) 141 | assert nested_array[nested_key] is None 142 | nested_array[nested_key] = chunk 143 | 144 | kwargs = dict( 145 | data_vars='minimal', 146 | coords='minimal', 147 | join='exact', 148 | combine_attrs='override', 149 | ) 150 | if combine_kwargs is not None: 151 | kwargs.update(combine_kwargs) 152 | 153 | try: 154 | combined_dataset = xarray.combine_nested( 155 | nested_array.tolist(), concat_dim=list(offsets), **kwargs 156 | ) 157 | return combined_key, combined_dataset 158 | except (ValueError, xarray.MergeError) as original_error: 159 | summaries = [] 160 | for axis, dim in enumerate(offsets): 161 | repr_string = '\n'.join( 162 | repr(ds) for ds in nested_array[(0,) * axis + (slice(2),)].tolist() 163 | ) 164 | if nested_array.shape[axis] > 2: 165 | repr_string += '\n...' 166 | repr_string = textwrap.indent(repr_string, prefix=' ') 167 | summaries.append( 168 | f'Leading datasets along dimension {dim!r}:\n{repr_string}' 169 | ) 170 | summaries_str = '\n'.join(summaries) 171 | raise ValueError( 172 | f'combining nested dataset chunks for vars={cur_vars} with ' 173 | f'offsets={offsets} failed.\n' 174 | + summaries_str 175 | ) from original_error 176 | 177 | 178 | def consolidate_chunks( 179 | inputs: Iterable[Tuple[core.Key, xarray.Dataset]], 180 | combine_kwargs: Optional[Mapping[str, Any]] = None, 181 | ) -> Iterable[Tuple[core.Key, xarray.Dataset]]: 182 | """Consolidate chunks across offsets into (Key, Dataset) pairs.""" 183 | inputs = list(inputs) 184 | keys = [key for key, _ in inputs] 185 | if len(set(keys)) < len(keys): 186 | raise ValueError(f'chunk keys are not unique: {keys}') 187 | 188 | # Group chunks by variable groups and combine offsets to validate inputs. 189 | inputs_by_vars = collections.defaultdict(list) 190 | combined_offsets_by_dim = collections.defaultdict(set) 191 | combined_offsets_by_vars = collections.defaultdict(set) 192 | for key, chunk in inputs: 193 | inputs_by_vars[key.vars].append((key, chunk)) 194 | for dim, offset in key.offsets.items(): 195 | combined_offsets_by_dim[dim].add(offset) 196 | combined_offsets_by_vars[(key.vars, dim)].add(offset) 197 | 198 | # All var groups need to have the exact same set of offsets on common 199 | # dimensions. 200 | for (cur_vars, dim), offsets in combined_offsets_by_vars.items(): 201 | if offsets != combined_offsets_by_dim[dim]: 202 | raise ValueError(f'some expected chunks are missing for vars={cur_vars}') 203 | 204 | for cur_vars, cur_inputs in inputs_by_vars.items(): 205 | combined_key, combined_dataset = _consolidate_chunks_in_var_group( 206 | cur_inputs, combine_kwargs 207 | ) 208 | yield combined_key, combined_dataset 209 | 210 | 211 | def consolidate_variables( 212 | inputs: Iterable[Tuple[core.Key, xarray.Dataset]], 213 | merge_kwargs: Optional[Mapping[str, Any]] = None, 214 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]: 215 | """Consolidate chunks across distinct variables into (Key, Dataset) pairs.""" 216 | kwargs = dict( 217 | compat='equals', 218 | join='exact', 219 | combine_attrs='override', 220 | ) 221 | if merge_kwargs is not None: 222 | kwargs.update(merge_kwargs) 223 | 224 | chunks_by_offsets = collections.defaultdict(list) 225 | for key, chunk in inputs: 226 | chunks_by_offsets[key.offsets].append(chunk) 227 | 228 | for offsets, chunks in chunks_by_offsets.items(): 229 | all_vars = [set(chunk.keys()) for chunk in chunks] 230 | new_vars = set.union(*all_vars) 231 | if len(new_vars) != sum(map(len, all_vars)): 232 | raise ValueError( 233 | f'cannot merge chunks with overlapping variables: {all_vars}' 234 | ) 235 | key = core.Key(offsets, new_vars) 236 | 237 | try: 238 | dataset = xarray.merge(chunks, **kwargs) 239 | except (ValueError, xarray.MergeError) as original_error: 240 | repr_string = '\n'.join(repr(ds) for ds in chunks[:2]) 241 | if len(chunks) > 2: 242 | repr_string += '\n...' 243 | repr_string = textwrap.indent(repr_string, prefix=' ') 244 | raise ValueError( 245 | f'merging dataset chunks with variables {all_vars} failed.\n' 246 | + repr_string 247 | ) from original_error 248 | yield key, dataset 249 | 250 | 251 | def consolidate_fully( 252 | inputs: Iterable[Tuple[core.Key, xarray.Dataset]], 253 | *, 254 | merge_kwargs: Optional[Mapping[str, Any]] = None, 255 | combine_kwargs: Optional[Mapping[str, Any]] = None, 256 | ) -> Tuple[core.Key, xarray.Dataset]: 257 | """Consolidate chunks via merge/concat into a single (Key, Dataset) pair.""" 258 | concatenated_chunks = [] 259 | combined_offsets = {} 260 | combined_vars = set() 261 | for key, chunk in consolidate_chunks(inputs, combine_kwargs): 262 | # We expect all chunks to be fully combined in all dimensions and all chunks 263 | # to have the same offset (in each dimension). The chunks from 264 | # consolidate_chunks() should already have this property but we explicitly 265 | # check it here again in case consolidate_chunks changes. 266 | for dim, offset in key.offsets.items(): 267 | if dim in combined_offsets and combined_offsets[dim] != offset: 268 | raise ValueError( 269 | 'consolidating chunks fully failed because ' 270 | f'chunk\n{chunk}\n has offsets {key.offsets} ' 271 | f'that differ from {combined_offsets}' 272 | ) 273 | combined_offsets[dim] = offset 274 | concatenated_chunks.append(chunk) 275 | combined_vars.update(chunk.keys()) 276 | 277 | # Merge variables, but unlike consolidate_variables, we merge all chunks and 278 | # not just chunks per unique key. 279 | kwargs = dict( 280 | compat='equals', 281 | join='exact', 282 | combine_attrs='override', 283 | ) 284 | if merge_kwargs is not None: 285 | kwargs.update(merge_kwargs) 286 | 287 | try: 288 | dataset = xarray.merge(concatenated_chunks, **kwargs) 289 | except (ValueError, xarray.MergeError) as original_error: 290 | repr_string = '\n'.join(repr(ds) for ds in concatenated_chunks[:2]) 291 | if len(concatenated_chunks) > 2: 292 | repr_string += '\n...' 293 | repr_string = textwrap.indent(repr_string, prefix=' ') 294 | raise ValueError( 295 | f'merging dataset chunks with variables {combined_vars} failed.\n' 296 | + repr_string 297 | ) from original_error 298 | return core.Key(combined_offsets, combined_vars), dataset # pytype: disable=wrong-arg-types 299 | 300 | 301 | class _ConsolidateBase(beam.PTransform): 302 | 303 | def expand(self, pcoll): 304 | return ( 305 | pcoll 306 | | 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key) 307 | | 'GroupByTempKeys' >> beam.GroupByKey() 308 | | 'Consolidate' >> beam.MapTuple(self._consolidate_chunks) 309 | ) 310 | 311 | 312 | def _round_chunk_key( 313 | key: core.Key, 314 | target_chunks: Mapping[str, int], 315 | ) -> core.Key: 316 | """Round down a chunk-key to offsets corresponding to new chunks.""" 317 | new_offsets = {} 318 | for dim, offset in key.offsets.items(): 319 | chunk_size = target_chunks.get(dim) 320 | if chunk_size is None: 321 | new_offsets[dim] = offset 322 | elif chunk_size == -1: 323 | new_offsets[dim] = 0 324 | else: 325 | new_offsets[dim] = chunk_size * (offset // chunk_size) 326 | return key.replace(new_offsets) 327 | 328 | 329 | @dataclasses.dataclass 330 | class ConsolidateChunks(beam.PTransform): 331 | """Consolidate existing chunks across offsets into bigger chunks.""" 332 | 333 | target_chunks: Mapping[str, int] 334 | 335 | def _prepend_chunk_key(self, key, chunk): 336 | rounded_key = _round_chunk_key(key, self.target_chunks) 337 | return rounded_key, (key, chunk) 338 | 339 | def _consolidate(self, key, inputs): 340 | ((consolidated_key, dataset),) = consolidate_chunks(inputs) 341 | assert key == consolidated_key, (key, consolidated_key) 342 | return consolidated_key, dataset 343 | 344 | def expand(self, pcoll): 345 | return ( 346 | pcoll 347 | | 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key) 348 | | 'GroupByTempKeys' >> beam.GroupByKey() 349 | | 'Consolidate' >> beam.MapTuple(self._consolidate) 350 | ) 351 | 352 | 353 | class ConsolidateVariables(beam.PTransform): 354 | """Consolidate existing chunks across variables into bigger chunks.""" 355 | 356 | # TODO(shoyer): add support for partial consolidation into explicit sets 357 | # of variables. 358 | 359 | def _prepend_chunk_key(self, key, chunk): 360 | return key.replace(vars=None), (key, chunk) 361 | 362 | def _consolidate(self, key, inputs): 363 | ((consolidated_key, dataset),) = consolidate_variables(inputs) 364 | assert key.offsets == consolidated_key.offsets, (key, consolidated_key) 365 | assert key.vars is None 366 | # TODO(shoyer): consider carefully whether it is better to return key or 367 | # consolidated_key. They are both valid in the xarray-beam data model -- the 368 | # difference is whether vars=None or is an explicit set of variables. 369 | # For now, conservatively return the version of key with vars=None so 370 | # users don't rely on it. 371 | return key, dataset 372 | 373 | def expand(self, pcoll): 374 | return ( 375 | pcoll 376 | | 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key) 377 | | 'GroupByTempKeys' >> beam.GroupByKey() 378 | | 'Consolidate' >> beam.MapTuple(self._consolidate) 379 | ) 380 | 381 | 382 | def _split_chunk_bounds( 383 | start: int, 384 | stop: int, 385 | multiple: int, 386 | ) -> List[Tuple[int, int]]: 387 | # pylint: disable=g-doc-args 388 | # pylint: disable=g-doc-return-or-yield 389 | """Calculate the size of divided chunks along a dimension. 390 | 391 | Example usage: 392 | 393 | >>> _split_chunk_bounds(0, 10, 3) 394 | [(0, 3), (3, 6), (6, 9), (9, 10)] 395 | >>> _split_chunk_bounds(5, 10, 3) 396 | [(5, 6), (6, 9), (9, 10)] 397 | >>> _split_chunk_bounds(10, 20, 12) 398 | [(10, 12), (12, 20)] 399 | """ 400 | if multiple == -1: 401 | return [(start, stop)] 402 | assert start >= 0 and stop > start and multiple > 0, (start, stop, multiple) 403 | first_multiple = (start // multiple + 1) * multiple 404 | breaks = list(range(first_multiple, stop, multiple)) 405 | return list(zip([start] + breaks, breaks + [stop])) 406 | 407 | 408 | def split_chunks( 409 | key: core.Key, 410 | dataset: xarray.Dataset, 411 | target_chunks: Mapping[str, int], 412 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]: 413 | """Split a single (Key, xarray.Dataset) pair into many chunks.""" 414 | # This function splits consolidated arrays into blocks of new sizes, e.g., 415 | # ⌈x_00 x_01 ...⌉ ⌈⌈x_00⌉ ⌈x_01⌉ ...⌉ 416 | # X = |x_10 x_11 ...| = ||x_10| |x_11| ...| 417 | # |x_20 x_21 ...| |⌊x_20⌋ ⌊x_21⌋ ...| 418 | # ⌊ ... ... ...⌋ ⌊ ... ... ...⌋ 419 | # and emits them as (Key, xarray.Dataset) pairs. 420 | all_bounds = [] 421 | for dim, chunk_size in target_chunks.items(): 422 | start = key.offsets.get(dim, 0) 423 | stop = start + dataset.sizes[dim] 424 | all_bounds.append(_split_chunk_bounds(start, stop, chunk_size)) 425 | 426 | for bounds in itertools.product(*all_bounds): 427 | new_offsets = dict(key.offsets) 428 | slices = {} 429 | for dim, (start, stop) in zip(target_chunks, bounds): 430 | base = key.offsets.get(dim, 0) 431 | new_offsets[dim] = start 432 | slices[dim] = slice(start - base, stop - base) 433 | 434 | new_key = key.replace(new_offsets) 435 | new_chunk = dataset.isel(slices) 436 | yield new_key, new_chunk 437 | 438 | 439 | @dataclasses.dataclass 440 | class SplitChunks(beam.PTransform): 441 | """Split existing chunks into smaller chunks.""" 442 | 443 | target_chunks: Mapping[str, int] 444 | 445 | def _split_chunks(self, key, dataset): 446 | target_chunks = { 447 | k: v for k, v in self.target_chunks.items() if k in dataset.dims 448 | } 449 | yield from split_chunks(key, dataset, target_chunks) 450 | 451 | def expand(self, pcoll): 452 | return pcoll | beam.FlatMapTuple(self._split_chunks) 453 | 454 | 455 | def split_variables( 456 | key: core.Key, 457 | dataset: xarray.Dataset, 458 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]: 459 | """Split a single (Key, xarray.Dataset) pair into separate variables.""" 460 | # TODO(shoyer): add support for partial splitting, into explicitly provided 461 | # sets of variables 462 | for var_name in dataset: 463 | new_dataset = dataset[[var_name]] 464 | offsets = {k: v for k, v in key.offsets.items() if k in new_dataset.dims} 465 | new_key = core.Key(offsets, vars={var_name}) # pytype: disable=wrong-arg-types 466 | yield new_key, new_dataset 467 | 468 | 469 | @dataclasses.dataclass 470 | class SplitVariables(beam.PTransform): 471 | """Split existing chunks into a separate chunk per data variable.""" 472 | 473 | def expand(self, pcoll): 474 | return pcoll | beam.FlatMapTuple(split_variables) 475 | 476 | 477 | def in_memory_rechunk( 478 | inputs: List[Tuple[core.Key, xarray.Dataset]], 479 | target_chunks: Mapping[str, int], 480 | ) -> Iterator[Tuple[core.Key, xarray.Dataset]]: 481 | """Rechunk in-memory pairs of (Key, xarray.Dataset).""" 482 | consolidated = consolidate_chunks(inputs) 483 | for key, dataset in consolidated: 484 | yield from split_chunks(key, dataset, target_chunks) 485 | 486 | 487 | @dataclasses.dataclass 488 | class RechunkStage(beam.PTransform): 489 | """A single stage of a rechunking pipeline.""" 490 | 491 | source_chunks: Mapping[str, int] 492 | target_chunks: Mapping[str, int] 493 | 494 | def expand(self, pcoll): 495 | source_values = self.source_chunks.values() 496 | target_values = self.target_chunks.values() 497 | if any(t % s for s, t in zip(source_values, target_values)): 498 | pcoll |= 'Split' >> SplitChunks(self.target_chunks) 499 | if any(s % t for s, t in zip(source_values, target_values)): 500 | pcoll |= 'Consolidate' >> ConsolidateChunks(self.target_chunks) 501 | return pcoll 502 | 503 | 504 | class Rechunk(beam.PTransform): 505 | """Rechunk to an arbitrary new chunking scheme with bounded memory usage. 506 | 507 | The approach taken here builds on Rechunker [1], but differs in two key ways: 508 | 509 | 1. It is performed via collective Beam operations, instead of writing 510 | intermediates arrays to disk. 511 | 2. It is performed collectively on full xarray.Dataset objects, instead of 512 | NumPy arrays. 513 | 514 | [1] rechunker.readthedocs.io 515 | """ 516 | 517 | def __init__( 518 | self, 519 | dim_sizes: Mapping[str, int], 520 | source_chunks: Mapping[str, Union[int, Tuple[int, ...]]], 521 | target_chunks: Mapping[str, Union[int, Tuple[int, ...]]], 522 | itemsize: int, 523 | min_mem: Optional[int] = None, 524 | max_mem: int = 2 ** 30, # 1 GB 525 | ): 526 | """Initialize Rechunk(). 527 | 528 | Args: 529 | dim_sizes: size of the full (combined) dataset of all chunks. 530 | source_chunks: sizes of source chunks. Missing keys or values equal to -1 531 | indicate "non-chunked" dimensions. 532 | target_chunks: sizes of target chunks, like `source_keys`. Keys must 533 | exactly match those found in source_chunks. 534 | itemsize: approximate number of bytes per xarray.Dataset element, after 535 | indexing out by all dimensions, e.g., `4 * len(dataset)` for float32 536 | data or roughly `dataset.nbytes / np.prod(dataset.sizes)`. 537 | min_mem: minimum memory that a single intermediate chunk must consume. 538 | max_mem: maximum memory that a single intermediate chunk may consume. 539 | """ 540 | if source_chunks.keys() != target_chunks.keys(): 541 | raise ValueError( 542 | 'source_chunks and target_chunks have different keys: ' 543 | f'{source_chunks} vs {target_chunks}' 544 | ) 545 | if min_mem is None: 546 | min_mem = max_mem // 100 547 | self.dim_sizes = dim_sizes 548 | self.source_chunks = normalize_chunks(source_chunks, dim_sizes) 549 | self.target_chunks = normalize_chunks(target_chunks, dim_sizes) 550 | 551 | if self.source_chunks == self.target_chunks: 552 | self.stage_in = self.stage_out = [] 553 | logging.info(f'Rechunk with chunks {self.source_chunks} is a no-op') 554 | return 555 | 556 | plan = rechunking_plan( 557 | dim_sizes, 558 | self.source_chunks, 559 | self.target_chunks, 560 | itemsize=itemsize, 561 | min_mem=min_mem, 562 | max_mem=max_mem, 563 | ) 564 | plan = ( 565 | [[self.source_chunks, self.source_chunks, plan[0][0]]] 566 | + plan 567 | + [[plan[-1][-1], self.target_chunks, self.target_chunks]] 568 | ) 569 | self.stage_in, (_, *intermediates, _), self.stage_out = zip(*plan) 570 | 571 | logging.info( 572 | 'Rechunking plan:\n' 573 | + '\n'.join( 574 | f'Stage{i}: {s} -> {t}' 575 | for i, (s, t) in enumerate(zip(self.stage_in, self.stage_out)) 576 | ) 577 | ) 578 | min_size = min( 579 | itemsize * math.prod(chunks.values()) for chunks in intermediates 580 | ) 581 | logging.info(f'Smallest intermediates have size {min_size:1.3e}') 582 | 583 | def expand(self, pcoll): 584 | for stage, (in_chunks, out_chunks) in enumerate( 585 | zip(self.stage_in, self.stage_out) 586 | ): 587 | pcoll |= f'Stage{stage}' >> RechunkStage(in_chunks, out_chunks) 588 | return pcoll 589 | -------------------------------------------------------------------------------- /xarray_beam/_src/test_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Testing utilities for Xarray-Beam.""" 15 | import pickle 16 | import tempfile 17 | 18 | from absl.testing import parameterized 19 | import apache_beam as beam 20 | import numpy as np 21 | import pandas as pd 22 | import xarray 23 | 24 | # pylint: disable=expression-not-assigned 25 | 26 | 27 | def _write_pickle(value, path): 28 | with open(path, 'wb') as f: 29 | pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL) 30 | 31 | 32 | class EagerPipeline: 33 | """A mock Beam pipeline for testing that returns lists of Python objects. 34 | 35 | Example usage: 36 | 37 | >>> EagerPipeline() | beam.Create([1, 2, 3]) | beam.Map(lambda x: x**2) 38 | [1, 4, 9] 39 | """ 40 | 41 | def __or__(self, ptransform): 42 | with tempfile.NamedTemporaryFile() as f: 43 | with beam.Pipeline('DirectRunner') as pipeline: 44 | ( 45 | pipeline 46 | | ptransform 47 | | beam.combiners.ToList() 48 | | beam.Map(_write_pickle, f.name) 49 | ) 50 | pipeline.run() 51 | return pickle.load(f) 52 | 53 | 54 | class TestCase(parameterized.TestCase): 55 | """TestCase for use in internal Xarray-Beam tests.""" 56 | 57 | def _assert_chunks(self, array_assert_func, actual, expected): 58 | actual = dict(actual) 59 | expected = dict(expected) 60 | self.assertCountEqual(expected, actual, msg='inconsistent keys') 61 | for key in expected: 62 | actual_chunk, expected_chunk = actual[key], expected[key] 63 | self.assertEqual(type(actual_chunk), type(expected_chunk)) 64 | if type(actual_chunk) is not list: 65 | actual_chunk, expected_chunk = (actual_chunk,), (expected_chunk,) 66 | for a, e in zip(actual_chunk, expected_chunk): 67 | array_assert_func(a, e) 68 | 69 | def assertIdenticalChunks(self, actual, expected): 70 | self._assert_chunks(xarray.testing.assert_identical, actual, expected) 71 | 72 | def assertAllCloseChunks(self, actual, expected): 73 | self._assert_chunks(xarray.testing.assert_allclose, actual, expected) 74 | 75 | 76 | def dummy_era5_surface_dataset( 77 | variables=2, 78 | latitudes=73, 79 | longitudes=144, 80 | times=365 * 4, 81 | freq='6H', 82 | ): 83 | """A mock version of the Pangeo ERA5 surface reanalysis dataset.""" 84 | # based on: gs://pangeo-era5/reanalysis/spatial-analysis 85 | dims = ('time', 'latitude', 'longitude') 86 | shape = (times, latitudes, longitudes) 87 | var_names = ['asn', 'd2m', 'e', 'mn2t', 'mx2t', 'ptype'][:variables] 88 | rng = np.random.default_rng(0) 89 | data_vars = { 90 | name: (dims, rng.normal(size=shape).astype(np.float32), {'var_index': i}) 91 | for i, name in enumerate(var_names) 92 | } 93 | 94 | latitude = np.linspace(90, 90, num=latitudes) 95 | longitude = np.linspace(0, 360, num=longitudes, endpoint=False) 96 | time = pd.date_range('1979-01-01T00', periods=times, freq=freq) 97 | coords = {'time': time, 'latitude': latitude, 'longitude': longitude} 98 | 99 | return xarray.Dataset(data_vars, coords, {'global_attr': 'yes'}) 100 | -------------------------------------------------------------------------------- /xarray_beam/_src/threadmap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Beam Map transforms that execute in parallel using a thread pool. 15 | 16 | This can be a good idea for IO bound tasks, but generally should be avoided for 17 | CPU bound tasks, especially CPU bound tasks that do not release the Python GIL. 18 | 19 | They can be used as drop-in substitutes for the corresponding Beam transforms: 20 | - beam.Map -> ThreadMap 21 | - beam.MapTuple -> ThreadMapTuple 22 | - beam.FlatMap -> FlatThreadMap 23 | - beam.FlatMapTuple -> FlatThreadMapTuple 24 | 25 | By default, 16 threads are used per task. This can be adjusted via the 26 | `num_threads` keyword argument. 27 | """ 28 | import concurrent.futures 29 | import functools 30 | 31 | import apache_beam as beam 32 | 33 | 34 | class ThreadDoFn(beam.DoFn): 35 | """A DoFn that executes inputs in a ThreadPool.""" 36 | 37 | def __init__(self, func, num_threads): 38 | self.func = func 39 | self.num_threads = num_threads 40 | 41 | def setup(self): 42 | self.executor = concurrent.futures.ThreadPoolExecutor(self.num_threads) 43 | 44 | def teardown(self): 45 | self.executor.shutdown() 46 | 47 | def process(self, element, *args, **kwargs): 48 | futures = [] 49 | for x in element: 50 | futures.append(self.executor.submit(self.func, x, *args, **kwargs)) 51 | for future in futures: 52 | yield future.result() 53 | 54 | 55 | class FlatThreadDoFn(ThreadDoFn): 56 | 57 | def process(self, element, *args, **kwargs): 58 | for results in super().process(element, *args, **kwargs): 59 | yield from results 60 | 61 | 62 | class _ThreadMap(beam.PTransform): 63 | """Like beam.Map, but executed in a thread-pool.""" 64 | 65 | def __init__(self, func, *args, num_threads, **kwargs): 66 | self.func = func 67 | self.args = args 68 | self.kwargs = kwargs 69 | self.num_threads = num_threads 70 | 71 | def get_dofn(self): 72 | return ThreadDoFn(self.func, self.num_threads) 73 | 74 | def expand(self, pcoll): 75 | return ( 76 | pcoll 77 | | 'BatchElements' 78 | >> beam.BatchElements( 79 | min_batch_size=self.num_threads, 80 | max_batch_size=self.num_threads, 81 | ) 82 | | 'ParDo' >> beam.ParDo(self.get_dofn(), *self.args, **self.kwargs) 83 | ) 84 | 85 | 86 | class _ThreadMapTuple(_ThreadMap): 87 | """Like beam.MapTuple, but executed in a thread-pool.""" 88 | 89 | def get_dofn(self): 90 | func = lambda xs, **kwargs: self.func(*xs, **kwargs) 91 | return ThreadDoFn(func, self.num_threads) 92 | 93 | 94 | class _FlatThreadMap(_ThreadMap): 95 | """Like beam.FlatMap, but executed in a thread-pool.""" 96 | 97 | def get_dofn(self): 98 | return FlatThreadDoFn(self.func, self.num_threads) 99 | 100 | 101 | class _FlatThreadMapTuple(_ThreadMap): 102 | """Like beam.FlatMapTuple, but executed in a thread-pool.""" 103 | 104 | def get_dofn(self): 105 | func = lambda xs, **kwargs: self.func(*xs, **kwargs) 106 | return FlatThreadDoFn(func, self.num_threads) 107 | 108 | 109 | def _maybe_threaded(beam_transform, thread_transform): 110 | @functools.wraps(thread_transform) 111 | def create(func, *args, num_threads=16, **kwargs): 112 | if num_threads is None: 113 | return beam_transform(func, *args, **kwargs) 114 | else: 115 | return thread_transform(func, *args, num_threads=num_threads, **kwargs) 116 | 117 | return create 118 | 119 | 120 | # These functions don't use threads if num_threads=None. 121 | ThreadMap = _maybe_threaded(beam.Map, _ThreadMap) 122 | ThreadMapTuple = _maybe_threaded(beam.MapTuple, _ThreadMapTuple) 123 | FlatThreadMap = _maybe_threaded(beam.FlatMap, _FlatThreadMap) 124 | FlatThreadMapTuple = _maybe_threaded(beam.FlatMapTuple, _FlatThreadMapTuple) 125 | -------------------------------------------------------------------------------- /xarray_beam/_src/threadmap_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for threadmap.""" 15 | 16 | from absl.testing import absltest 17 | import apache_beam as beam 18 | import unittest 19 | 20 | from xarray_beam._src import test_util 21 | from xarray_beam._src import threadmap 22 | 23 | 24 | # pylint: disable=expression-not-assigned 25 | # pylint: disable=pointless-statement 26 | 27 | 28 | class ThreadMapTest(test_util.TestCase): 29 | 30 | def test_map(self): 31 | def f(*args, **kwargs): 32 | return args, kwargs 33 | 34 | expected = [1, 2, 3] | beam.Map(f, 4, y=5) 35 | actual = [1, 2, 3] | threadmap.ThreadMap(f, 4, y=5) 36 | self.assertEqual(expected, actual) 37 | 38 | actual = [1, 2, 3] | threadmap.ThreadMap(f, 4, y=5, num_threads=None) 39 | self.assertEqual(expected, actual) 40 | 41 | @unittest.skip('this is failing with recent Apache Beam releases') 42 | def test_flat_map(self): 43 | def f(*args, **kwargs): 44 | return [(args, kwargs)] * 2 45 | 46 | expected = [1, 2, 3] | beam.FlatMap(f, 4, y=5) 47 | actual = [1, 2, 3] | threadmap.FlatThreadMap(f, 4, y=5) 48 | self.assertEqual(expected, actual) 49 | 50 | actual = [1, 2, 3] | threadmap.FlatThreadMap(f, 4, y=5, num_threads=None) 51 | self.assertEqual(expected, actual) 52 | 53 | def test_map_tuple(self): 54 | def f(a, b, y=None): 55 | return a, b, y 56 | 57 | expected = [(1, 2), (3, 4)] | beam.MapTuple(f, y=5) 58 | actual = [(1, 2), (3, 4)] | threadmap.ThreadMapTuple(f, y=5) 59 | self.assertEqual(expected, actual) 60 | 61 | actual = [(1, 2), (3, 4)] | threadmap.ThreadMapTuple( 62 | f, y=5, num_threads=None 63 | ) 64 | self.assertEqual(expected, actual) 65 | 66 | def test_flat_map_tuple(self): 67 | def f(a, b, y=None): 68 | return a, b, y 69 | 70 | expected = [(1, 2), (3, 4)] | beam.FlatMapTuple(f, y=5) 71 | actual = [(1, 2), (3, 4)] | threadmap.FlatThreadMapTuple(f, y=5) 72 | self.assertEqual(expected, actual) 73 | 74 | actual = [(1, 2), (3, 4)] | threadmap.FlatThreadMapTuple( 75 | f, y=5, num_threads=None 76 | ) 77 | self.assertEqual(expected, actual) 78 | 79 | def test_maybe_thread_map(self): 80 | transform = threadmap.ThreadMap(lambda x: x) 81 | self.assertIsInstance(transform, threadmap._ThreadMap) 82 | 83 | transform = threadmap.ThreadMap(lambda x: x, num_threads=None) 84 | self.assertIsInstance(transform, beam.ParDo) 85 | 86 | transform = threadmap.ThreadMap(lambda x: x, num_threads=1) 87 | self.assertIsInstance(transform, threadmap._ThreadMap) 88 | 89 | 90 | if __name__ == '__main__': 91 | absltest.main() 92 | -------------------------------------------------------------------------------- /xarray_beam/_src/zarr_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import dask.array as da 19 | import numpy as np 20 | import pandas as pd 21 | import xarray 22 | import xarray_beam as xbeam 23 | from xarray_beam._src import test_util 24 | 25 | 26 | # pylint: disable=expression-not-assigned 27 | # pylint: disable=pointless-statement 28 | 29 | 30 | class DatasetToZarrTest(test_util.TestCase): 31 | 32 | def test_open_zarr(self): 33 | source_ds = xarray.Dataset( 34 | {'foo': ('x', da.arange(0, 60, 10, chunks=2))}, 35 | ) 36 | temp_dir = self.create_tempdir().full_path 37 | source_ds.to_zarr(temp_dir) 38 | roundtripped_ds, chunks = xbeam.open_zarr(temp_dir) 39 | xarray.testing.assert_identical(roundtripped_ds, source_ds) 40 | self.assertEqual(roundtripped_ds.chunks, {}) 41 | self.assertEqual(chunks, {'x': 2}) 42 | 43 | def test_open_zarr_inconsistent(self): 44 | source_ds = xarray.Dataset( 45 | { 46 | 'foo': ('x', da.arange(0, 60, 10, chunks=2)), 47 | 'bar': ('x', da.arange(0, 60, 10, chunks=3)), 48 | }, 49 | ) 50 | temp_dir = self.create_tempdir().full_path 51 | source_ds.to_zarr(temp_dir) 52 | with self.assertRaisesRegex( 53 | ValueError, 54 | "inconsistent chunk sizes on Zarr dataset for dimension 'x': {2, 3}", 55 | ): 56 | xbeam.open_zarr(temp_dir) 57 | 58 | def test_make_template(self): 59 | source = xarray.Dataset( 60 | { 61 | 'foo': ('x', np.ones(3)), 62 | 'bar': ('x', np.ones(3)), 63 | }, 64 | ) 65 | template = xbeam.make_template(source) 66 | self.assertEqual(list(template.data_vars), ['foo', 'bar']) 67 | self.assertEqual(template.chunks, {'x': (3,)}) 68 | self.assertEqual(template.sizes, {'x': 3}) 69 | with self.assertRaisesRegex( 70 | ValueError, 'cannot compute array values of xarray.Dataset objects' 71 | ): 72 | template.compute() 73 | 74 | def test_make_template_lazy_vars_on_numpy(self): 75 | source = xarray.Dataset( 76 | { 77 | 'foo': ('x', np.ones(3)), 78 | 'bar': ('x', np.ones(3)), 79 | }, 80 | ) 81 | template = xbeam.make_template(source, lazy_vars={'foo'}) 82 | self.assertEqual(template.foo.chunks, ((3,),)) 83 | self.assertIsNone(template.bar.chunks) 84 | 85 | def test_make_template_lazy_vars_on_dask(self): 86 | source = xarray.Dataset( 87 | { 88 | 'foo': ('x', np.ones(3)), 89 | 'bar': ('x', np.ones(3)), 90 | }, 91 | ).chunk({'x': 2}) 92 | template = xbeam.make_template(source, lazy_vars={'foo'}) 93 | self.assertEqual(template.foo.chunks, ((3,),)) # one chunk 94 | self.assertIsInstance(template.bar.data, np.ndarray) # computed 95 | 96 | def test_make_template_from_chunked(self): 97 | source = xarray.Dataset( 98 | { 99 | 'foo': ('x', da.ones(3)), 100 | 'bar': ('x', np.ones(3)), 101 | }, 102 | ) 103 | template = xbeam._src.zarr._make_template_from_chunked(source) 104 | self.assertEqual(template.foo.chunks, ((3,),)) 105 | self.assertIsNone(template.bar.chunks) 106 | 107 | def test_replace_template_dims_with_coords(self): 108 | source = xarray.Dataset( 109 | {'foo': (('x', 'y'), np.zeros((1, 2)))}, 110 | coords={'x': [0], 'y': [10, 20]}, 111 | ) 112 | template = xbeam.make_template(source) 113 | new_x_coords = pd.date_range('2000-01-01', periods=5) 114 | new_template = xbeam.replace_template_dims(template, x=new_x_coords) 115 | 116 | self.assertEqual(new_template.sizes, {'x': 5, 'y': 2}) 117 | expected_x_coord = xarray.DataArray( 118 | new_x_coords, dims='x', coords={'x': new_x_coords} 119 | ) 120 | xarray.testing.assert_equal(new_template.x, expected_x_coord) 121 | xarray.testing.assert_equal(new_template.y, source.y) # Unchanged coord 122 | self.assertEqual(new_template.foo.shape, (5, 2)) 123 | self.assertIsInstance(new_template.foo.data, da.Array) # Still lazy 124 | 125 | def test_replace_template_dims_with_size(self): 126 | source = xarray.Dataset( 127 | {'foo': (('x', 'y'), np.zeros((1, 2)))}, 128 | coords={'x': [0], 'y': [10, 20]}, 129 | ) 130 | template = xbeam.make_template(source) 131 | new_template = xbeam.replace_template_dims(template, x=10) 132 | 133 | self.assertEqual(new_template.sizes, {'x': 10, 'y': 2}) 134 | self.assertNotIn( 135 | 'x', new_template.coords 136 | ) # Coord is dropped when replaced by size 137 | xarray.testing.assert_equal(new_template.y, source.y) 138 | self.assertEqual(new_template.foo.shape, (10, 2)) 139 | self.assertIsInstance(new_template.foo.data, da.Array) 140 | 141 | def test_replace_template_dims_multiple(self): 142 | source = xarray.Dataset( 143 | {'foo': (('x', 'y'), np.zeros((1, 2)))}, 144 | coords={'x': [0], 'y': [10, 20]}, 145 | ) 146 | template = xbeam.make_template(source) 147 | new_x_coords = pd.date_range('2000-01-01', periods=5) 148 | new_template = xbeam.replace_template_dims(template, x=new_x_coords, y=3) 149 | 150 | self.assertEqual(new_template.sizes, {'x': 5, 'y': 3}) 151 | expected_x_coord = xarray.DataArray( 152 | new_x_coords, dims='x', coords={'x': new_x_coords} 153 | ) 154 | xarray.testing.assert_equal(new_template.x, expected_x_coord) 155 | self.assertNotIn('y', new_template.coords) 156 | self.assertEqual(new_template.foo.shape, (5, 3)) 157 | self.assertIsInstance(new_template.foo.data, da.Array) 158 | 159 | def test_replace_template_dims_multiple_vars(self): 160 | source = xarray.Dataset( 161 | { 162 | 'foo': (('x', 'y'), np.zeros((1, 2))), 163 | 'bar': ('x', np.zeros(1)), 164 | 'baz': ('z', np.zeros(3)), # Unrelated dim 165 | }, 166 | coords={'x': [0], 'y': [10, 20], 'z': [1, 2, 3]}, 167 | ) 168 | template = xbeam.make_template(source) 169 | new_template = xbeam.replace_template_dims(template, x=5) 170 | 171 | self.assertEqual(new_template.sizes, {'x': 5, 'y': 2, 'z': 3}) 172 | self.assertNotIn('x', new_template.coords) 173 | xarray.testing.assert_equal(new_template.y, source.y) 174 | xarray.testing.assert_equal(new_template.z, source.z) 175 | self.assertEqual(new_template.foo.shape, (5, 2)) 176 | self.assertEqual(new_template.bar.shape, (5,)) 177 | self.assertEqual(new_template.baz.shape, (3,)) # Unchanged var 178 | self.assertIsInstance(new_template.foo.data, da.Array) 179 | self.assertIsInstance(new_template.bar.data, da.Array) 180 | self.assertIsInstance(new_template.baz.data, da.Array) 181 | 182 | def test_replace_template_dims_error_on_non_template(self): 183 | source = xarray.Dataset({'foo': ('x', np.zeros(1))}) # Not a template 184 | with self.assertRaisesRegex(ValueError, 'is not chunked with Dask'): 185 | xbeam.replace_template_dims(source, x=5) 186 | 187 | def test_chunks_to_zarr(self): 188 | dataset = xarray.Dataset( 189 | {'foo': ('x', np.arange(0, 60, 10))}, 190 | coords={'x': np.arange(6)}, 191 | ) 192 | chunked = dataset.chunk() 193 | inputs = [ 194 | (xbeam.Key({'x': 0}), dataset), 195 | ] 196 | with self.subTest('no template'): 197 | temp_dir = self.create_tempdir().full_path 198 | with self.assertWarnsRegex(FutureWarning, 'No template provided'): 199 | inputs | xbeam.ChunksToZarr(temp_dir, template=None) 200 | result = xarray.open_zarr(temp_dir, consolidated=True) 201 | xarray.testing.assert_identical(dataset, result) 202 | with self.subTest('with template'): 203 | temp_dir = self.create_tempdir().full_path 204 | inputs | xbeam.ChunksToZarr(temp_dir, chunked) 205 | result = xarray.open_zarr(temp_dir, consolidated=True) 206 | xarray.testing.assert_identical(dataset, result) 207 | with self.subTest('with template and needs_setup=False'): 208 | temp_dir = self.create_tempdir().full_path 209 | xbeam.setup_zarr(chunked, temp_dir) 210 | inputs | xbeam.ChunksToZarr(temp_dir, chunked, needs_setup=False) 211 | result = xarray.open_zarr(temp_dir, consolidated=True) 212 | xarray.testing.assert_identical(dataset, result) 213 | with self.subTest('with zarr_chunks and with template'): 214 | temp_dir = self.create_tempdir().full_path 215 | zarr_chunks = {'x': 3} 216 | inputs | xbeam.ChunksToZarr(temp_dir, chunked, zarr_chunks) 217 | result = xarray.open_zarr(temp_dir, consolidated=True) 218 | xarray.testing.assert_identical(dataset, result) 219 | self.assertEqual(result.chunks, {'x': (3, 3)}) 220 | with self.subTest('with zarr_chunks and no template'): 221 | temp_dir = self.create_tempdir().full_path 222 | zarr_chunks = {'x': 3} 223 | with self.assertWarnsRegex(FutureWarning, 'No template provided'): 224 | inputs | xbeam.ChunksToZarr( 225 | temp_dir, template=None, zarr_chunks=zarr_chunks 226 | ) 227 | result = xarray.open_zarr(temp_dir, consolidated=True) 228 | xarray.testing.assert_identical(dataset, result) 229 | self.assertEqual(result.chunks, {'x': (3, 3)}) 230 | 231 | temp_dir = self.create_tempdir().full_path 232 | with self.assertRaisesRegex( 233 | ValueError, 234 | 'template does not have any variables chunked with Dask', 235 | ): 236 | xbeam.ChunksToZarr(temp_dir, dataset) 237 | 238 | temp_dir = self.create_tempdir().full_path 239 | template = chunked.assign_coords(x=np.zeros(6)) 240 | with self.assertRaisesRegex( 241 | ValueError, 242 | 'template and chunk indexes do not match', 243 | ): 244 | inputs | xbeam.ChunksToZarr(temp_dir, template) 245 | 246 | inputs2 = [ 247 | (xbeam.Key({'x': 0}), dataset.expand_dims(z=[1, 2])), 248 | ] 249 | temp_dir = self.create_tempdir().full_path 250 | with self.assertRaisesRegex( 251 | ValueError, 252 | 'unexpected new indexes found in chunk', 253 | ): 254 | inputs2 | xbeam.ChunksToZarr(temp_dir, template) 255 | 256 | def test_multiple_vars_chunks_to_zarr(self): 257 | dataset = xarray.Dataset( 258 | { 259 | 'foo': ('x', np.arange(0, 60, 10)), 260 | 'bar': ('x', -np.arange(6)), 261 | }, 262 | coords={'x': np.arange(6)}, 263 | ) 264 | chunked = dataset.chunk() 265 | inputs = [ 266 | (xbeam.Key({'x': 0}, {'foo'}), dataset[['foo']]), 267 | (xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']]), 268 | ] 269 | with self.subTest('no template'): 270 | temp_dir = self.create_tempdir().full_path 271 | with self.assertWarnsRegex(FutureWarning, 'No template provided'): 272 | inputs | xbeam.ChunksToZarr(temp_dir, template=None) 273 | result = xarray.open_zarr(temp_dir, consolidated=True) 274 | xarray.testing.assert_identical(dataset, result) 275 | with self.subTest('with template'): 276 | temp_dir = self.create_tempdir().full_path 277 | inputs | xbeam.ChunksToZarr(temp_dir, chunked) 278 | result = xarray.open_zarr(temp_dir, consolidated=True) 279 | xarray.testing.assert_identical(dataset, result) 280 | 281 | @parameterized.named_parameters( 282 | { 283 | 'testcase_name': 'combined_coords', 284 | 'coords': {'bar': (('x', 'y'), -np.arange(6).reshape(3, 2))}, 285 | }, 286 | { 287 | 'testcase_name': 'separate_coords', 288 | 'coords': {'x': np.arange(3), 'y': np.arange(2)}, 289 | }, 290 | ) 291 | def test_2d_chunks_to_zarr(self, coords): 292 | dataset = xarray.Dataset( 293 | {'foo': (('x', 'y'), np.arange(0, 60, 10).reshape(3, 2))}, 294 | coords=coords, 295 | ) 296 | with self.subTest('partial key'): 297 | inputs = [(xbeam.Key({'x': 0}), dataset)] 298 | temp_dir = self.create_tempdir().full_path 299 | inputs | xbeam.ChunksToZarr(temp_dir, template=dataset.chunk()) 300 | result = xarray.open_zarr(temp_dir, consolidated=True) 301 | xarray.testing.assert_identical(dataset, result) 302 | with self.subTest('split along partial key'): 303 | inputs = [(xbeam.Key({'x': 0}), dataset)] 304 | temp_dir = self.create_tempdir().full_path 305 | inputs | xbeam.SplitChunks({'x': 1}) | xbeam.ChunksToZarr( 306 | temp_dir, template=dataset.chunk({'x': 1}) 307 | ) 308 | result = xarray.open_zarr(temp_dir, consolidated=True) 309 | xarray.testing.assert_identical(dataset, result) 310 | with self.subTest('full key'): 311 | inputs = [(xbeam.Key({'x': 0, 'y': 0}), dataset)] 312 | temp_dir = self.create_tempdir().full_path 313 | inputs | xbeam.ChunksToZarr(temp_dir, template=dataset.chunk()) 314 | result = xarray.open_zarr(temp_dir, consolidated=True) 315 | xarray.testing.assert_identical(dataset, result) 316 | 317 | def test_chunks_to_zarr_dask_chunks(self): 318 | dataset = xarray.Dataset( 319 | {'foo': ('x', np.arange(0, 60, 10))}, 320 | coords={'x': np.arange(6)}, 321 | ) 322 | chunked = dataset.chunk() 323 | inputs = [ 324 | (xbeam.Key({'x': 0}), dataset.chunk(3)), 325 | ] 326 | temp_dir = self.create_tempdir().full_path 327 | inputs | xbeam.ChunksToZarr(temp_dir, chunked) 328 | result = xarray.open_zarr(temp_dir) 329 | xarray.testing.assert_identical(dataset, result) 330 | 331 | def test_dataset_to_zarr_simple(self): 332 | dataset = xarray.Dataset( 333 | {'foo': ('x', np.arange(0, 60, 10))}, 334 | coords={'x': np.arange(6)}, 335 | attrs={'meta': 'data'}, 336 | ) 337 | chunked = dataset.chunk({'x': 3}) 338 | temp_dir = self.create_tempdir().full_path 339 | test_util.EagerPipeline() | xbeam.DatasetToZarr(chunked, temp_dir) 340 | actual = xarray.open_zarr(temp_dir, consolidated=True) 341 | xarray.testing.assert_identical(actual, dataset) 342 | 343 | def test_dataset_to_zarr_unchunked(self): 344 | dataset = xarray.Dataset( 345 | {'foo': ('x', np.arange(0, 60, 10))}, 346 | ) 347 | temp_dir = self.create_tempdir().full_path 348 | with self.assertRaisesRegex( 349 | ValueError, 'dataset must be chunked or chunks must be provided' 350 | ): 351 | test_util.EagerPipeline() | xbeam.DatasetToZarr(dataset, temp_dir) 352 | 353 | def test_validate_zarr_chunk_accepts_partial_key(self): 354 | dataset = xarray.Dataset( 355 | {'foo': (('x', 'y'), np.zeros((3, 2)))}, 356 | coords={'x': np.arange(3), 'y': np.arange(2)}, 357 | ) 358 | # Should not raise an exception: 359 | xbeam.validate_zarr_chunk( 360 | key=xbeam.Key({'x': 0}), 361 | chunk=dataset, 362 | template=dataset.chunk(), 363 | zarr_chunks=None, 364 | ) 365 | 366 | def test_to_zarr_wrong_multiple_error(self): 367 | ds = xarray.Dataset({'foo': ('x', np.arange(6))}) 368 | inputs = [ 369 | (xbeam.Key({'x': 3}), ds.tail(3)), 370 | ] 371 | temp_dir = self.create_tempdir().full_path 372 | with self.assertRaisesRegex( 373 | ValueError, 374 | ( 375 | "chunk offset 3 along dimension 'x' is not a multiple of zarr " 376 | "chunks {'x': 4}" 377 | ), 378 | ): 379 | inputs | xbeam.ChunksToZarr( 380 | temp_dir, template=ds.chunk(4), zarr_chunks={'x': 4} 381 | ) 382 | 383 | def test_to_zarr_needs_consolidation_error(self): 384 | ds = xarray.Dataset({'foo': ('x', np.arange(6))}) 385 | inputs = [ 386 | (xbeam.Key({'x': 0}), ds.head(3)), 387 | (xbeam.Key({'x': 3}), ds.tail(3)), 388 | ] 389 | temp_dir = self.create_tempdir().full_path 390 | with self.assertRaisesRegex( 391 | ValueError, 'chunk is smaller than zarr chunks' 392 | ): 393 | inputs | xbeam.ChunksToZarr( 394 | temp_dir, template=ds.chunk(), zarr_chunks={'x': 6} 395 | ) 396 | with self.assertRaisesRegex( 397 | ValueError, 'chunk is smaller than zarr chunks' 398 | ): 399 | inputs | xbeam.ChunksToZarr(temp_dir, template=ds.chunk()) 400 | 401 | def test_to_zarr_fixed_template(self): 402 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 403 | template = dataset.chunk({'x': 3}) 404 | inputs = [ 405 | (xbeam.Key({'x': 0}), dataset.head(3)), 406 | (xbeam.Key({'x': 3}), dataset.tail(3)), 407 | ] 408 | temp_dir = self.create_tempdir().full_path 409 | chunks_to_zarr = xbeam.ChunksToZarr(temp_dir, template) 410 | self.assertEqual(chunks_to_zarr.template.chunks, {'x': (6,)}) 411 | self.assertEqual(chunks_to_zarr.zarr_chunks, {'x': 3}) 412 | inputs | chunks_to_zarr 413 | actual = xarray.open_zarr(temp_dir, consolidated=True) 414 | xarray.testing.assert_identical(actual, dataset) 415 | 416 | def test_infer_zarr_chunks(self): 417 | dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) 418 | 419 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset) 420 | self.assertEqual(chunks, {}) 421 | 422 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.chunk()) 423 | self.assertEqual(chunks, {'x': 6}) 424 | 425 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.head(0).chunk()) 426 | self.assertEqual(chunks, {'x': 0}) 427 | 428 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.chunk(3)) 429 | self.assertEqual(chunks, {'x': 3}) 430 | 431 | chunks = xbeam._src.zarr._infer_zarr_chunks(dataset.chunk(4)) 432 | self.assertEqual(chunks, {'x': 4}) 433 | 434 | with self.assertRaisesRegex( 435 | ValueError, 436 | re.escape( 437 | "Zarr cannot handle inconsistent chunk sizes along dimension 'x': " 438 | '(2, 4)' 439 | ), 440 | ): 441 | xbeam._src.zarr._infer_zarr_chunks(dataset.chunk({'x': (2, 4)})) 442 | 443 | with self.assertRaisesRegex( 444 | ValueError, 445 | re.escape( 446 | "Zarr cannot handle inconsistent chunk sizes along dimension 'x': " 447 | '(3, 2, 1)' 448 | ), 449 | ): 450 | xbeam._src.zarr._infer_zarr_chunks(dataset.chunk({'x': (3, 2, 1)})) 451 | 452 | def test_chunks_to_zarr_docs_demo(self): 453 | # verify that the ChunksToChunk demo from our docs works 454 | data = np.random.RandomState(0).randn(2920, 25, 53) 455 | ds = xarray.Dataset({'temperature': (('time', 'lat', 'lon'), data)}) 456 | chunks = {'time': 1000, 'lat': 25, 'lon': 53} 457 | temp_dir = self.create_tempdir().full_path 458 | ( 459 | test_util.EagerPipeline() 460 | | xbeam.DatasetToChunks(ds, chunks) 461 | | xbeam.ChunksToZarr( 462 | temp_dir, template=xbeam.make_template(ds), zarr_chunks=chunks 463 | ) 464 | ) 465 | result = xarray.open_zarr(temp_dir) 466 | xarray.testing.assert_identical(result, ds) 467 | 468 | 469 | if __name__ == '__main__': 470 | absltest.main() 471 | --------------------------------------------------------------------------------