├── .coveragerc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ ├── feature_request.yml │ └── misc.yml ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── release-drafter.yml └── workflows │ ├── main.yaml │ ├── pypi-release.yaml │ ├── release-drafter.yml │ └── testpypi-release.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.rst ├── asv_bench ├── asv.conf.json └── benchmarks │ ├── __init__.py │ ├── accessors.py │ ├── batches.py │ ├── generators.py │ └── loaders.py ├── ci └── requirements │ ├── asv.yml │ ├── doc.yml │ └── environment.yml ├── conftest.py ├── doc ├── Makefile ├── _static │ ├── logo.svg │ └── switcher.json ├── api.rst ├── conf.py ├── contributing.rst ├── demo.ipynb ├── index.rst ├── roadmap.rst ├── tutorials-and-presentations.rst └── user-guide │ ├── caching.ipynb │ ├── create-fashion-mnist-dataset.ipynb │ ├── index.rst │ ├── training-a-neural-network-with-Pytorch-and-xbatcher.ipynb │ └── training-a-neural-network-with-keras-and-xbatcher.ipynb ├── pyproject.toml ├── readthedocs.yaml └── xbatcher ├── __init__.py ├── accessors.py ├── generators.py ├── loaders ├── __init__.py ├── keras.py └── torch.py ├── testing.py ├── tests ├── __init__.py ├── test_accessors.py ├── test_generators.py ├── test_keras_loaders.py ├── test_print_versions.py └── test_torch_loaders.py └── util ├── __init__.py └── print_versions.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = xbatcher/* 4 | include = xbatcher/* 5 | omit = */setup.py 6 | */version.py 7 | xbatcher/tests/* 8 | */__init__.py 9 | [report] 10 | include = xbatcher/* 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: File a bug report to help us improve 3 | labels: [bug, "needs triage"] 4 | body: 5 | - type: textarea 6 | id: what-happened 7 | attributes: 8 | label: What happened? 9 | description: | 10 | Thanks for reporting a bug! Please describe what you were trying to get done. 11 | Tell us what happened, what went wrong. 12 | validations: 13 | required: true 14 | 15 | - type: textarea 16 | id: what-did-you-expect-to-happen 17 | attributes: 18 | label: What did you expect to happen? 19 | description: | 20 | Describe what you expected to happen. 21 | validations: 22 | required: false 23 | 24 | - type: textarea 25 | id: sample-code 26 | attributes: 27 | label: Minimal Complete Verifiable Example 28 | description: | 29 | Minimal, self-contained copy-pastable example that demonstrates the issue. For more details, check out: 30 | 31 | - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) 32 | - [Craft Minimal Bug Reports](http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 33 | 34 | This will be automatically formatted into code, so no need for markdown backticks. 35 | render: Python 36 | 37 | - type: textarea 38 | id: log-output 39 | attributes: 40 | label: Relevant log output 41 | description: Please copy and paste any relevant output. This will be automatically formatted into code, so no need for markdown backticks. 42 | render: Python 43 | 44 | - type: textarea 45 | id: extra 46 | attributes: 47 | label: Anything else we need to know? 48 | description: | 49 | Please describe any other information you want to share. 50 | 51 | - type: textarea 52 | id: show-versions 53 | attributes: 54 | label: Environment 55 | description: Please paste the output of `xbatcher.show_versions()`. This will be automatically formatted into code, so no need for markdown backticks. 56 | render: Python 57 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: ❓ Usage question 4 | url: https://github.com/xarray-contrib/xbatcher/discussions 5 | about: | 6 | Ask questions and discuss with other community members here. 7 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: 💡 Feature Request 2 | description: Suggest an idea for xbatcher 3 | labels: [feature request] 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Is your feature request related to a problem? 9 | description: | 10 | Please provide a clear and concise description of what the problem is. 11 | validations: 12 | required: true 13 | - type: textarea 14 | id: solution 15 | attributes: 16 | label: Describe the solution you'd like 17 | description: | 18 | A clear and concise description of what you want to happen. 19 | - type: textarea 20 | id: alternatives 21 | attributes: 22 | label: Describe alternatives you've considered 23 | description: | 24 | A clear and concise description of any alternative solutions or features you've considered. 25 | validations: 26 | required: false 27 | - type: textarea 28 | id: additional-context 29 | attributes: 30 | label: Additional context 31 | description: | 32 | Add any other context about the feature request here. 33 | validations: 34 | required: false 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/misc.yml: -------------------------------------------------------------------------------- 1 | name: 📝 Issue 2 | description: General issue, that's not a bug report. 3 | labels: ["needs triage"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Please describe your issue here. 9 | - type: textarea 10 | id: issue-description 11 | attributes: 12 | label: What is your issue? 13 | description: | 14 | Thank you for filing an issue! Please give us further information on how we can help you. 15 | placeholder: Please describe your issue. 16 | validations: 17 | required: true 18 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **Description of proposed changes** 2 | 3 | 4 | 5 | 6 | 7 | Fixes # 8 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | groups: 8 | pip-updates: 9 | patterns: 10 | - "*" 11 | - package-ecosystem: "github-actions" 12 | directory: "/" 13 | schedule: 14 | interval: "daily" 15 | groups: 16 | gh-actions: 17 | patterns: 18 | - "*" 19 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | categories: 4 | - title: "Features" 5 | label: "feature" 6 | - title: "Enhancement" 7 | label: "enhancement" 8 | - title: "Bug Fixes" 9 | label: "bug" 10 | - title: "Documentation" 11 | label: "documentation" 12 | - title: "Maintenance" 13 | label: "maintenance" 14 | change-template: "- $TITLE @$AUTHOR ([#$NUMBER]($URL))" 15 | change-title-escapes: '\<*_&#@' 16 | version-resolver: 17 | major: 18 | labels: 19 | - "major" 20 | minor: 21 | labels: 22 | - "feature" 23 | - "enhancement" 24 | default: patch 25 | exclude-labels: 26 | - "skip-changelog" 27 | template: | 28 | ## Release v$RESOLVED_VERSION (20YY/MM/DD) 29 | 30 | $CHANGES 31 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | paths-ignore: 9 | - ".github/workflows/*-release.yaml" 10 | - "asv_bench/**" 11 | - "doc/**" 12 | schedule: 13 | - cron: "0 0 * * *" 14 | 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.ref }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | test: 21 | name: ${{ matrix.python-version }}-build 22 | runs-on: ubuntu-latest 23 | strategy: 24 | matrix: 25 | python-version: ["3.10", "3.11", "3.12"] 26 | fail-fast: false 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Setup Python 30 | uses: actions/setup-python@v5.4.0 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | architecture: x64 34 | - uses: actions/cache@v4 35 | with: 36 | path: ~/.cache/pip 37 | key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.txt') }} 38 | restore-keys: | 39 | ${{ runner.os }}-pip- 40 | - run: | 41 | python -m pip install -e .[dev] 42 | python -m pip list 43 | - name: Running Tests 44 | run: | 45 | pytest --verbose --cov=. --cov-report=xml 46 | - name: Upload coverage to Codecov 47 | uses: codecov/codecov-action@v5.4.0 48 | if: ${{ matrix.python-version }} == 3.10 49 | with: 50 | file: ./coverage.xml 51 | fail_ci_if_error: false 52 | 53 | test-upstream: 54 | name: ${{ matrix.python-version }}-dev-build 55 | runs-on: ubuntu-latest 56 | strategy: 57 | matrix: 58 | python-version: ["3.10", "3.11", "3.12"] 59 | fail-fast: false 60 | steps: 61 | - uses: actions/checkout@v4 62 | - name: Setup Python 63 | uses: actions/setup-python@v5.4.0 64 | with: 65 | python-version: ${{ matrix.python-version }} 66 | architecture: x64 67 | - run: | 68 | python -m pip install -e .[dev] 69 | python -m pip install --upgrade \ 70 | git+https://github.com/dask/dask \ 71 | git+https://github.com/pydata/xarray 72 | python -m pip list 73 | - name: Running Tests 74 | run: | 75 | py.test --verbose --cov=. 76 | -------------------------------------------------------------------------------- /.github/workflows/pypi-release.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Upload xbatcher to PyPI 2 | on: 3 | release: 4 | types: 5 | - published 6 | # Runs for pull requests should be disabled other than for testing purposes 7 | #pull_request: 8 | # branches: 9 | # - main 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build-artifacts: 16 | runs-on: ubuntu-latest 17 | if: github.repository == 'xarray-contrib/xbatcher' 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v5.4.0 23 | name: Install Python 24 | with: 25 | python-version: 3.11 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install build twine 31 | 32 | # This step is only necessary for testing purposes and for TestPyPI 33 | - name: Fix up version string for TestPyPI 34 | if: ${{ !startsWith(github.ref, 'refs/tags') }} 35 | run: | 36 | # Change setuptools-scm local_scheme to "no-local-version" so the 37 | # local part of the version isn't included, making the version string 38 | # compatible with PyPI. 39 | sed --in-place "s/node-and-date/no-local-version/g" pyproject.toml 40 | 41 | - name: Build tarball and wheels 42 | run: | 43 | git clean -xdf 44 | git restore -SW . 45 | python -m build 46 | - name: Check built artifacts 47 | run: | 48 | python -m twine check --strict dist/* 49 | pwd 50 | if [ -f dist/xbatcher-0.0.0.tar.gz ]; then 51 | echo "❌ INVALID VERSION NUMBER" 52 | exit 1 53 | else 54 | echo "✅ Looks good" 55 | fi 56 | - uses: actions/upload-artifact@v4 57 | with: 58 | name: releases 59 | path: dist 60 | 61 | test-built-dist: 62 | needs: build-artifacts 63 | runs-on: ubuntu-latest 64 | steps: 65 | - uses: actions/setup-python@v5.4.0 66 | name: Install Python 67 | with: 68 | python-version: 3.11 69 | - uses: actions/download-artifact@v4 70 | with: 71 | name: releases 72 | path: dist 73 | - name: List contents of built dist 74 | run: | 75 | ls -ltrh 76 | ls -ltrh dist 77 | - name: Verify the built dist/wheel is valid 78 | run: | 79 | python -m pip install --upgrade pip 80 | python -m pip install dist/xbatcher*.whl 81 | python -m xbatcher.util.print_versions 82 | - name: Publish package to TestPyPI 83 | uses: pypa/gh-action-pypi-publish@v1.12.4 84 | with: 85 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 86 | repository-url: https://test.pypi.org/legacy/ 87 | # verbose: true 88 | 89 | upload-to-pypi: 90 | needs: test-built-dist 91 | if: github.event_name == 'release' 92 | runs-on: ubuntu-latest 93 | steps: 94 | - uses: actions/download-artifact@v4 95 | with: 96 | name: releases 97 | path: dist 98 | - name: Publish package to PyPI 99 | uses: pypa/gh-action-pypi-publish@v1.12.4 100 | with: 101 | password: ${{ secrets.PYPI_API_TOKEN }} 102 | # verbose: true 103 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | # branches to consider in the event; optional, defaults to all 6 | branches: 7 | - main 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | update_release_draft: 14 | permissions: 15 | # write permission is required to create a github release 16 | contents: write 17 | # write permission is required for autolabeler 18 | # otherwise, read permission is required at least 19 | pull-requests: write 20 | runs-on: ubuntu-latest 21 | steps: 22 | # (Optional) GitHub Enterprise requires GHE_HOST variable set 23 | #- name: Set GHE_HOST 24 | # run: | 25 | # echo "GHE_HOST=${GITHUB_SERVER_URL##https:\/\/}" >> $GITHUB_ENV 26 | 27 | # Drafts your next Release notes as Pull Requests are merged into "main" 28 | - uses: release-drafter/release-drafter@v6 29 | # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml 30 | # with: 31 | # config-name: my-config.yml 32 | # disable-autolabeler: true 33 | env: 34 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 35 | -------------------------------------------------------------------------------- /.github/workflows/testpypi-release.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Upload xbatcher to TestPyPI 2 | on: 3 | push: 4 | branches: 5 | - main 6 | # pull_request: 7 | # branches: 8 | # - main 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | publish-testpypi: 15 | name: Publish to Test PyPI 16 | runs-on: ubuntu-latest 17 | if: github.repository == 'xarray-contrib/xbatcher' 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v4 22 | with: 23 | # fetch all history so that setuptools-scm works 24 | fetch-depth: 0 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v5.4.0 28 | with: 29 | python-version: "3.11" 30 | 31 | - name: Install dependencies 32 | run: python -m pip install build 33 | 34 | - name: Fix up version string for TestPyPI 35 | if: ${{ !startsWith(github.ref, 'refs/tags') }} 36 | run: | 37 | sed --in-place "s/node-and-date/no-local-version/g" pyproject.toml 38 | 39 | - name: Build tarball and wheels 40 | run: | 41 | python -m build 42 | echo "Generated files:" 43 | ls -lh dist/ 44 | 45 | - name: Verify the built dist/wheel is valid 46 | run: | 47 | python -m pip install --upgrade pip 48 | python -m pip install dist/xbatcher*.whl 49 | python -m xbatcher.util.print_versions 50 | 51 | - name: Publish package to TestPyPI 52 | uses: pypa/gh-action-pypi-publish@v1.12.4 53 | with: 54 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 55 | repository-url: https://test.pypi.org/legacy/ 56 | # verbose: true 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # asv environments 47 | .asv 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | 56 | # Sphinx documentation 57 | doc/_build/ 58 | doc/generated/ 59 | 60 | # PyBuilder 61 | target/ 62 | 63 | # notebook 64 | */.ipynb_checkpoints/* 65 | 66 | # tests 67 | .pytest_cache/* 68 | .mypy_cache/ 69 | .ipynb_checkpoints/ 70 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_schedule: quarterly 3 | autofix_prs: false 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v5.0.0 8 | hooks: 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-json 13 | exclude: "asv_bench/asv.conf.json" 14 | - id: check-yaml 15 | 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: "v0.8.6" 18 | hooks: 19 | - id: ruff 20 | args: ["--fix"] 21 | - id: ruff-format 22 | 23 | - repo: https://github.com/pre-commit/mirrors-prettier 24 | rev: v4.0.0-alpha.8 25 | hooks: 26 | - id: prettier 27 | 28 | - repo: https://github.com/pre-commit/mirrors-mypy 29 | rev: v1.14.1 30 | hooks: 31 | - id: mypy 32 | additional_dependencies: [ 33 | # Type stubs 34 | types-setuptools, 35 | # Dependencies that are typed 36 | numpy, 37 | xarray, 38 | ] 39 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | 3 | title: xbatcher 4 | doi: 10.5281/zenodo.13776824 5 | type: software 6 | url: "https://github.com/xarray-contrib/xbatcher" 7 | abstract: >- 8 | Xbatcher is a small library for iterating Xarray DataArrays and Datasets in 9 | batches. The goal is to make it easy to feed Xarray objects to machine 10 | learning libraries such as PyTorch or TensorFlow. 11 | 12 | keywords: 13 | - Xarray 14 | - Machine Learning 15 | - Deep Learning 16 | - PyTorch 17 | - TensorFlow 18 | - Dask 19 | version: 0.4.0 20 | date-released: 2024-09-17 21 | message: "If you use this software, please cite it as below." 22 | authors: 23 | - family-names: Jones 24 | given-names: Max 25 | orcid: https://orcid.org/0000-0003-0180-8928 26 | - family-names: Abernathey 27 | given-names: Ryan 28 | orcid: https://orcid.org/0000-0001-5999-4917 29 | - family-names: Hamman 30 | given-names: Joseph 31 | orcid: https://orcid.org/0000-0001-7479-8439 32 | - family-names: Banihirwe 33 | given-names: Anderson 34 | orcid: https://orcid.org/0000-0001-6583-571X 35 | - family-names: Leong 36 | given-names: Wei Ji 37 | orcid: https://orcid.org/0000-0003-2354-1988 38 | - family-names: Cindy 39 | given-names: Chiao 40 | - family-names: Bell 41 | given-names: Ryan 42 | - family-names: Hagen 43 | given-names: Raphael 44 | - family-names: Scott 45 | given-names: Richard 46 | - family-names: Bednar 47 | given-names: James 48 | - family-names: Vandal 49 | given-names: TJ 50 | - family-names: Bourbeau 51 | given-names: James 52 | - family-names: Jackson 53 | given-names: Robert 54 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Xbatcher's contributor guidelines [can be found in the online documentation](https://xbatcher.readthedocs.io/en/latest/contributing.html). 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Xbatcher contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | xbatcher: Batch Generation from Xarray Datasets 2 | =============================================== 3 | 4 | |Build Status| |codecov| |docs| |pypi| |conda-forge| |license| |zenodo| 5 | 6 | 7 | Xbatcher is a small library for iterating Xarray DataArrays and Datasets in 8 | batches. The goal is to make it easy to feed Xarray objects to machine 9 | learning libraries such as PyTorch_ or TensorFlow_. View the |docs| for more 10 | info. 11 | 12 | .. _TensorFlow: https://www.tensorflow.org/ 13 | 14 | .. _PyTorch: https://pytorch.org/ 15 | 16 | 17 | .. |Build Status| image:: https://github.com/xarray-contrib/xbatcher/workflows/CI/badge.svg 18 | :target: https://github.com/xarray-contrib/xbatcher/actions 19 | :alt: github actions build status 20 | .. |codecov| image:: https://codecov.io/gh/xarray-contrib/xbatcher/branch/main/graph/badge.svg 21 | :target: https://codecov.io/gh/xarray-contrib/xbatcher 22 | :alt: code coverage 23 | .. |docs| image:: http://readthedocs.org/projects/xbatcher/badge/?version=latest 24 | :target: http://xbatcher.readthedocs.org/en/latest/?badge=latest 25 | :alt: docs 26 | .. |pypi| image:: https://img.shields.io/pypi/v/xbatcher.svg 27 | :target: https://pypi.python.org/pypi/xbatcher 28 | :alt: pypi 29 | .. |conda-forge| image:: https://img.shields.io/conda/vn/conda-forge/xbatcher.svg 30 | :target: https://anaconda.org/conda-forge/xbatcher 31 | :alt: conda-forge 32 | .. |license| image:: https://img.shields.io/github/license/xarray-contrib/xbatcher.svg 33 | :target: https://github.com/xarray-contrib/xbatcher 34 | :alt: license 35 | .. |zenodo| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.13776824.svg 36 | :target: https://doi.org/10.5281/zenodo.13776824 37 | :alt: zenodo 38 | 39 | Installation 40 | ------------ 41 | 42 | Xbatcher can be installed from PyPI as:: 43 | 44 | python -m pip install xbatcher 45 | 46 | Or via Conda as:: 47 | 48 | conda install -c conda-forge xbatcher 49 | 50 | Or from source as:: 51 | 52 | python -m pip install git+https://github.com/xarray-contrib/xbatcher.git 53 | 54 | .. note:: 55 | The required dependencies installed with Xbatcher are `Xarray `_, 56 | `Dask `_, and `NumPy `_. 57 | You will need to separately install `TensorFlow `_ 58 | or `PyTorch `_ to use those data loaders or 59 | Xarray accessors. `Review the installation instructions `_ 60 | for more details. 61 | 62 | Documentation 63 | ------------- 64 | 65 | Documentation is hosted on ReadTheDocs: https://xbatcher.readthedocs.org 66 | 67 | License 68 | ------------ 69 | 70 | Apache License 2.0, see LICENSE file. 71 | 72 | Acknowledgements 73 | ---------------- 74 | 75 | This work was funded in part by: 76 | 77 | NASA ACCESS19-0049: Pangeo ML: Open Source Tools and Pipelines for Scalable Machine Learning Using NASA Earth Observation Data 78 | 79 | This work was motivated by many conversations in the Pangeo community and Pangeo ML working group 80 | -------------------------------------------------------------------------------- /asv_bench/asv.conf.json: -------------------------------------------------------------------------------- 1 | { 2 | // The version of the config file format. Do not change, unless 3 | // you know what you are doing. 4 | "version": 1, 5 | 6 | // The name of the project being benchmarked 7 | "project": "xbatcher", 8 | 9 | // The project's homepage 10 | "project_url": "https://xbatcher.readthedocs.io/", 11 | 12 | // The URL or local path of the source code repository for the 13 | // project being benchmarked 14 | "repo": "..", 15 | 16 | // The Python project's subdirectory in your repo. If missing or 17 | // the empty string, the project is assumed to be located at the root 18 | // of the repository. 19 | // "repo_subdir": "", 20 | 21 | // Customizable commands for building, installing, and 22 | // uninstalling the project. See asv.conf.json documentation. 23 | // 24 | // "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"], 25 | // "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"], 26 | "build_command": [ 27 | "python -m pip install build", 28 | "python -m build --wheel -o {build_cache_dir} {build_dir}" 29 | //"PIP_NO_BUILD_ISOLATION=false python -m pip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}" 30 | ], 31 | 32 | // List of branches to benchmark. If not provided, defaults to "master" 33 | // (for git) or "default" (for mercurial). 34 | "branches": ["main"], // for git 35 | 36 | // The DVCS being used. If not set, it will be automatically 37 | // determined from "repo" by looking at the protocol in the URL 38 | // (if remote), or by looking for special directories, such as 39 | // ".git" (if local). 40 | "dvcs": "git", 41 | 42 | // The tool to use to create environments. May be "conda", 43 | // "virtualenv" or other value depending on the plugins in use. 44 | // If missing or the empty string, the tool will be automatically 45 | // determined by looking for tools on the PATH environment 46 | // variable. 47 | "environment_type": "conda", 48 | 49 | // timeout in seconds for installing any dependencies in environment 50 | // defaults to 10 min 51 | "install_timeout": 600, 52 | 53 | // the base URL to show a commit for the project. 54 | // "show_commit_url": "http://github.com/xarray-contrib/xbatcher/commit/", 55 | 56 | // The Pythons you'd like to test against. If not provided, defaults 57 | // to the current version of Python used to run `asv`. 58 | // "pythons": ["3.8"], 59 | 60 | // The list of conda channel names to be searched for benchmark 61 | // dependency packages in the specified order 62 | "conda_channels": ["conda-forge"], 63 | 64 | // A conda environment file that is used for environment creation. 65 | "conda_environment_file": "../ci/requirements/asv.yml", 66 | 67 | // The matrix of dependencies to test. Each key of the "req" 68 | // requirements dictionary is the name of a package (in PyPI) and 69 | // the values are version numbers. An empty list or empty string 70 | // indicates to just test against the default (latest) 71 | // version. null indicates that the package is to not be 72 | // installed. If the package to be tested is only available from 73 | // PyPi, and the 'environment_type' is conda, then you can preface 74 | // the package name by 'pip+', and the package will be installed 75 | // via pip (with all the conda available packages installed first, 76 | // followed by the pip installed packages). 77 | // 78 | // The ``@env`` and ``@env_nobuild`` keys contain the matrix of 79 | // environment variables to pass to build and benchmark commands. 80 | // An environment will be created for every combination of the 81 | // cartesian product of the "@env" variables in this matrix. 82 | // Variables in "@env_nobuild" will be passed to every environment 83 | // during the benchmark phase, but will not trigger creation of 84 | // new environments. A value of ``null`` means that the variable 85 | // will not be set for the current combination. 86 | // 87 | // "matrix": { 88 | // "req": { 89 | // "numpy": ["1.6", "1.7"], 90 | // "six": ["", null], // test with and without six installed 91 | // "pip+emcee": [""] // emcee is only available for install with pip. 92 | // }, 93 | // "env": {"ENV_VAR_1": ["val1", "val2"]}, 94 | // "env_nobuild": {"ENV_VAR_2": ["val3", null]}, 95 | // }, 96 | // "matrix": { 97 | // "xarray": [""], 98 | // "numpy": [""], 99 | // "dask": [""], 100 | // }, 101 | 102 | // Combinations of libraries/python versions can be excluded/included 103 | // from the set to test. Each entry is a dictionary containing additional 104 | // key-value pairs to include/exclude. 105 | // 106 | // An exclude entry excludes entries where all values match. The 107 | // values are regexps that should match the whole string. 108 | // 109 | // An include entry adds an environment. Only the packages listed 110 | // are installed. The 'python' key is required. The exclude rules 111 | // do not apply to includes. 112 | // 113 | // In addition to package names, the following keys are available: 114 | // 115 | // - python 116 | // Python version, as in the *pythons* variable above. 117 | // - environment_type 118 | // Environment type, as above. 119 | // - sys_platform 120 | // Platform, as in sys.platform. Possible values for the common 121 | // cases: 'linux2', 'win32', 'cygwin', 'darwin'. 122 | // - req 123 | // Required packages 124 | // - env 125 | // Environment variables 126 | // - env_nobuild 127 | // Non-build environment variables 128 | // 129 | // "exclude": [ 130 | // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows 131 | // {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda 132 | // {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1 133 | // ], 134 | // 135 | // "include": [ 136 | // // additional env for python2.7 137 | // {"python": "2.7", "req": {"numpy": "1.8"}, "env_nobuild": {"FOO": "123"}}, 138 | // // additional env if run on windows+conda 139 | // {"platform": "win32", "environment_type": "conda", "python": "2.7", "req": {"libpython": ""}}, 140 | // ], 141 | 142 | // The directory (relative to the current directory) that benchmarks are 143 | // stored in. If not provided, defaults to "benchmarks" 144 | "benchmark_dir": "benchmarks", 145 | 146 | // The directory (relative to the current directory) to cache the Python 147 | // environments in. If not provided, defaults to "env" 148 | "env_dir": ".asv/env", 149 | 150 | // The directory (relative to the current directory) that raw benchmark 151 | // results are stored in. If not provided, defaults to "results". 152 | "results_dir": ".asv/results", 153 | 154 | // The directory (relative to the current directory) that the html tree 155 | // should be written to. If not provided, defaults to "html". 156 | "html_dir": ".asv/html" 157 | 158 | // The number of characters to retain in the commit hashes. 159 | // "hash_length": 8, 160 | 161 | // `asv` will cache results of the recent builds in each 162 | // environment, making them faster to install next time. This is 163 | // the number of builds to keep, per environment. 164 | // "build_cache_size": 2, 165 | 166 | // The commits after which the regression search in `asv publish` 167 | // should start looking for regressions. Dictionary whose keys are 168 | // regexps matching to benchmark names, and values corresponding to 169 | // the commit (exclusive) after which to start looking for 170 | // regressions. The default is to start from the first commit 171 | // with results. If the commit is `null`, regression detection is 172 | // skipped for the matching benchmark. 173 | // 174 | // "regressions_first_commits": { 175 | // "some_benchmark": "352cdf", // Consider regressions only after this commit 176 | // "another_benchmark": null, // Skip regression detection altogether 177 | // }, 178 | 179 | // The thresholds for relative change in results, after which `asv 180 | // publish` starts reporting regressions. Dictionary of the same 181 | // form as in ``regressions_first_commits``, with values 182 | // indicating the thresholds. If multiple entries match, the 183 | // maximum is taken. If no entry matches, the default is 5%. 184 | // 185 | // "regressions_thresholds": { 186 | // "some_benchmark": 0.01, // Threshold of 1% 187 | // "another_benchmark": 0.5, // Threshold of 50% 188 | // }, 189 | } 190 | -------------------------------------------------------------------------------- /asv_bench/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def parameterized(names, params): 5 | """ 6 | Copied from xarray benchmarks: 7 | https://github.com/pydata/xarray/blob/main/asv_bench/benchmarks/__init__.py#L9-L15 8 | """ 9 | 10 | def decorator(func): 11 | func.param_names = names 12 | func.params = params 13 | return func 14 | 15 | return decorator 16 | 17 | 18 | def randn(shape, frac_nan=None, chunks=None, seed=0): 19 | """ 20 | Copied from xarray benchmarks: 21 | https://github.com/pydata/xarray/blob/main/asv_bench/benchmarks/__init__.py#L32-L46 22 | """ 23 | rng = np.random.RandomState(seed) 24 | if chunks is None: 25 | x = rng.standard_normal(shape) 26 | else: 27 | import dask.array as da 28 | 29 | rng = da.random.RandomState(seed) 30 | x = rng.standard_normal(shape, chunks=chunks) 31 | 32 | if frac_nan is not None: 33 | inds = rng.choice(range(x.size), int(x.size * frac_nan)) 34 | x.flat[inds] = np.nan 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /asv_bench/benchmarks/accessors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | 5 | import xbatcher # noqa: F401 6 | 7 | from . import parameterized, randn 8 | 9 | nx = 250 10 | ny = 50 11 | nt = 10 12 | 13 | randn_xyt = randn((nx, ny, nt), frac_nan=0.1) 14 | 15 | 16 | class Accessor: 17 | def setup(self, *args, **kwargs): 18 | self.ds = xr.Dataset( 19 | { 20 | 'var1': (('x', 'y', 't'), randn_xyt), 21 | }, 22 | coords={ 23 | 'x': np.arange(nx), 24 | 'y': np.linspace(0, 1, ny), 25 | 't': pd.date_range('1970-01-01', periods=nt, freq='D'), 26 | }, 27 | ) 28 | 29 | @parameterized( 30 | ['input_dims'], 31 | ([{'x': 10}, {'x': 10, 'y': 5}, {'x': 10, 'y': 5, 't': 2}],), 32 | ) 33 | def time_input_dims(self, input_dims): 34 | """ 35 | Benchmark simple batch generation case using xarray accessor 36 | Equivalent to subset of ``time_batch_input()``. 37 | """ 38 | bg = self.ds.batch.generator(input_dims=input_dims) 39 | bg[0] 40 | -------------------------------------------------------------------------------- /asv_bench/benchmarks/batches.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | 5 | from xbatcher import BatchGenerator 6 | 7 | from . import randn 8 | 9 | nx = 250 10 | ny = 50 11 | nt = 10 12 | 13 | randn_xyt = randn((nx, ny, nt), frac_nan=0.1) 14 | 15 | 16 | class Base: 17 | def setup(self): 18 | self.ds = xr.Dataset( 19 | { 20 | 'var1': (('x', 'y', 't'), randn_xyt), 21 | }, 22 | coords={ 23 | 'x': np.arange(nx), 24 | 'y': np.linspace(0, 1, ny), 25 | 't': pd.date_range('1970-01-01', periods=nt, freq='D'), 26 | }, 27 | ) 28 | 29 | 30 | class NoPreload(Base): 31 | """ 32 | Get a batch from the generator without computing dask arrays. 33 | """ 34 | 35 | def setup(self): 36 | super().setup() 37 | ds_dask = self.ds.chunk({'t': 2}) 38 | self.bg = BatchGenerator(ds_dask, input_dims={'t': 2}, preload_batch=False) 39 | 40 | def time_next_batch(self): 41 | """ 42 | Get a batch 43 | """ 44 | next(iter(self.bg)) 45 | 46 | 47 | class OneInputDim(Base): 48 | """ 49 | Get a batch from the generator with one input_dim specified. 50 | """ 51 | 52 | def setup(self): 53 | super().setup() 54 | self.bg = BatchGenerator(self.ds, input_dims={'x': 10}) 55 | 56 | def time_next_batch(self): 57 | """ 58 | Get a batch 59 | """ 60 | next(iter(self.bg)) 61 | 62 | 63 | class AllInputDim(Base): 64 | """ 65 | Get a batch from the generator with all dimensions specified in input_dims. 66 | """ 67 | 68 | def setup(self): 69 | super().setup() 70 | self.bg = BatchGenerator(self.ds, input_dims={'x': 10, 'y': 10, 't': 5}) 71 | 72 | def time_next_batch(self): 73 | """ 74 | Get a batch 75 | """ 76 | next(iter(self.bg)) 77 | 78 | 79 | class InputDimInputOverlap(Base): 80 | """ 81 | Get a batch from the generator using input_dims and input_overlap. 82 | """ 83 | 84 | def setup(self): 85 | super().setup() 86 | self.bg = BatchGenerator( 87 | self.ds, input_dims={'x': 10, 'y': 10}, input_overlap={'x': 5, 'y': 5} 88 | ) 89 | 90 | def time_next_batch(self): 91 | """ 92 | Get a batch 93 | """ 94 | next(iter(self.bg)) 95 | 96 | 97 | class InputDimConcat(Base): 98 | """ 99 | Get a batch from the generator with input_dims and concat_input_dims 100 | """ 101 | 102 | def setup(self): 103 | super().setup() 104 | self.bg = BatchGenerator( 105 | self.ds, input_dims={'x': 10, 'y': 10}, concat_input_dims=True 106 | ) 107 | 108 | def time_next_batch(self): 109 | """ 110 | Get a batch 111 | """ 112 | next(iter(self.bg)) 113 | 114 | 115 | class InputDimBatchDim(Base): 116 | """ 117 | Get a batch from the generator with input_dims and batch_dims 118 | """ 119 | 120 | def setup(self): 121 | super().setup() 122 | self.bg = BatchGenerator( 123 | self.ds, input_dims={'x': 10, 'y': 10}, batch_dims={'t': 2} 124 | ) 125 | 126 | def time_next_batch(self): 127 | """ 128 | Get a batch 129 | """ 130 | next(iter(self.bg)) 131 | 132 | 133 | class InputDimBatchDimConcat(Base): 134 | """ 135 | Get a batch from the generator with input_dims, batch_dims and concat_input_dim 136 | """ 137 | 138 | def setup(self): 139 | super().setup() 140 | self.bg = BatchGenerator( 141 | self.ds, 142 | input_dims={'x': 5, 'y': 5}, 143 | batch_dims={'x': 10, 'y': 10}, 144 | concat_input_dims=True, 145 | ) 146 | 147 | def time_next_batch(self): 148 | """ 149 | Get a batch 150 | """ 151 | next(iter(self.bg)) 152 | 153 | 154 | class InputDimInputOverlapConcat(Base): 155 | """ 156 | Get a batch from the generator with input_dims, input_overlap and concat_input_dim 157 | """ 158 | 159 | def setup(self): 160 | super().setup() 161 | self.bg = BatchGenerator( 162 | self.ds, 163 | input_dims={'x': 10, 'y': 10}, 164 | input_overlap={'x': 5, 'y': 5}, 165 | concat_input_dims=True, 166 | ) 167 | 168 | def time_next_batch(self): 169 | """ 170 | Get a batch 171 | """ 172 | next(iter(self.bg)) 173 | -------------------------------------------------------------------------------- /asv_bench/benchmarks/generators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | 5 | from xbatcher import BatchGenerator 6 | 7 | from . import parameterized, randn 8 | 9 | nx = 250 10 | ny = 50 11 | nt = 10 12 | 13 | randn_xyt = randn((nx, ny, nt), frac_nan=0.1) 14 | 15 | 16 | class Generator: 17 | def setup(self, *args, **kwargs): 18 | self.ds = xr.Dataset( 19 | { 20 | 'var1': (('x', 'y', 't'), randn_xyt), 21 | }, 22 | coords={ 23 | 'x': np.arange(nx), 24 | 'y': np.linspace(0, 1, ny), 25 | 't': pd.date_range('1970-01-01', periods=nt, freq='D'), 26 | }, 27 | ) 28 | 29 | @parameterized(['preload_batch'], ([True, False])) 30 | def time_batch_preload(self, preload_batch): 31 | """ 32 | Construct a generator on a chunked DataSet with and without preloading 33 | batches. 34 | """ 35 | ds_dask = self.ds.chunk({'t': 2}) 36 | BatchGenerator(ds_dask, input_dims={'t': 2}, preload_batch=preload_batch) 37 | 38 | @parameterized( 39 | ['input_dims'], 40 | ([{'x': 10}, {'x': 10, 'y': 5}, {'x': 10, 'y': 5, 't': 2}],), 41 | ) 42 | def time_input_dims(self, input_dims): 43 | """ 44 | Benchmark simple batch generation case. 45 | """ 46 | BatchGenerator( 47 | self.ds, 48 | input_dims=input_dims, 49 | ) 50 | 51 | def time_input_dims_and_input_overlap(self): 52 | """ 53 | Benchmark simple batch generation case. 54 | """ 55 | BatchGenerator( 56 | self.ds, input_dims={'x': 10, 'y': 10}, input_overlap={'x': 5, 'y': 5} 57 | ) 58 | 59 | @parameterized(['concat_input_dims'], (['True', 'False'])) 60 | def time_input_dims_and_concat_input_dims(self, concat_input_dims): 61 | """ 62 | Benchmark concat_input_dims 63 | """ 64 | BatchGenerator( 65 | self.ds, input_dims={'x': 10, 'y': 5}, concat_input_dims=concat_input_dims 66 | ) 67 | 68 | @parameterized( 69 | ['input_dims', 'batch_dims'], 70 | ([{'x': 10}, {'x': 10, 'y': 5}],), 71 | ) 72 | def time_input_dims_and_batch_dims(self, input_dims): 73 | """ 74 | Benchmark batch generator with input_dims and batch_dims. 75 | """ 76 | BatchGenerator(self.ds, input_dims=input_dims, batch_dims={'t': 2}) 77 | 78 | @parameterized( 79 | ['concat_input_dims'], 80 | ([True, False]), 81 | ) 82 | def time_input_dims_batch_dims_and_concat_input_dims(self, concat_input_dims): 83 | """ 84 | Construct a generator on a DataSet with and without concatenating 85 | chunks specified by ``input_dims`` into the batch dimension. 86 | """ 87 | BatchGenerator( 88 | self.ds, 89 | input_dims={'x': 10, 'y': 5}, 90 | batch_dims={'x': 20, 'y': 10}, 91 | concat_input_dims=concat_input_dims, 92 | ) 93 | 94 | @parameterized( 95 | ['concat_input_dims'], 96 | ([True, False]), 97 | ) 98 | def time_input_dims_input_overlap_and_concat_input_dims(self, concat_input_dims): 99 | """ 100 | Construct a generator on a DataSet with and without concatenating 101 | chunks specified by ``input_dims`` into the batch dimension. 102 | """ 103 | BatchGenerator( 104 | self.ds, 105 | input_dims={'x': 10, 'y': 10}, 106 | input_overlap={'x': 5, 'y': 5}, 107 | concat_input_dims=concat_input_dims, 108 | ) 109 | -------------------------------------------------------------------------------- /asv_bench/benchmarks/loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import xarray as xr 4 | 5 | from xbatcher import BatchGenerator 6 | from xbatcher.loaders.torch import IterableDataset, MapDataset 7 | 8 | from . import randn 9 | 10 | nx = 250 11 | ny = 50 12 | 13 | randn_xy = randn((nx, ny), frac_nan=0.1) 14 | randn_y = randn((ny), frac_nan=0.1) 15 | 16 | 17 | class TorchLoader: 18 | def setup(self, *args, **kwargs): 19 | self.ds = xr.Dataset( 20 | { 21 | 'var1': (('x', 'y'), randn_xy), 22 | 'var2': (('y'), randn_y), 23 | }, 24 | coords={ 25 | 'x': np.arange(nx), 26 | 'y': np.linspace(0, 1, ny), 27 | }, 28 | ) 29 | self.x_gen = BatchGenerator(self.ds['var1'], {'y': 10}) 30 | self.y_gen = BatchGenerator(self.ds['var2'], {'y': 10}) 31 | 32 | def time_map_dataset(self): 33 | """ 34 | Benchmark MapDataset integration with torch DataLoader. 35 | """ 36 | dataset = MapDataset(self.x_gen, self.y_gen) 37 | loader = torch.utils.data.DataLoader(dataset) 38 | next(iter(loader)) 39 | 40 | def time_iterable_dataset(self): 41 | """ 42 | Benchmark IterableDataset integration with torch DataLoader. 43 | """ 44 | dataset = IterableDataset(self.x_gen, self.y_gen) 45 | loader = torch.utils.data.DataLoader(dataset) 46 | next(iter(loader)) 47 | -------------------------------------------------------------------------------- /ci/requirements/asv.yml: -------------------------------------------------------------------------------- 1 | name: xbatcher-benchmarks 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | # Required dependencies 7 | - python=3.10 8 | - dask 9 | - numpy 10 | - xarray 11 | - pytorch 12 | - tensorflow 13 | # Xbatcher installation 14 | - pip 15 | -------------------------------------------------------------------------------- /ci/requirements/doc.yml: -------------------------------------------------------------------------------- 1 | name: xbatcher-docs 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python=3.10 7 | - dask 8 | - pydata-sphinx-theme 9 | - ipython 10 | - matplotlib 11 | - numpy 12 | - numpydoc 13 | - pytest 14 | - sphinx<6 15 | - sphinx-autosummary-accessors 16 | - sphinx-copybutton 17 | - sphinx-design 18 | - xarray 19 | # For examples 20 | - s3fs 21 | - ipykernel 22 | - nbsphinx 23 | - netcdf4 24 | - pooch 25 | - zarr 26 | - pytorch 27 | - keras 28 | - tensorflow 29 | # Editable xbatcher installation 30 | - pip 31 | 32 | - pip: 33 | # relative to this file. Needs to be editable to be accepted. 34 | - -e ../.. 35 | -------------------------------------------------------------------------------- /ci/requirements/environment.yml: -------------------------------------------------------------------------------- 1 | name: xbatcher-tests 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | # Required dependencies 7 | - python=3.10 8 | - dask 9 | - numpy 10 | - xarray 11 | # Dev dependencies 12 | - s3fs 13 | - asv 14 | - pre-commit 15 | - pytest 16 | - pytest-cov 17 | - pytorch 18 | - tensorflow 19 | - zarr 20 | # Style checks 21 | - black 22 | - blackdoc 23 | - docformatter 24 | - flake8 25 | - isort>=5 26 | - pylint 27 | # Xbatcher installation 28 | - pip 29 | - pip: 30 | # relative to this file. Needs to be editable to be accepted. 31 | - -e ../.. 32 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import pytest 3 | 4 | 5 | @pytest.fixture(autouse=True) 6 | def add_standard_imports(doctest_namespace, tmpdir): 7 | import numpy as np 8 | 9 | # always seed numpy.random to make the examples deterministic 10 | np.random.seed(0) 11 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | rm -rf generated/* 52 | 53 | html: 54 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 55 | @echo 56 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 57 | 58 | dirhtml: 59 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 60 | @echo 61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 62 | 63 | singlehtml: 64 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 65 | @echo 66 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 67 | 68 | pickle: 69 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 70 | @echo 71 | @echo "Build finished; now you can process the pickle files." 72 | 73 | json: 74 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 75 | @echo 76 | @echo "Build finished; now you can process the JSON files." 77 | 78 | htmlhelp: 79 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 80 | @echo 81 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 82 | ".hhp project file in $(BUILDDIR)/htmlhelp." 83 | 84 | qthelp: 85 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 86 | @echo 87 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 88 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 89 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/xgcm.qhcp" 90 | @echo "To view the help file:" 91 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/xgcm.qhc" 92 | 93 | devhelp: 94 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 95 | @echo 96 | @echo "Build finished." 97 | @echo "To view the help file:" 98 | @echo "# mkdir -p $$HOME/.local/share/devhelp/xgcm" 99 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/xgcm" 100 | @echo "# devhelp" 101 | 102 | epub: 103 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 104 | @echo 105 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 106 | 107 | latex: 108 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 109 | @echo 110 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 111 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 112 | "(use \`make latexpdf' here to do that automatically)." 113 | 114 | latexpdf: 115 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 116 | @echo "Running LaTeX files through pdflatex..." 117 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 118 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 119 | 120 | latexpdfja: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through platex and dvipdfmx..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | text: 127 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 128 | @echo 129 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 130 | 131 | man: 132 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 133 | @echo 134 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 135 | 136 | texinfo: 137 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 138 | @echo 139 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 140 | @echo "Run \`make' in that directory to run these through makeinfo" \ 141 | "(use \`make info' here to do that automatically)." 142 | 143 | info: 144 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 145 | @echo "Running Texinfo files through makeinfo..." 146 | make -C $(BUILDDIR)/texinfo info 147 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 148 | 149 | gettext: 150 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 151 | @echo 152 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 153 | 154 | changes: 155 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 156 | @echo 157 | @echo "The overview file is in $(BUILDDIR)/changes." 158 | 159 | linkcheck: 160 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 161 | @echo 162 | @echo "Link check complete; look for any errors in the above output " \ 163 | "or in $(BUILDDIR)/linkcheck/output.txt." 164 | 165 | doctest: 166 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 167 | @echo "Testing of doctests in the sources finished, look at the " \ 168 | "results in $(BUILDDIR)/doctest/output.txt." 169 | 170 | xml: 171 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 172 | @echo 173 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 174 | 175 | pseudoxml: 176 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 177 | @echo 178 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 179 | -------------------------------------------------------------------------------- /doc/_static/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /doc/_static/switcher.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "dev", 4 | "version": "latest", 5 | "url": "https://xbatcher.readthedocs.io/en/latest/" 6 | }, 7 | { 8 | "version": "0.4.0", 9 | "url": "https://xbatcher.readthedocs.io/en/v0.4.0/" 10 | }, 11 | { 12 | "version": "0.3.0", 13 | "url": "https://xbatcher.readthedocs.io/en/v0.3.0/" 14 | }, 15 | { 16 | "version": "0.2.0", 17 | "url": "https://xbatcher.readthedocs.io/en/v0.2.0/" 18 | }, 19 | { 20 | "version": "0.1.0", 21 | "url": "https://xbatcher.readthedocs.io/en/0.1.0/" 22 | } 23 | ] 24 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API reference 4 | ------------- 5 | 6 | This page provides an auto-generated summary of Xbatcher's API. 7 | 8 | Core 9 | ==== 10 | 11 | .. currentmodule:: xbatcher 12 | 13 | .. autosummary:: 14 | :toctree: generated/ 15 | 16 | BatchGenerator 17 | BatchSchema 18 | 19 | Xbatcher Xarray accessors 20 | ========================= 21 | 22 | .. currentmodule:: xarray 23 | 24 | .. autosummary:: 25 | :toctree: generated/ 26 | :template: autosummary/accessor_method.rst 27 | 28 | Dataset.batch.generator 29 | DataArray.batch.generator 30 | 31 | Dataloaders 32 | =========== 33 | 34 | .. currentmodule:: xbatcher 35 | 36 | .. autosummary:: 37 | :toctree: generated/ 38 | 39 | loaders.torch.MapDataset 40 | loaders.torch.IterableDataset 41 | loaders.keras.CustomTFDataset 42 | :members: 43 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # 2 | # xbatcher documentation build configuration file, created by 3 | # sphinx-quickstart on Sat Aug 29 00:18:20 2015. 4 | # 5 | # This file is execfile()d with the current directory set to its 6 | # containing dir. 7 | # 8 | # Note that not all possible configuration values are present in this 9 | # autogenerated file. 10 | # 11 | # All configuration values have a default; values that are commented out 12 | # serve to show the default. 13 | 14 | # type: ignore 15 | 16 | import datetime 17 | import os 18 | import sys 19 | 20 | import sphinx_autosummary_accessors 21 | 22 | import xbatcher 23 | 24 | # If extensions (or modules to document with autodoc) are in another directory, 25 | # add these directories to sys.path here. If the directory is relative to the 26 | # documentation root, use os.path.abspath to make it absolute, like shown here. 27 | # sys.path.insert(0, os.path.abspath('.')) 28 | # sys.path.insert(os.path.abspath('..')) 29 | 30 | print('python exec:', sys.executable) 31 | print('sys.path:', sys.path) 32 | print('xbatcher.version:', xbatcher.__version__) 33 | 34 | 35 | # -- General configuration ------------------------------------------------ 36 | 37 | # If your documentation needs a minimal Sphinx version, state it here. 38 | # needs_sphinx = '1.0' 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | # ones. 43 | extensions = [ 44 | 'sphinx.ext.mathjax', 45 | 'sphinx.ext.autodoc', 46 | 'sphinx.ext.autosummary', 47 | 'sphinx.ext.extlinks', 48 | 'sphinx.ext.viewcode', 49 | 'sphinx.ext.intersphinx', 50 | 'numpydoc', 51 | 'nbsphinx', 52 | 'IPython.sphinxext.ipython_directive', 53 | 'IPython.sphinxext.ipython_console_highlighting', 54 | 'sphinx_autosummary_accessors', 55 | 'sphinx_copybutton', 56 | 'sphinx_design', 57 | ] 58 | 59 | nbsphinx_execute = 'auto' 60 | 61 | 62 | autodoc_mock_imports = ['torch', 'tensorflow'] 63 | 64 | # link to github issues 65 | extlinks = {'issue': ('https://github.com/xarray-contrib/xbatcher/issues/%s', '#%s')} 66 | 67 | # sphinx-copybutton configurations (from https://github.com/pydata/xarray/blob/main/doc/conf.py) 68 | copybutton_prompt_text = r'>>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: ' 69 | copybutton_prompt_is_regexp = True 70 | 71 | autosummary_generate = True 72 | numpydoc_class_members_toctree = False 73 | numpydoc_show_class_members = False 74 | 75 | # Add any paths that contain templates here, relative to this directory. 76 | templates_path = ['_templates', sphinx_autosummary_accessors.templates_path] 77 | 78 | # The suffix of source filenames. 79 | source_suffix = '.rst' 80 | 81 | # The encoding of source files. 82 | # source_encoding = 'utf-8-sig' 83 | 84 | # The master toctree document. 85 | master_doc = 'index' 86 | 87 | # General information about the project. 88 | project = 'xbatcher' 89 | copyright = f'2016-{datetime.datetime.now().year}, xbatcher developers' 90 | 91 | # The version info for the project you're documenting, acts as replacement for 92 | # |version| and |release|, also used in various other places throughout the 93 | # built documents. 94 | # 95 | # The short X.Y version. 96 | version = xbatcher.__version__ 97 | # The full version, including alpha/beta/rc tags. 98 | release = xbatcher.__version__ 99 | 100 | # The language for content autogenerated by Sphinx. Refer to documentation 101 | # for a list of supported languages. 102 | # language = None 103 | 104 | # There are two options for replacing |today|: either, you set today to some 105 | # non-false value, then it is used: 106 | # today = '' 107 | # Else, today_fmt is used as the format for a strftime call. 108 | # today_fmt = '%B %d, %Y' 109 | 110 | # List of patterns, relative to source directory, that match files and 111 | # directories to ignore when looking for source files. 112 | exclude_patterns = [ 113 | '_build', 114 | '**.ipynb_checkpoints', 115 | 'user-guide/create-fashion-mnist-dataset.ipynb', 116 | ] 117 | 118 | # The reST default role (used for this markup: `text`) to use for all 119 | # documents. 120 | # default_role = None 121 | 122 | # If true, '()' will be appended to :func: etc. cross-reference text. 123 | # add_function_parentheses = True 124 | 125 | # If true, the current module name will be prepended to all description 126 | # unit titles (such as .. function::). 127 | # add_module_names = True 128 | 129 | # If true, sectionauthor and moduleauthor directives will be shown in the 130 | # output. They are ignored by default. 131 | # show_authors = False 132 | 133 | # The name of the Pygments (syntax highlighting) style to use. 134 | pygments_style = 'sphinx' 135 | 136 | # A list of ignored prefixes for module index sorting. 137 | # modindex_common_prefix = [] 138 | 139 | # If true, keep warnings as "system message" paragraphs in the built documents. 140 | # keep_warnings = False 141 | 142 | 143 | # -- Options for HTML output ---------------------------------------------- 144 | 145 | # The theme to use for HTML and HTML Help pages. See the documentation for 146 | # a list of builtin themes. 147 | # tml_theme = 'default' 148 | html_theme = 'pydata_sphinx_theme' 149 | html_logo = '_static/logo.svg' 150 | html_favicon = '_static/logo.svg' 151 | 152 | # The following is from the pydata-sphinx-theme settings (https://github.com/pydata/pydata-sphinx-theme/blob/main/docs/conf.py) 153 | # Define the json_url for our version switcher. 154 | json_url = 'https://xbatcher.readthedocs.io/en/latest/_static/switcher.json' 155 | 156 | # Define the version we use for matching in the version switcher. 157 | version_match = os.environ.get('READTHEDOCS_VERSION') 158 | # If READTHEDOCS_VERSION doesn't exist, we're not on RTD 159 | # If it is an integer, we're in a PR build and the version isn't correct. 160 | if not version_match or version_match.isdigit(): 161 | # For local development, infer the version to match from the package. 162 | release = xbatcher.__version__ 163 | if 'dev' in release or 'post' in release or 'rc' in release: 164 | version_match = 'latest' 165 | # We want to keep the relative reference if we are in dev mode 166 | # but we want the whole url if we are effectively in a released version 167 | json_url = '_static/switcher.json' 168 | else: 169 | version_match = 'v' + release 170 | 171 | print(f'release: {release}') 172 | 173 | 174 | # Theme options are theme-specific and customize the look and feel of a theme 175 | # further. For a list of options available for each theme, see the 176 | # documentation. 177 | html_theme_options = { 178 | 'github_url': 'https://github.com/xarray-contrib/xbatcher', 179 | 'switcher': { 180 | 'json_url': json_url, 181 | 'version_match': version_match, 182 | }, 183 | 'logo': { 184 | 'text': 'Xbatcher', 185 | 'alt_text': 'Xbatcher', 186 | }, 187 | 'navbar_align': 'left', # [left, content, right] For testing that the navbar items align properly 188 | 'navbar_center': ['version-switcher', 'navbar-nav'], 189 | } 190 | 191 | # Add any paths that contain custom themes here, relative to this directory. 192 | # html_theme_path = [] 193 | 194 | # The name for this set of Sphinx documents. If None, it defaults to 195 | # " v documentation". 196 | # html_title = None 197 | 198 | # A shorter title for the navigation bar. Default is the same as html_title. 199 | # html_short_title = None 200 | 201 | # The name of an image file (relative to this directory) to place at the top 202 | # of the sidebar. 203 | # html_logo = None 204 | 205 | # The name of an image file (within the static path) to use as favicon of the 206 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 207 | # pixels large. 208 | # html_favicon = None 209 | 210 | # Add any paths that contain custom static files (such as style sheets) here, 211 | # relative to this directory. They are copied after the builtin static files, 212 | # so a file named "default.css" will overwrite the builtin "default.css". 213 | html_static_path = ['_static'] 214 | 215 | # Add any extra paths that contain custom files (such as robots.txt or 216 | # .htaccess) here, relative to this directory. These files are copied 217 | # directly to the root of the documentation. 218 | # html_extra_path = [] 219 | 220 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 221 | # using the given strftime format. 222 | # html_last_updated_fmt = '%b %d, %Y' 223 | 224 | # If true, SmartyPants will be used to convert quotes and dashes to 225 | # typographically correct entities. 226 | # html_use_smartypants = True 227 | 228 | # Custom sidebar templates, maps document names to template names. 229 | # html_sidebars = {} 230 | 231 | # Additional templates that should be rendered to pages, maps page names to 232 | # template names. 233 | # html_additional_pages = {} 234 | 235 | # If false, no module index is generated. 236 | # html_domain_indices = True 237 | 238 | # If false, no index is generated. 239 | # html_use_index = True 240 | 241 | # If true, the index is split into individual pages for each letter. 242 | # html_split_index = False 243 | 244 | # If true, links to the reST sources are added to the pages. 245 | # html_show_sourcelink = True 246 | 247 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 248 | # html_show_sphinx = True 249 | 250 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 251 | # html_show_copyright = True 252 | 253 | # If true, an OpenSearch description file will be output, and all pages will 254 | # contain a tag referring to it. The value of this option must be the 255 | # base URL from which the finished HTML is served. 256 | # html_use_opensearch = '' 257 | 258 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 259 | # html_file_suffix = None 260 | 261 | # Output file base name for HTML help builder. 262 | htmlhelp_basename = 'xbatcherdoc' 263 | 264 | 265 | # -- Options for LaTeX output --------------------------------------------- 266 | 267 | latex_elements = { 268 | # The paper size ('letterpaper' or 'a4paper'). 269 | # 'papersize': 'letterpaper', 270 | # The font size ('10pt', '11pt' or '12pt'). 271 | # 'pointsize': '10pt', 272 | # Additional stuff for the LaTeX preamble. 273 | # 'preamble': '', 274 | } 275 | 276 | # Grouping the document tree into LaTeX files. List of tuples 277 | # (source start file, target name, title, 278 | # author, documentclass [howto, manual, or own class]). 279 | latex_documents = [ 280 | ( 281 | 'index', 282 | 'xbatcher.tex', 283 | 'xbatcher Documentation', 284 | 'xbatcher developers', 285 | 'manual', 286 | ), 287 | ] 288 | 289 | # The name of an image file (relative to this directory) to place at the top of 290 | # the title page. 291 | # latex_logo = None 292 | 293 | # For "manual" documents, if this is true, then toplevel headings are parts, 294 | # not chapters. 295 | # latex_use_parts = False 296 | 297 | # If true, show page references after internal links. 298 | # latex_show_pagerefs = False 299 | 300 | # If true, show URL addresses after external links. 301 | # latex_show_urls = False 302 | 303 | # Documents to append as an appendix to all manuals. 304 | # latex_appendices = [] 305 | 306 | # If false, no module index is generated. 307 | # latex_domain_indices = True 308 | 309 | 310 | # -- Options for manual page output --------------------------------------- 311 | 312 | # One entry per manual page. List of tuples 313 | # (source start file, name, description, authors, manual section). 314 | man_pages = [ 315 | ( 316 | 'index', 317 | 'xbatcher', 318 | 'xbatcher Documentation', 319 | ['xbatcher developers'], 320 | 1, 321 | ) 322 | ] 323 | 324 | # If true, show URL addresses after external links. 325 | # man_show_urls = False 326 | 327 | 328 | # -- Options for Texinfo output ------------------------------------------- 329 | 330 | # Grouping the document tree into Texinfo files. List of tuples 331 | # (source start file, target name, title, author, 332 | # dir menu entry, description, category) 333 | texinfo_documents = [ 334 | ( 335 | 'index', 336 | 'xbatcher', 337 | 'xbatcher Documentation', 338 | 'xbatcher developers', 339 | 'xbatcher', 340 | 'One line description of project.', 341 | 'Miscellaneous', 342 | ), 343 | ] 344 | 345 | # Documents to append as an appendix to all manuals. 346 | # texinfo_appendices = [] 347 | 348 | # If false, no module index is generated. 349 | # texinfo_domain_indices = True 350 | 351 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 352 | # texinfo_show_urls = 'footnote' 353 | 354 | # If true, do not generate a @detailmenu in the "Top" node's menu. 355 | # texinfo_no_detailmenu = False 356 | 357 | 358 | # Example configuration for intersphinx: refer to the Python standard library. 359 | intersphinx_mapping = { 360 | 'python': ('https://docs.python.org/3/', None), 361 | 'xarray': ('http://xarray.pydata.org/en/stable/', None), 362 | } 363 | -------------------------------------------------------------------------------- /doc/contributing.rst: -------------------------------------------------------------------------------- 1 | .. _contributing: 2 | 3 | ****************** 4 | Contributing Guide 5 | ****************** 6 | 7 | .. note:: 8 | 9 | Large parts of this document came from the `Xarray Contributing 10 | Guide `_, which is based 11 | on the `Pandas Contributing Guide 12 | `_. 13 | 14 | Bug reports and feature requests 15 | ================================ 16 | 17 | To report bugs or request new features, head over to the `xbatcher repository 18 | `_. 19 | 20 | Contributing code 21 | ================== 22 | 23 | `GitHub has instructions `__ for 24 | installing git, setting up your SSH key, and configuring git. All these steps 25 | need to be completed for you to work between your local repository and GitHub. 26 | 27 | .. _contributing.forking: 28 | 29 | Forking 30 | ------- 31 | 32 | You will need your own fork to work on the code. Go to the `xbatcher project 33 | page `_ and hit the ``Fork`` button. 34 | You will need to clone your fork to your machine:: 35 | 36 | git clone git@github.com:yourusername/xbatcher.git 37 | cd xbatcher 38 | git remote add upstream git@github.com:xarray-contrib/xbatcher.git 39 | 40 | This creates the directory ``xbatcher`` and connects your repository to 41 | the upstream (main project) *xbatcher* repository. 42 | 43 | .. _contributing.dev_env: 44 | 45 | Creating a development environment 46 | ---------------------------------- 47 | 48 | To test out code changes, you'll need to build *xbatcher* from source, which 49 | requires a Python environment. If you're making documentation changes, you can 50 | skip to :ref:`contributing.documentation` but you won't be able to build the 51 | documentation locally before pushing your changes. 52 | 53 | .. _contributiong.dev_python: 54 | 55 | Creating a Python Environment 56 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 57 | 58 | Before starting any development, you'll need to create an isolated xbatcher 59 | development environment: 60 | 61 | - Install either `Anaconda `_ or `miniconda 62 | `_ 63 | - Make sure your conda is up to date (``conda update conda``) 64 | - Make sure that you have :ref:`cloned the repository ` 65 | - ``cd`` to the *xbatcher* source directory 66 | 67 | First we'll create and activate the build environment: 68 | 69 | .. code-block:: sh 70 | 71 | conda env create --file ci/requirements/environment.yml 72 | conda activate xbatcher-tests 73 | 74 | At this point you should be able to import *xbatcher* from your locally 75 | built version: 76 | 77 | .. code-block:: sh 78 | 79 | $ python # start an interpreter 80 | >>> import xbatcher 81 | >>> xbatcher.__version__ 82 | 83 | This will create the new environment, and not touch any of your existing environments, 84 | nor any existing Python installation. 85 | 86 | To view your environments:: 87 | 88 | conda info --envs 89 | 90 | To return to your base environment:: 91 | 92 | conda deactivate 93 | 94 | See the full conda docs `here `__. 95 | 96 | Setting up pre-commit 97 | ~~~~~~~~~~~~~~~~~~~~~ 98 | 99 | We use `pre-commit `_ to manage code linting and style. 100 | To set up pre-commit after activating your conda environment, run: 101 | 102 | .. code-block:: sh 103 | 104 | pre-commit install 105 | 106 | Creating a branch 107 | ----------------- 108 | 109 | You want your ``main`` branch to reflect only production-ready code, so create a 110 | feature branch before making your changes. For example:: 111 | 112 | git branch shiny-new-feature 113 | git checkout shiny-new-feature 114 | 115 | The above can be simplified to:: 116 | 117 | git checkout -b shiny-new-feature 118 | 119 | This changes your working directory to the shiny-new-feature branch. Keep any 120 | changes in this branch specific to one bug or feature so it is clear 121 | what the branch brings to *xbatcher*. You can have many "shiny-new-features" 122 | and switch in between them using the ``git checkout`` command. 123 | 124 | To update this branch, you need to retrieve the changes from the ``main`` branch:: 125 | 126 | git fetch upstream 127 | git merge upstream/main 128 | 129 | This will combine your commits with the latest *xbatcher* git ``main``. If this 130 | leads to merge conflicts, you must resolve these before submitting your pull 131 | request. If you have uncommitted changes, you will need to ``git stash`` them 132 | prior to updating. This will effectively store your changes, which can be 133 | reapplied after updating. 134 | 135 | Running the test suite 136 | ---------------------- 137 | 138 | *xbatcher* uses the `pytest `_ 139 | framework for testing. You can run the test suite using:: 140 | 141 | pytest xbatcher 142 | 143 | 144 | 145 | Running the performance test suite 146 | ---------------------------------- 147 | 148 | *xbatcher* is starting a suite of benchmarking tests using 149 | `asv `__ to enable easy monitoring of 150 | the performance of critical operations. These benchmarks are all found in the 151 | ``asv_bench`` directory. 152 | 153 | To use all features of asv, you will need either ``conda`` or ``virtualenv``. 154 | For more details please check the `asv installation webpage 155 | `_. 156 | 157 | To install asv:: 158 | 159 | pip install git+https://github.com/airspeed-velocity/asv 160 | 161 | If you need to run a benchmark, change your directory to ``asv_bench/`` and run:: 162 | 163 | asv continuous -f 1.1 main 164 | 165 | You can replace ``my-branch`` with the name of the branch you are working on. 166 | The output will include "BENCHMARKS NOT SIGNIFICANTLY CHANGED" if the 167 | benchmarks did not change by more than 10%. 168 | 169 | The command uses ``conda`` by default for creating the benchmark 170 | environments. If you want to use virtualenv instead, write:: 171 | 172 | asv continuous -f 1.1 -E virtualenv main 173 | 174 | The ``-E virtualenv`` option should be added to all ``asv`` commands 175 | that run benchmarks. The default value is defined in ``asv.conf.json``. 176 | 177 | If you want to only run a specific group of tests from a file, you can do it 178 | using ``.`` as a separator. For example:: 179 | 180 | asv continuous -f 1.1 main HEAD -b benchmarks.Generator.time_batch_preload 181 | 182 | will only run the ``Generator.time_batch_preload`` benchmark defined in 183 | ``benchmarks.py``. 184 | 185 | Information on how to write a benchmark and how to use asv can be found in the 186 | `asv documentation `_. 187 | 188 | Contributing documentation 189 | ========================== 190 | 191 | We greatly appreciate documentation improvements. The docs are built from the docstrings 192 | in the code and the docs in the ``doc`` directory. 193 | 194 | To build the documentation, you will need to requirements listed in ``ci/requirements/doc.yml``. 195 | You can create an environment for building the documentation using:: 196 | 197 | conda env create --file ci/requirements/docs.yml 198 | conda activate xbatcher-docs 199 | 200 | You can then build the documentation using:: 201 | 202 | cd docs 203 | make html 204 | 205 | Contributing changes 206 | ==================== 207 | 208 | Once you've made changes, you can see them by typing:: 209 | 210 | git status 211 | 212 | If you have created a new file, it is not being tracked by git. Add it by typing:: 213 | 214 | git add path/to/file-to-be-added.py 215 | 216 | The following defines how a commit message should be structured: 217 | 218 | * A subject line with `< 72` chars. 219 | * One blank line. 220 | * Optionally, a commit message body. 221 | 222 | Now you can commit your changes in your local repository:: 223 | 224 | git commit -m 225 | 226 | When you want your changes to appear publicly on your GitHub page, push your 227 | commits to a branch off your fork:: 228 | 229 | git push origin shiny-new-feature 230 | 231 | Here ``origin`` is the default name given to your remote repository on GitHub. 232 | You can see the remote repositories:: 233 | 234 | git remote -v 235 | 236 | If you navigate to your branch on GitHub, you should see a banner to submit a pull 237 | request to the *xbatcher* repository. 238 | 239 | .. _contributing.ci: 240 | 241 | Continuous integration 242 | ====================== 243 | 244 | Continuous integration is done with `GitHub Actions `_. 245 | 246 | There are currently 3 workflows configured: 247 | 248 | - `main.yaml `_ - Run test suite with pytest. 249 | - `pypi-release.yaml `_ - Publish 250 | wheels to TestPyPI and PyPI on a tagged release. The pull request trigger can be uncommented to test a release using Test PyPI. 251 | - `release-drafter.yml `_ - Draft 252 | release notes based on PR titles and labels. 253 | -------------------------------------------------------------------------------- /doc/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "sticky-exhibit", 7 | "metadata": {}, 8 | "source": [ 9 | "# Demo\n", 10 | "\n", 11 | "Author: Cindy Chiao\n", 12 | "\n", 13 | "## What is xbatcher? \n", 14 | "Xbatcher is a small library for iterating through Xarray objects (DataArrays and Datasets) in batches. The goal is to make it easy to feed Xarray objects to machine learning libraries such as Keras and PyTorch. \n", 15 | "\n", 16 | "## What is included in this notebook?\n", 17 | "* showcase current abilities with example data \n", 18 | "* brief discussion of current development track and ideas for future work " 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "banner-importance", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import xarray as xr\n", 29 | "\n", 30 | "import xbatcher" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "equipped-sense", 36 | "metadata": {}, 37 | "source": [ 38 | "## Example data\n", 39 | "\n", 40 | "Here we will load an example dataset from a global climate model. The data is from the _historical_ experiment from CMIP6 and represents 60 days of daily max air temperature. " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "dutch-grave", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "store = 's3://carbonplan-share/xbatcher/example_cmip6_data.zarr'\n", 51 | "ds = xr.open_dataset(\n", 52 | " store, engine='zarr', chunks={}, backend_kwargs={'storage_options': {'anon': True}}\n", 53 | ")\n", 54 | "\n", 55 | "# inspect the dataset\n", 56 | "ds" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "applicable-diesel", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# plot the first time dimension\n", 67 | "ds.isel(time=0).tasmax.plot();" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "animated-marsh", 73 | "metadata": {}, 74 | "source": [ 75 | "## Batch generation\n", 76 | "\n", 77 | "Xbatcher's `BatchGenerator` can be used to generate batches with several arguments controlling the exact behavior.\n", 78 | "\n", 79 | "The `input_dims` argument takes a dictionary specifying the size of the inputs in each dimension. For example, `{'time': 10}` means that each of the input sample will have 10 time points, while all other dimensions are flattened to a \"sample\" dimension\n", 80 | "\n", 81 | "Note that even though `ds` in this case only has one variable, the function can operate on multiple variables at the same time." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "attempted-cooling", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "n_timepoint_in_each_sample = 10\n", 92 | "\n", 93 | "bgen = xbatcher.BatchGenerator(\n", 94 | " ds=ds,\n", 95 | " input_dims={'time': n_timepoint_in_each_sample},\n", 96 | ")\n", 97 | "\n", 98 | "print(f'{len(bgen)} batches')" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "546aed21-3931-46b5-910e-c43498b51e23", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "batch = bgen[0]\n", 109 | "batch" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "digital-night", 115 | "metadata": {}, 116 | "source": [ 117 | "We can verify that the outputs have the expected shapes. \n", 118 | "\n", 119 | "For example, there are 60 time points in our input dataset, we're asking 10 timepoints in each batch, thus expecting 6 batches " 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "integral-theta", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n", 130 | "print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "usual-kennedy", 136 | "metadata": {}, 137 | "source": [ 138 | "There are 145 lat points and 192 lon points, thus we're expecting 145 * 192 = 27840 samples in a batch." 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "incomplete-native", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "expected_batch_size = len(ds.lat) * len(ds.lon)\n", 149 | "print(\n", 150 | " f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n", 151 | ")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "id": "durable-gazette", 157 | "metadata": {}, 158 | "source": [ 159 | "## Controlling the size/shape of batches\n", 160 | "\n", 161 | "We can use `batch_dims` and `concat_input_dims` options to control how many sample ends up in each batch. For example, we can specify 10 time points for each sample, but 20 time points in each batch this should yield half as many batches and twice as many samples in a batch as the example above note the difference in dimension name in this case " 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "sophisticated-legislation", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "n_timepoint_in_each_sample = 10\n", 172 | "n_timepoint_in_each_batch = 20\n", 173 | "\n", 174 | "bgen = xbatcher.BatchGenerator(\n", 175 | " ds=ds,\n", 176 | " input_dims={'time': n_timepoint_in_each_sample},\n", 177 | " batch_dims={'time': n_timepoint_in_each_batch},\n", 178 | " concat_input_dims=True,\n", 179 | ")\n", 180 | "\n", 181 | "print(f'{len(bgen)} batches')" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "857d962e-2b9e-4e25-95e3-a922bfac3d6f", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "bgen[0]" 192 | ] 193 | }, 194 | { 195 | "attachments": {}, 196 | "cell_type": "markdown", 197 | "id": "spectacular-reading", 198 | "metadata": {}, 199 | "source": [ 200 | "## Last batch behavior\n", 201 | "\n", 202 | "If the input ds is not divisible by the specified `input_dims`, the remainder will be discarded instead of having a fractional batch. See https://github.com/xarray-contrib/xbatcher/discussions/82 for more on this topic." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "residential-income", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "n_timepoint_in_batch = 31\n", 213 | "\n", 214 | "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n", 215 | "\n", 216 | "for batch in bgen:\n", 217 | " print(f'last time point in ds is {ds.time[-1].values}')\n", 218 | " print(f'last time point in batch is {batch.time[-1].values}')\n", 219 | "batch" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "id": "competitive-islam", 225 | "metadata": {}, 226 | "source": [ 227 | "## Overlapping inputs\n", 228 | "\n", 229 | "In the example above, all samples have distinct time points. That is, for any lat/lon pixel, sample 1 has time points 1-10, sample 2 has time point 11-20, and they do not overlap \n", 230 | "however, in many machine learning applications, we will want overlapping samples (e.g. sample 1 has time points 1-10, sample 2 has time points 2-11, and so on). We can use the `input_overlap` argument to get this behavior." 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "cleared-custody", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "n_timepoint_in_each_sample = 10\n", 241 | "n_timepoint_in_each_batch = 20\n", 242 | "input_overlap = 9\n", 243 | "\n", 244 | "bgen = xbatcher.BatchGenerator(\n", 245 | " ds=ds,\n", 246 | " input_dims={'time': n_timepoint_in_each_sample},\n", 247 | " batch_dims={'time': n_timepoint_in_each_batch},\n", 248 | " concat_input_dims=True,\n", 249 | " input_overlap={'time': input_overlap},\n", 250 | ")\n", 251 | "\n", 252 | "batch = bgen[0]\n", 253 | "\n", 254 | "print(f'{len(bgen)} batches')\n", 255 | "batch" 256 | ] 257 | }, 258 | { 259 | "attachments": {}, 260 | "cell_type": "markdown", 261 | "id": "harmful-benefit", 262 | "metadata": {}, 263 | "source": [ 264 | "We can inspect the samples in a batch for a lat/lon pixel, noting that the overlap applies across batches." 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "id": "earlier-warehouse", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "lat = -90\n", 275 | "lon = 0\n", 276 | "pixel = batch.sel(lat=lat, lon=lon)\n", 277 | "display(pixel)\n", 278 | "\n", 279 | "print(\n", 280 | " f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n", 281 | ")\n", 282 | "print(\n", 283 | " f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n", 284 | ")" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "id": "arranged-telephone", 290 | "metadata": {}, 291 | "source": [ 292 | "## Example applications\n", 293 | "\n", 294 | "These batches can then be used to train a downstream machine learning model while preserving the indices of these sample. \n", 295 | "\n", 296 | "As an example, let's say we want to train a simple CNN model to predict the max air temprature for each day at each lat/lon pixel. To predict the temperature at lat/lon/time of (i, j, t), we'll use features including the temperature of a 9 x 9 grid centered at (i, j), from times t-10 to t-1 (shape of input should be (n_samples_in_each_batch, 9, 9, 9)). Note that in this example, we subset the dataset to a smaller domain for efficiency." 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "consolidated-chocolate", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "bgen = xbatcher.BatchGenerator(\n", 307 | " ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", 308 | " input_dims={'lat': 9, 'lon': 9, 'time': 10},\n", 309 | " batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n", 310 | " concat_input_dims=True,\n", 311 | " input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n", 312 | ")\n", 313 | "\n", 314 | "for i, batch in enumerate(bgen):\n", 315 | " print(f'batch {i}')\n", 316 | " # make sure the ordering of dimension is consistent\n", 317 | " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", 318 | "\n", 319 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n", 320 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n", 321 | " # select the center pixel at the last time point to be the label to be predicted\n", 322 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n", 323 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n", 324 | "\n", 325 | " print('feature shape', features.shape)\n", 326 | " print('label shape', labels.shape)\n", 327 | " print('shape of lat of each sample', labels.coords['lat'].shape)\n", 328 | " print('')" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "id": "legislative-closer", 334 | "metadata": {}, 335 | "source": [ 336 | "We can also use the Xarray's \"stack\" method to transform these into 2D inputs (n_samples, n_features) suitable for other machine learning algorithms implemented in libraries such as [sklearn](https://scikit-learn.org/stable/) and [xgboost](https://xgboost.readthedocs.io/en/stable/). In this case, we are expecting 9 x 9 x 9 = 729 features total." 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "id": "advisory-chicken", 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "for i, batch in enumerate(bgen):\n", 347 | " print(f'batch {i}')\n", 348 | " # make sure the ordering of dimension is consistent\n", 349 | " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", 350 | "\n", 351 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n", 352 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n", 353 | " features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n", 354 | "\n", 355 | " # select the center pixel at the last time point to be the label to be predicted\n", 356 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n", 357 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n", 358 | "\n", 359 | " print('feature shape', features.shape)\n", 360 | " print('label shape', labels.shape)\n", 361 | " print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')" 362 | ] 363 | }, 364 | { 365 | "attachments": {}, 366 | "cell_type": "markdown", 367 | "id": "persistent-culture", 368 | "metadata": {}, 369 | "source": [ 370 | "## What's next?\n", 371 | "\n", 372 | "There are many additional useful features that were yet to be implemented in the context of batch generation for downstream machine learning model training purposes. One of the current efforts is to improve the set of data loaders. \n", 373 | "\n", 374 | "Additional features of interest can include: \n", 375 | "\n", 376 | "1. Shuffling/randomization of samples across batches. It is often desirable for each batch to be grouped randomly instead of along a specific dimension. \n", 377 | "\n", 378 | "2. Be efficient in terms of memory usage. In the case where overlap is enabled, each sample would comprised of mostly repetitive values compared to adjacent samples. It would be beneficial if each batch/sample is generated lazily to avoid storing these extra duplicative values. \n", 379 | "\n", 380 | "3. Handling preprocessing steps. For example, data augmentation, scaling/normalization, outlier detection, etc. \n", 381 | "\n", 382 | "\n", 383 | "More thoughts on 1. can be found in [this discussion](https://github.com/xarray-contrib/xbatcher/discussions/78). Interested users are welcomed to comment or submit other issues in GitHub. " 384 | ] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "Python 3 (ipykernel)", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.11.9" 404 | }, 405 | "vscode": { 406 | "interpreter": { 407 | "hash": "64c578a0a9f6dde4e1dfaddaa39417770d5e50fec039804eaf1eb97ef756c00c" 408 | } 409 | } 410 | }, 411 | "nbformat": 4, 412 | "nbformat_minor": 5 413 | } 414 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | xbatcher: Batch Generation from Xarray Datasets 2 | =============================================== 3 | 4 | Xbatcher is a small library for iterating Xarray DataArrays and Datasets in 5 | batches. The goal is to make it easy to feed Xarray objects to machine learning 6 | libraries such as Keras_. 7 | 8 | .. _Keras: https://keras.io/ 9 | 10 | Installation 11 | ------------ 12 | 13 | Xbatcher can be installed from PyPI as:: 14 | 15 | python -m pip install xbatcher 16 | 17 | Or via Conda as:: 18 | 19 | conda install -c conda-forge xbatcher 20 | 21 | Or from source as:: 22 | 23 | python -m pip install git+https://github.com/xarray-contrib/xbatcher.git 24 | 25 | Optional Dependencies 26 | ~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. note:: 29 | The required dependencies installed with Xbatcher are `Xarray `_, 30 | `Dask `_, and `NumPy `_. 31 | You will need to separately install `TensorFlow `_ 32 | or `PyTorch `_ to use those data loaders or 33 | Xarray accessors. 34 | 35 | To install Xbatcher and PyTorch via `Conda `_:: 36 | 37 | conda install -c conda-forge xbatcher pytorch 38 | 39 | Or via PyPI:: 40 | 41 | python -m pip install xbatcher[torch] 42 | 43 | To install Xbatcher and TensorFlow via `Conda `_:: 44 | 45 | conda install -c conda-forge xbatcher tensorflow 46 | 47 | Or via PyPI:: 48 | 49 | python -m pip install xbatcher[tensorflow] 50 | 51 | Basic Usage 52 | ----------- 53 | 54 | Let's say we have an Xarray Dataset 55 | 56 | .. ipython:: python 57 | 58 | import xarray as xr 59 | import numpy as np 60 | da = xr.DataArray(np.random.rand(1000, 100, 100), name='foo', 61 | dims=['time', 'y', 'x']).chunk({'time': 1}) 62 | da 63 | 64 | and we want to create batches along the time dimension. We can do it like this 65 | 66 | .. ipython:: python 67 | 68 | import xbatcher 69 | bgen = xbatcher.BatchGenerator(da, {'time': 10}) 70 | for batch in bgen: 71 | pass 72 | # actually feed to machine learning library 73 | batch 74 | 75 | or via a built-in `Xarray accessor `_: 76 | 77 | .. ipython:: python 78 | 79 | import xbatcher 80 | 81 | for batch in da.batch.generator({'time': 10}): 82 | pass 83 | # actually feed to machine learning library 84 | batch 85 | 86 | .. toctree:: 87 | :maxdepth: 2 88 | :caption: Contents: 89 | 90 | api 91 | user-guide/index 92 | tutorials-and-presentations 93 | roadmap 94 | contributing 95 | -------------------------------------------------------------------------------- /doc/roadmap.rst: -------------------------------------------------------------------------------- 1 | .. _roadmap: 2 | 3 | Roadmap 4 | ======= 5 | 6 | Authors: Joe Hamman and Ryan Abernathey 7 | Date: February 7, 2019 8 | 9 | Background and scope 10 | -------------------- 11 | 12 | Xbatcher is a small library for iterating xarray objects in batches. The 13 | goal is to make it easy to feed xarray datasets to machine learning libraries 14 | such as `Keras`_ or `PyTorch`_. For example, implementing a simple machine 15 | learning workflow may look something like this: 16 | 17 | .. code-block:: Python 18 | 19 | import xarray as xr 20 | import xbatcher as xb 21 | 22 | da = xr.open_dataset(filename, chunks=chunks) # open a dataset and use dask 23 | da_train = preprocess(ds) # perform some preprocessing 24 | bgen = xb.BatchGenerator(da_train, {'time': 10}) # create a generator 25 | 26 | for batch in bgen: # iterate through the generator 27 | model.fit(batch['x'], batch['y']) # fit a deep-learning model 28 | # or 29 | model.predict(batch['x']) # make one batch of predictions 30 | 31 | We are currently envisioning the project growing to support more complex 32 | extract-transform-load components commonly found in machine learning workflows 33 | that use multidimensional data. We note that many of the concepts in Xbatcher 34 | have been developed through collaborations in the `Pangeo Project Machine 35 | Learning Working Group `_. 36 | 37 | Batch generation 38 | ~~~~~~~~~~~~~~~~ 39 | 40 | At the core of Xbatcher is the ability to define a schema that defines a 41 | selection of a larger dataset. Today, this schema is fairly simple (e.g. 42 | `{'time': 10}`) but this may evolve in the future. As we describe below, 43 | additional utilities for shuffling, sampling, and caching may provide enhanced 44 | batch generation functionality 45 | 46 | Shuffle and Sampling APIs 47 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 48 | 49 | When training machine-learning models in batches, it is often necessary to 50 | selectively or randomly sample from your training data. Xbatcher can help 51 | facilitate seamless shuffling and sampling by providing APIs that operate on 52 | batches and/or full datasets. This may require working with Xarray and Dask to 53 | facilitate fast, distributed shuffles of Dask arrays. 54 | 55 | Caching APIs 56 | ~~~~~~~~~~~~ 57 | 58 | A common pattern in ML is perform the ETL tasks once before saving the results 59 | to a local file system. This is an effective approach for speeding up dataset 60 | loading during training but comes with numerous downsides (i.e. requires 61 | sufficient file space, breaks workflow continuity, etc.). We propose the 62 | development of a pluggable cache mechanism in Xbatcher that would help address 63 | these downsides while providing improved performance during model training and 64 | inference. For example, this pluggable cache mechanism may allow choosing 65 | between multiple cache types, such as an LRU in-memory cache, a Zarr filesystem 66 | or S3 bucket, or a Redis database cache. 67 | 68 | Integration with TensorFlow and PyTorch Dataset Loaders 69 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 70 | 71 | Deep-learning libraries like TensorFlow and PyTorch provide high-performance 72 | dataset-generator APIs that facilitate the construction of flexible and 73 | efficient input pipelines. In particular, they have been optimized to support 74 | asynchronous data loading and training, transfer to and from GPUs, and batch 75 | caching. Xbatcher will provide compatible dataset APIs that allow users to pass 76 | Xarray datasets directly to deep-learning frameworks. 77 | 78 | Dependencies 79 | ------------ 80 | 81 | - Core: Xarray, Pandas, Dask, Scikit-learn, Numpy, Scipy 82 | - Optional: Keras, PyTorch, Tensorflow, etc. 83 | 84 | .. _Keras: https://keras.io/ 85 | .. _PyTorch: https://pytorch.org/ 86 | -------------------------------------------------------------------------------- /doc/tutorials-and-presentations.rst: -------------------------------------------------------------------------------- 1 | .. _tutorials-and-presentations: 2 | 3 | Tutorials and Presentations 4 | =========================== 5 | 6 | Tutorials 7 | --------- 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :hidden: 12 | 13 | demo 14 | 15 | .. grid:: 1 2 2 2 16 | :gutter: 2 17 | 18 | 19 | .. grid-item-card:: 20 | :text-align: center 21 | :link: demo.html 22 | 23 | .. image:: https://xbatcher.readthedocs.io/en/latest/_images/demo_4_0.png 24 | :alt: Xbatcher demonstration 25 | +++ 26 | Xbatcher demonstration 27 | 28 | Presentations 29 | ------------- 30 | 31 | .. card:: Xbatcher - A Python Package That Simplifies Feeding Xarray Data Objects to Machine Learning Libraries 32 | 33 | 2023 AMS Annual Meeting 34 | ^^^ 35 | 36 | 37 | | Presentation Recording (starts at 45:30): `AMS Confex `_ 38 | | DOI: `10.6084/m9.figshare.22264072.v1 `_ 39 | 40 | +++ 41 | Max Jones, Joe Hamman, and Wei Ji Leong 42 | -------------------------------------------------------------------------------- /doc/user-guide/caching.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Xbatcher Caching Feature \n", 8 | "\n", 9 | "This notebook demonstrates the new caching feature added to xbatcher's `BatchGenerator`. This feature allows you to cache batches, potentially improving performance for repeated access to the same batches. \n", 10 | "\n", 11 | "\n", 12 | "## Introduction\n", 13 | "\n", 14 | "The caching feature in xbatcher's `BatchGenerator` allows you to store generated batches in a cache, which can significantly speed up subsequent accesses to the same batches. This is particularly useful in scenarios where you need to iterate over the same dataset multiple times. \n", 15 | "\n", 16 | "\n", 17 | "The cache is pluggable, meaning you can use any dict-like object to store the cache. This flexibility allows for various storage backends, including local storage, distributed storage systems, or cloud storage solutions.\n", 18 | "\n", 19 | "## Installation \n", 20 | "\n", 21 | "To use the caching feature, you'll need to have xbatcher installed, along with zarr for serialization. If you haven't already, you can install these using pip:\n", 22 | "\n", 23 | "```bash\n", 24 | "python -m pip install xbatcher zarr\n", 25 | "```\n", 26 | "\n", 27 | "or \n", 28 | "\n", 29 | "using conda:\n", 30 | "\n", 31 | "```bash\n", 32 | "conda install -c conda-forge xbatcher zarr\n", 33 | "```\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Basic Usage \n", 41 | "\n", 42 | "Let's start with a basic example of how to use the caching feature:" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import tempfile\n", 52 | "\n", 53 | "import xarray as xr\n", 54 | "import zarr\n", 55 | "\n", 56 | "import xbatcher" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# create a cache using Zarr's DirectoryStore\n", 66 | "directory = f'{tempfile.mkdtemp()}/xbatcher-cache'\n", 67 | "print(directory)\n", 68 | "cache = zarr.storage.DirectoryStore(directory)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "In this example, we're using a local directory to store the cache, but you could use any zarr-compatible store, such as S3, Redis, etc." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# load a sample dataset\n", 85 | "ds = xr.tutorial.open_dataset('air_temperature', chunks={})\n", 86 | "ds" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# create a BatchGenerator with caching enabled\n", 96 | "gen = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10}, cache=cache)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### Performance Comparison\n", 104 | "\n", 105 | "\n", 106 | "Let's compare the performance with and without caching:\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "import time\n", 116 | "\n", 117 | "\n", 118 | "def time_iteration(gen):\n", 119 | " start = time.time()\n", 120 | " for batch in gen:\n", 121 | " pass\n", 122 | " end = time.time()\n", 123 | " return end - start" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "directory = f'{tempfile.mkdtemp()}/xbatcher-cache'\n", 133 | "cache = zarr.storage.DirectoryStore(directory)\n", 134 | "\n", 135 | "# Without cache\n", 136 | "gen_no_cache = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10})\n", 137 | "time_no_cache = time_iteration(gen_no_cache)\n", 138 | "print(f'Time without cache: {time_no_cache:.2f} seconds')" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# With cache\n", 148 | "gen_with_cache = xbatcher.BatchGenerator(\n", 149 | " ds, input_dims={'lat': 10, 'lon': 10}, cache=cache\n", 150 | ")\n", 151 | "time_first_run = time_iteration(gen_with_cache)\n", 152 | "print(f'Time with cache (first run): {time_first_run:.2f} seconds')\n", 153 | "\n", 154 | "\n", 155 | "time_second_run = time_iteration(gen_with_cache)\n", 156 | "print(f'Time with cache (second run): {time_second_run:.2f} seconds')" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "You should see that the second run with cache is significantly faster than both the first run and the run without cache." 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "## Advanced Usage \n", 171 | "\n", 172 | "### Custom Cache Preprocessing\n", 173 | "\n", 174 | "You can also specify a custom preprocessing function to be applied to batches before they are cached:\n" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "# create a cache using Zarr's DirectoryStore\n", 184 | "directory = f'{tempfile.mkdtemp()}/xbatcher-cache'\n", 185 | "cache = zarr.storage.DirectoryStore(directory)\n", 186 | "\n", 187 | "\n", 188 | "def preprocess_batch(batch):\n", 189 | " # example: add a new variable to each batch\n", 190 | " batch['new_var'] = batch['air'] * 2\n", 191 | " return batch\n", 192 | "\n", 193 | "\n", 194 | "gen_with_preprocess = xbatcher.BatchGenerator(\n", 195 | " ds,\n", 196 | " input_dims={'lat': 10, 'lon': 10},\n", 197 | " cache=cache,\n", 198 | " cache_preprocess=preprocess_batch,\n", 199 | ")\n", 200 | "\n", 201 | "# Now, each cached batch will include the 'new_var' variable\n", 202 | "for batch in gen_with_preprocess:\n", 203 | " print(batch)\n", 204 | " break" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "### Using Different Storage Backends\n", 212 | "\n", 213 | "While we've been using a local directory for caching, you can use any dict-like that is compatible with zarr. For example, you could use an S3 bucket as the cache storage backend:\n", 214 | "\n", 215 | "```python\n", 216 | "import s3fs\n", 217 | "import zarr \n", 218 | "\n", 219 | "# Set up S3 filesystem (you'll need appropriate credentials)\n", 220 | "s3 = s3fs.S3FileSystem(anon=False)\n", 221 | "cache = s3.get_mapper('s3://my-bucket/my-cache.zarr')\n", 222 | "\n", 223 | "# Use this cache with BatchGenerator\n", 224 | "gen_s3 = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10}, cache=cache)\n", 225 | "```\n" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "## Considerations and Best Practices \n", 233 | "\n", 234 | "- **Storage Space**: Be mindful of the storage space required for your cache, especially when working with large datasets.\n", 235 | "- **Cache Invalidation**: The current implementation doesn't handle cache invalidation. If your source data changes, you'll need to manually clear or update the cache.\n", 236 | "- **Performance Tradeoffs**: While caching can significantly speed up repeated access to the same data, the initial caching process may be slower than processing without a cache. Consider your use case to determine if caching is beneficial.\n", 237 | "- **Storage Backend**: Choose a storage backend that's appropriate for your use case. Local storage might be fastest for single-machine applications, while distributed or cloud storage might be necessary for cluster computing or cloud-based workflows.\n", 238 | "\n" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3 (ipykernel)", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.11.9" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 4 268 | } 269 | -------------------------------------------------------------------------------- /doc/user-guide/index.rst: -------------------------------------------------------------------------------- 1 | User Guide 2 | =========== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | caching 9 | training-a-neural-network-with-Pytorch-and-xbatcher 10 | training-a-neural-network-with-keras-and-xbatcher 11 | -------------------------------------------------------------------------------- /doc/user-guide/training-a-neural-network-with-Pytorch-and-xbatcher.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e579c0e1-bb12-4c7b-8a97-bdb6fad01755", 6 | "metadata": {}, 7 | "source": [ 8 | "# End-to-End Tutorial: Training a Neural Network with PyTorch and Xbatcher\n", 9 | "\n", 10 | "This tutorial demonstrates how to use xarray, xbatcher, and PyTorch to train a simple neural network on the FashionMNIST dataset." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5aa4bf55-588a-465d-affb-de5d16a54cdd", 16 | "metadata": {}, 17 | "source": [ 18 | "## Step 1: Setup \n", 19 | "\n", 20 | "Import the necessary libraries and load the dataset" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "916bb4a8-d2df-49e8-9109-a92299960886", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import matplotlib.pyplot as plt\n", 31 | "import torch\n", 32 | "import torch.nn as nn\n", 33 | "import torch.optim as optim\n", 34 | "import torch.utils.data\n", 35 | "import xarray as xr\n", 36 | "\n", 37 | "import xbatcher as xb\n", 38 | "import xbatcher.loaders.torch" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "474b2cc1-9991-4060-92fe-559c15d96678", 45 | "metadata": { 46 | "scrolled": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "ds = xr.open_dataset(\n", 51 | " 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n", 52 | " engine='zarr',\n", 53 | " chunks={},\n", 54 | " backend_kwargs={'storage_options': {'anon': True}},\n", 55 | ")\n", 56 | "ds" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "d647846e-381a-4901-ba1c-d4d47ff7b1fa", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "ds.sel(sample=1).images.plot(cmap='gray');" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "533b9827-8c27-4229-8035-cf39f3e99e54", 72 | "metadata": {}, 73 | "source": [ 74 | "## Step 2: Create batch generator and data loader\n", 75 | "\n", 76 | "We use `xbatcher` to create batch generators for the images (`X_bgen`) and labels (`y_gen`)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "303b2a1d-9126-44e7-b312-fa546eca8f2e", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# Define batch generators\n", 87 | "X_bgen = xb.BatchGenerator(\n", 88 | " ds['images'],\n", 89 | " input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n", 90 | " preload_batch=False,\n", 91 | ")\n", 92 | "y_bgen = xb.BatchGenerator(\n", 93 | " ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n", 94 | ")\n", 95 | "X_bgen[0]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "1c1ab1a7-4bf2-4d73-a8e2-a88e7cdf6829", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# Map batches to a PyTorch-compatible dataset\n", 106 | "dataset = xbatcher.loaders.torch.MapDataset(X_bgen, y_bgen)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "942ee9c8-5369-49d3-8038-0658b80ff851", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Create a DataLoader\n", 117 | "train_dataloader = torch.utils.data.DataLoader(\n", 118 | " dataset,\n", 119 | " batch_size=None, # Using batches defined by the dataset itself (via xbatcher)\n", 120 | " prefetch_factor=3, # Prefetch up to 3 batches in advance to reduce data loading latency\n", 121 | " num_workers=4, # Use 4 parallel worker processes to load data concurrently\n", 122 | " persistent_workers=True, # Keep workers alive between epochs for faster subsequent epochs\n", 123 | " multiprocessing_context='forkserver', # Use \"forkserver\" to spawn subprocesses, ensuring stability in multiprocessing\n", 124 | ")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "16953fd7-53d8-4d37-80e8-57f5b4daccee", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "train_features, train_labels = next(iter(train_dataloader))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "119cbac9-a973-4b37-b42d-12e03f105826", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "print(f'Feature batch shape: {train_features.size()}')\n", 145 | "print(f'Labels batch shape: {train_labels.size()}')\n", 146 | "img = train_features[0].squeeze()\n", 147 | "label = train_labels[0]\n", 148 | "plt.imshow(img, cmap='gray')\n", 149 | "plt.show()\n", 150 | "print(f'Label: {label}')" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "4a6c3219-b281-4782-a77d-58860f7f7c83", 156 | "metadata": {}, 157 | "source": [ 158 | "## Step 3: Define the Neural Network\n", 159 | "\n", 160 | "We define a simple feedforward neural network for classification." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "09620988-bbda-4508-b39e-e1d81f2374c9", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "class SimpleNN(nn.Module):\n", 171 | " def __init__(self):\n", 172 | " super().__init__()\n", 173 | " self.flatten = nn.Flatten()\n", 174 | " self.fc1 = nn.Linear(28 * 28, 128)\n", 175 | " self.fc2 = nn.Linear(128, 10)\n", 176 | "\n", 177 | " def forward(self, x):\n", 178 | " x = self.flatten(x)\n", 179 | " x = torch.relu(self.fc1(x))\n", 180 | " x = self.fc2(x)\n", 181 | " return x" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "cff27830-a15e-4f4a-b529-7fed8ea7632d", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# Instantiate the model\n", 192 | "model = SimpleNN()\n", 193 | "model" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "c12073c2-a380-4a28-bc24-68c1811739b5", 199 | "metadata": {}, 200 | "source": [ 201 | "## Step 4: Define Loss Function and Optimizer\n", 202 | "We use Cross-Entropy Loss and the Adam optimizer." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "ab29bcbc-11e6-48ac-8618-8b6da4627b8f", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "loss_fn = nn.CrossEntropyLoss()\n", 213 | "optimizer = optim.Adam(model.parameters(), lr=0.001)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "ec0bd97d-c9fd-42ae-8962-ec9ed6434331", 219 | "metadata": {}, 220 | "source": [ 221 | "## Step 5: Train the Model\n", 222 | "We train the model using the data loader." 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "76fe4f2d-d15d-43ba-a079-0c3c8c4d65a2", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "%%time\n", 233 | "\n", 234 | "epochs = 5\n", 235 | "\n", 236 | "for epoch in range(epochs):\n", 237 | " print(f'Epoch {epoch+1}/{epochs}')\n", 238 | " for batch, (X, y) in enumerate(train_dataloader):\n", 239 | " # Forward pass\n", 240 | " predictions = model(X)\n", 241 | " loss = loss_fn(predictions, y)\n", 242 | "\n", 243 | " # Backward pass\n", 244 | " optimizer.zero_grad()\n", 245 | " loss.backward()\n", 246 | " optimizer.step()\n", 247 | "\n", 248 | " if batch % 10 == 0:\n", 249 | " print(f'Batch {batch}: Loss = {loss.item():.4f}')\n", 250 | "\n", 251 | "print('Training completed!')" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "aa1db735-9670-41dc-a27e-74369e8c320d", 257 | "metadata": {}, 258 | "source": [ 259 | "## Step 6: Evaluate the Model\n", 260 | "You can evaluate the model on the test set or visualize some predictions." 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "id": "ea26801a-18a9-4ffe-9d04-92157c42bc8a", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "# Visualize a sample prediction\n", 271 | "img = train_features[0].squeeze()\n", 272 | "label = train_labels[0]\n", 273 | "predicted_label = torch.argmax(model(train_features[0:1]), dim=1).item()\n", 274 | "\n", 275 | "plt.imshow(img, cmap='gray')\n", 276 | "plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n", 277 | "plt.show()" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "id": "8c52dcc1-d583-4cfe-af34-8030d4b451f0", 283 | "metadata": {}, 284 | "source": [ 285 | "## Key Highlights\n", 286 | "\n", 287 | "- **Data Handling**: We use Xbatcher to create efficient, chunked data pipelines from Xarray datasets.\n", 288 | "- **Integration**: The `xbatcher.loaders.torch.MapDatase`t enables direct compatibility with PyTorch's DataLoader.\n", 289 | "- **Training**: PyTorch simplifies the model training loop while leveraging the custom data pipeline.\n" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3 (ipykernel)", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.11.9" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 5 314 | } 315 | -------------------------------------------------------------------------------- /doc/user-guide/training-a-neural-network-with-keras-and-xbatcher.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b314e777-7ffb-4e62-b4c5-ce8a785c5181", 6 | "metadata": {}, 7 | "source": [ 8 | "# End-to-End Tutorial: Training a Neural Network with Keras and Xbatcher\n", 9 | "\n", 10 | "## Import Required Libraries" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "5d912ff0-d808-4704-8dea-b9e1b5a53bf1", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import tensorflow as tf\n", 22 | "import xarray as xr\n", 23 | "from keras import layers, models, optimizers\n", 24 | "\n", 25 | "import xbatcher as xb\n", 26 | "import xbatcher.loaders.keras" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "7fb892c1-50fd-48c8-8567-b150946b53c9", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# Open the dataset stored in Zarr format\n", 37 | "ds = xr.open_dataset(\n", 38 | " 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n", 39 | " engine='zarr',\n", 40 | " chunks={},\n", 41 | " backend_kwargs={'storage_options': {'anon': True}},\n", 42 | ")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "c98134fe-581f-412a-93e3-6b07b7706078", 48 | "metadata": {}, 49 | "source": [ 50 | "## Define Batch Generators" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "c680ebd7-0310-4f40-91b5-e7cc1a59e853", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# Define batch generators for features (X) and labels (y)\n", 61 | "X_bgen = xb.BatchGenerator(\n", 62 | " ds['images'],\n", 63 | " input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n", 64 | " preload_batch=False, # Load each batch dynamically\n", 65 | ")\n", 66 | "y_bgen = xb.BatchGenerator(\n", 67 | " ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n", 68 | ")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "91d63180-e3a6-49f7-a8e7-67b8b698b08c", 74 | "metadata": {}, 75 | "source": [ 76 | "## Map Batches to a Keras-Compatible Dataset" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "d1195057-269b-44ba-a3e7-aeedaa4ba8df", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# Use xbatcher's MapDataset to wrap the generators\n", 87 | "dataset = xbatcher.loaders.keras.CustomTFDataset(X_bgen, y_bgen)\n", 88 | "\n", 89 | "# Create a DataLoader using tf.data.Dataset\n", 90 | "train_dataloader = tf.data.Dataset.from_generator(\n", 91 | " lambda: iter(dataset),\n", 92 | " output_signature=(\n", 93 | " tf.TensorSpec(shape=(2000, 1, 28, 28), dtype=tf.float32), # Images\n", 94 | " tf.TensorSpec(shape=(2000,), dtype=tf.int64), # Labels\n", 95 | " ),\n", 96 | ").prefetch(3) # Prefetch 3 batches to improve performance" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "1892411c-ca17-4d7f-b76b-5b5decaa78c1", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "## Visualize a Sample Batch" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "133b24bc-e7bc-4734-ad0a-22a848dd204c", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Extract a batch from the DataLoader\n", 117 | "for train_features, train_labels in train_dataloader.take(1):\n", 118 | " print(f'Feature batch shape: {train_features.shape}')\n", 119 | " print(f'Labels batch shape: {train_labels.shape}')\n", 120 | "\n", 121 | " img = train_features[0].numpy().squeeze() # Extract the first image\n", 122 | " label = train_labels[0].numpy()\n", 123 | " plt.imshow(img, cmap='gray')\n", 124 | " plt.title(f'Label: {label}')\n", 125 | " plt.show()\n", 126 | " break" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "1e5d6a66-1943-47da-be67-9b54d51defed", 132 | "metadata": {}, 133 | "source": [ 134 | "## Build a Simple Neural Network with Keras" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "8b0490e5-7ccc-47fe-90ec-d41a81c4eb20", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# Define a simple feedforward neural network\n", 145 | "model = models.Sequential(\n", 146 | " [\n", 147 | " layers.Flatten(input_shape=(1, 28, 28)), # Flatten input images\n", 148 | " layers.Dense(128, activation='relu'), # Fully connected layer with 128 units\n", 149 | " layers.Dense(10, activation='softmax'), # Output layer for 10 classes\n", 150 | " ]\n", 151 | ")\n", 152 | "\n", 153 | "# Compile the model\n", 154 | "model.compile(\n", 155 | " optimizer=optimizers.Adam(learning_rate=0.001),\n", 156 | " loss='sparse_categorical_crossentropy',\n", 157 | " metrics=['accuracy'],\n", 158 | ")\n", 159 | "\n", 160 | "# Display model summary\n", 161 | "model.summary()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "838df9c6-0753-4120-a0e0-dcc1480416b4", 167 | "metadata": {}, 168 | "source": [ 169 | "## Train the Model " 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "25e86eba-4d4e-47cc-a6a7-9f0be244b009", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "%%time\n", 180 | "\n", 181 | "# Train the model for 5 epochs\n", 182 | "epochs = 5\n", 183 | "\n", 184 | "model.fit(\n", 185 | " train_dataloader, # Pass the DataLoader directly\n", 186 | " epochs=epochs,\n", 187 | " verbose=1, # Print progress during training\n", 188 | ")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "a0f4246c-6461-4e2a-a49d-df6c1ce770fc", 194 | "metadata": {}, 195 | "source": [ 196 | "## Visualize a Sample Prediction" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "9361cb65-3c0d-40d6-be5c-18b309626817", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "# Visualize a prediction on a sample image\n", 207 | "for train_features, train_labels in train_dataloader.take(1):\n", 208 | " img = train_features[0].numpy().squeeze()\n", 209 | " label = train_labels[0].numpy()\n", 210 | " predicted_label = tf.argmax(model.predict(train_features[:1]), axis=1).numpy()[0]\n", 211 | "\n", 212 | " plt.imshow(img, cmap='gray')\n", 213 | " plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n", 214 | " plt.show()\n", 215 | " break" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "372d0e0a-1542-4aa0-b3b9-9fd4337459ba", 221 | "metadata": {}, 222 | "source": [ 223 | "## Key Highlights \n", 224 | "\n", 225 | "- **Dynamic Batching**: Xbatcher and the MapDataset class allow for dynamic loading of batches, which reduces memory usage and speeds up data processing.\n", 226 | "- **Prefetching**: The prefetch feature in `tf.data.Dataset` overlaps data loading with model training to minimize idle time.\n", 227 | "- **Compatibility**: The pipeline works seamlessly with `keras.Model.fit`, simplifying training workflows." 228 | ] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3 (ipykernel)", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.11.9" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 5 252 | } 253 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools-scm[toml]>=6.2", "setuptools>=64"] 4 | 5 | [project] 6 | authors = [ 7 | { name = "xbatcher Developers", email = "rpa@ldeo.columbia.edu" }, 8 | ] 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "Intended Audience :: Science/Research", 12 | "License :: OSI Approved :: Apache Software License", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.10", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | "Programming Language :: Python", 19 | "Topic :: Scientific/Engineering", 20 | ] 21 | dependencies = ["dask", "numpy", "xarray"] 22 | description = "Batch generation from Xarray objects" 23 | dynamic = ["version"] 24 | license = { text = "Apache" } 25 | name = "xbatcher" 26 | readme = "README.rst" 27 | requires-python = ">=3.10" 28 | [project.optional-dependencies] 29 | dev = [ 30 | "asv", 31 | "coverage", 32 | "pytest", 33 | "pytest-cov", 34 | "s3fs", 35 | "tensorflow", 36 | "torch", 37 | "zarr<3.0", 38 | ] 39 | tensorflow = ["tensorflow"] 40 | torch = ["torch"] 41 | [project.urls] 42 | documentation = "https://xbatcher.readthedocs.io/en/latest/" 43 | repository = "https://github.com/xarray-contrib/xbatcher" 44 | 45 | [tool.setuptools.packages.find] 46 | include = ["xbatcher*"] 47 | 48 | [tool.setuptools_scm] 49 | fallback_version = "999" 50 | local_scheme = "node-and-date" 51 | 52 | [tool.ruff] 53 | extend-include = ["*.ipynb"] 54 | target-version = "py310" 55 | 56 | builtins = ["ellipsis"] 57 | # Exclude a variety of commonly ignored directories. 58 | exclude = [ 59 | ".bzr", 60 | ".direnv", 61 | ".eggs", 62 | ".git", 63 | ".git-rewrite", 64 | ".hg", 65 | ".ipynb_checkpoints", 66 | ".mypy_cache", 67 | ".nox", 68 | ".pants.d", 69 | ".pyenv", 70 | ".pytest_cache", 71 | ".pytype", 72 | ".ruff_cache", 73 | ".svn", 74 | ".tox", 75 | ".venv", 76 | ".vscode", 77 | "__pypackages__", 78 | "_build", 79 | "buck-out", 80 | "build", 81 | "dist", 82 | "node_modules", 83 | "site-packages", 84 | "venv", 85 | ] 86 | [tool.ruff.lint] 87 | ignore = [ 88 | "E501", # Conflicts with ruff format 89 | "E721", # Comparing types instead of isinstance 90 | "E741", # Ambiguous variable names 91 | ] 92 | per-file-ignores = {} 93 | select = [ 94 | # Pyflakes 95 | "F", 96 | # Pycodestyle 97 | "E", 98 | "W", 99 | # isort 100 | "I", 101 | # Pyupgrade 102 | "UP", 103 | ] 104 | 105 | [tool.ruff.lint.mccabe] 106 | max-complexity = 18 107 | 108 | [tool.ruff.lint.isort] 109 | known-first-party = ["xbatcher"] 110 | known-third-party = [ 111 | "numpy", 112 | "pandas", 113 | "pytest", 114 | "sphinx_autosummary_accessors", 115 | "torch", 116 | "xarray", 117 | ] 118 | 119 | combine-as-imports = true 120 | 121 | [tool.ruff.format] 122 | docstring-code-format = true 123 | quote-style = "single" 124 | 125 | [tool.ruff.lint.pydocstyle] 126 | convention = "numpy" 127 | 128 | [tool.ruff.lint.pyupgrade] 129 | # Preserve types, even if a file imports `from __future__ import annotations`. 130 | keep-runtime-typing = true 131 | 132 | [tool.pytest.ini_options] 133 | log_cli = true 134 | log_level = "INFO" 135 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | build: 8 | os: "ubuntu-24.04" 9 | tools: 10 | python: "mambaforge-latest" 11 | 12 | # Build documentation in the doc/ directory with Sphinx 13 | sphinx: 14 | configuration: doc/conf.py 15 | fail_on_warning: true 16 | 17 | # Optionally declare the Python requirements required to build your docs 18 | conda: 19 | environment: ci/requirements/doc.yml 20 | -------------------------------------------------------------------------------- /xbatcher/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import ( 2 | PackageNotFoundError as _PackageNotFoundError, 3 | version as _version, 4 | ) 5 | 6 | from . import testing # noqa: F401 7 | from .accessors import BatchAccessor # noqa: F401 8 | from .generators import BatchGenerator, BatchSchema # noqa: F401 9 | from .util.print_versions import show_versions # noqa: F401 10 | 11 | try: 12 | __version__ = _version(__name__) 13 | except _PackageNotFoundError: 14 | # package is not installed 15 | __version__ = 'unknown' 16 | -------------------------------------------------------------------------------- /xbatcher/accessors.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import xarray as xr 4 | 5 | from .generators import BatchGenerator 6 | 7 | 8 | def _as_xarray_dataarray(xr_obj: xr.Dataset | xr.DataArray) -> xr.DataArray: 9 | """ 10 | Convert xarray.Dataset to xarray.DataArray if needed, so that it can 11 | be converted into a Tensor object. 12 | """ 13 | if isinstance(xr_obj, xr.Dataset): 14 | xr_obj = xr_obj.to_array().squeeze(dim='variable') 15 | 16 | return xr_obj 17 | 18 | 19 | @xr.register_dataarray_accessor('batch') 20 | @xr.register_dataset_accessor('batch') 21 | class BatchAccessor: 22 | def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): 23 | """ 24 | Batch accessor returning a BatchGenerator object via the `generator method` 25 | """ 26 | self._obj = xarray_obj 27 | 28 | def generator(self, *args, **kwargs) -> BatchGenerator: 29 | """ 30 | Return a BatchGenerator via the batch accessor 31 | 32 | Parameters 33 | ---------- 34 | *args : iterable 35 | Positional arguments to pass to the `BatchGenerator` constructor. 36 | **kwargs : dict 37 | Keyword arguments to pass to the `BatchGenerator` constructor. 38 | """ 39 | return BatchGenerator(self._obj, *args, **kwargs) 40 | 41 | 42 | @xr.register_dataarray_accessor('tf') 43 | @xr.register_dataset_accessor('tf') 44 | class TFAccessor: 45 | def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): 46 | self._obj = xarray_obj 47 | 48 | def to_tensor(self) -> Any: 49 | """Convert this DataArray to a tensorflow.Tensor""" 50 | import tensorflow as tf 51 | 52 | dataarray = _as_xarray_dataarray(xr_obj=self._obj) 53 | 54 | return tf.convert_to_tensor(dataarray.data) 55 | 56 | 57 | @xr.register_dataarray_accessor('torch') 58 | @xr.register_dataset_accessor('torch') 59 | class TorchAccessor: 60 | def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): 61 | self._obj = xarray_obj 62 | 63 | def to_tensor(self) -> Any: 64 | """Convert this DataArray to a torch.Tensor""" 65 | import torch 66 | 67 | dataarray = _as_xarray_dataarray(xr_obj=self._obj) 68 | 69 | return torch.tensor(data=dataarray.data) 70 | 71 | def to_named_tensor(self) -> Any: 72 | """ 73 | Convert this DataArray to a torch.Tensor with named dimensions. 74 | 75 | See https://pytorch.org/docs/stable/named_tensor.html 76 | """ 77 | import torch 78 | 79 | dataarray = _as_xarray_dataarray(xr_obj=self._obj) 80 | 81 | return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes)) 82 | -------------------------------------------------------------------------------- /xbatcher/generators.py: -------------------------------------------------------------------------------- 1 | """Classes for iterating through xarray datarrays / datasets in batches.""" 2 | 3 | import itertools 4 | import json 5 | import warnings 6 | from collections.abc import Callable, Hashable, Iterator, Sequence 7 | from operator import itemgetter 8 | from typing import Any 9 | 10 | import numpy as np 11 | import xarray as xr 12 | 13 | PatchGenerator = Iterator[dict[Hashable, slice]] 14 | BatchSelector = list[dict[Hashable, slice]] 15 | BatchSelectorSet = dict[int, BatchSelector] 16 | 17 | 18 | class BatchSchema: 19 | """ 20 | A representation of the indices and stacking/transposing parameters needed 21 | to generator batches from Xarray DataArrays and Datasets using 22 | xbatcher.BatchGenerator. 23 | 24 | Parameters 25 | ---------- 26 | ds : ``xarray.Dataset`` or ``xarray.DataArray`` 27 | The data to iterate over. Unlike for the BatchGenerator, the data is 28 | not retained as a class attribute for the BatchSchema. 29 | input_dims : dict 30 | A dictionary specifying the size of the inputs in each dimension, 31 | e.g. ``{'lat': 30, 'lon': 30}`` 32 | These are the dimensions the ML library will see. All other dimensions 33 | will be stacked into one dimension called ``sample``. 34 | input_overlap : dict, optional 35 | A dictionary specifying the overlap along each dimension 36 | e.g. ``{'lat': 3, 'lon': 3}`` 37 | batch_dims : dict, optional 38 | A dictionary specifying the size of the batch along each dimension 39 | e.g. ``{'time': 10}``. These will always be iterated over. 40 | concat_input_dims : bool, optional 41 | If ``True``, the dimension chunks specified in ``input_dims`` will be 42 | concatenated and stacked into the ``sample`` dimension. The batch index 43 | will be included as a new level ``input_batch`` in the ``sample`` 44 | coordinate. 45 | If ``False``, the dimension chunks specified in ``input_dims`` will be 46 | iterated over. 47 | preload_batch : bool, optional 48 | If ``True``, each batch will be loaded into memory before reshaping / 49 | processing, triggering any dask arrays to be computed. 50 | 51 | Notes 52 | ----- 53 | The BatchSchema is experimental and subject to change without notice. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | ds: xr.Dataset | xr.DataArray, 59 | input_dims: dict[Hashable, int], 60 | input_overlap: dict[Hashable, int] | None = None, 61 | batch_dims: dict[Hashable, int] | None = None, 62 | concat_input_bins: bool = True, 63 | preload_batch: bool = True, 64 | ): 65 | if input_overlap is None: 66 | input_overlap = {} 67 | if batch_dims is None: 68 | batch_dims = {} 69 | self.input_dims = dict(input_dims) 70 | self.input_overlap = input_overlap 71 | self.batch_dims = dict(batch_dims) 72 | self.concat_input_dims = concat_input_bins 73 | self.preload_batch = preload_batch 74 | # Store helpful information based on arguments 75 | self._duplicate_batch_dims: dict[Hashable, int] = { 76 | dim: length 77 | for dim, length in self.batch_dims.items() 78 | if self.input_dims.get(dim) is not None 79 | } 80 | self._unique_batch_dims: dict[Hashable, int] = { 81 | dim: length 82 | for dim, length in self.batch_dims.items() 83 | if self.input_dims.get(dim) is None 84 | } 85 | self._input_stride: dict[Hashable, int] = { 86 | dim: length - self.input_overlap.get(dim, 0) 87 | for dim, length in self.input_dims.items() 88 | } 89 | self._all_sliced_dims: dict[Hashable, int] = dict( 90 | **self._unique_batch_dims, **self.input_dims 91 | ) 92 | self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) 93 | 94 | def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet: 95 | """ 96 | Create batch selectors dict, which can be used to create a batch 97 | from an Xarray data object. 98 | """ 99 | # Create an iterator that returns an object usable for .isel in xarray 100 | patch_selectors = self._gen_patch_selectors(ds) 101 | # Create the Dict containing batch selectors 102 | if self.concat_input_dims: # Combine the patches into batches 103 | return self._combine_patches_into_batch(ds, patch_selectors) 104 | else: # Each patch gets its own batch 105 | return {ind: [value] for ind, value in enumerate(patch_selectors)} 106 | 107 | def _gen_patch_selectors(self, ds: xr.DataArray | xr.Dataset) -> PatchGenerator: 108 | """ 109 | Create an iterator that can be used to index an Xarray Dataset/DataArray. 110 | """ 111 | if self._duplicate_batch_dims and not self.concat_input_dims: 112 | warnings.warn( 113 | 'The following dimensions were included in both ``input_dims`` ' 114 | 'and ``batch_dims``. Since ``concat_input_dims`` is ``False``, ' 115 | f'these dimensions will not impact batch generation: {self._duplicate_batch_dims}' 116 | ) 117 | # Generate the slices by iterating over batch_dims and input_dims 118 | all_slices = _iterate_through_dimensions( 119 | ds, 120 | dims=self._all_sliced_dims, 121 | overlap=self.input_overlap, 122 | ) 123 | return all_slices 124 | 125 | def _combine_patches_into_batch( 126 | self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator 127 | ) -> BatchSelectorSet: 128 | """ 129 | Combine the patch selectors to form a batch 130 | """ 131 | # Check that patches are only combined with concat_input_dims 132 | if not self.concat_input_dims: 133 | raise AssertionError( 134 | 'Patches should only be combined into batches when ``concat_input_dims`` is ``True``' 135 | ) 136 | if not self.batch_dims: 137 | return self._combine_patches_into_one_batch(patch_selectors) 138 | elif self._duplicate_batch_dims: 139 | return self._combine_patches_grouped_by_input_and_batch_dims( 140 | ds=ds, patch_selectors=patch_selectors 141 | ) 142 | else: 143 | return self._combine_patches_grouped_by_batch_dims(patch_selectors) 144 | 145 | def _combine_patches_into_one_batch( 146 | self, patch_selectors: PatchGenerator 147 | ) -> BatchSelectorSet: 148 | """ 149 | Group all patches into a single batch 150 | """ 151 | return dict(enumerate([list(patch_selectors)])) 152 | 153 | def _combine_patches_grouped_by_batch_dims( 154 | self, patch_selectors: PatchGenerator 155 | ) -> BatchSelectorSet: 156 | """ 157 | Group patches based on the unique slices for dimensions in ``batch_dims`` 158 | """ 159 | batch_selectors = [ 160 | list(value) 161 | for _, value in itertools.groupby( 162 | patch_selectors, key=itemgetter(*self.batch_dims) 163 | ) 164 | ] 165 | return dict(enumerate(batch_selectors)) 166 | 167 | def _combine_patches_grouped_by_input_and_batch_dims( 168 | self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator 169 | ) -> BatchSelectorSet: 170 | """ 171 | Combine patches with multiple slices along ``batch_dims`` grouped into 172 | each patch. Required when a dimension is duplicated between ``batch_dims`` 173 | and ``input_dims``. 174 | """ 175 | self._gen_patch_numbers(ds) 176 | self._gen_batch_numbers(ds) 177 | batch_id_per_patch = self._get_batch_multi_index_per_patch() 178 | patch_in_range = self._get_batch_in_range_per_batch( 179 | batch_multi_index=batch_id_per_patch 180 | ) 181 | batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch) 182 | batch_selectors = self._gen_empty_batch_selectors() 183 | for i, patch in enumerate(patch_selectors): 184 | if patch_in_range[i]: 185 | batch_selectors[batch_id_per_patch[i]].append(patch) 186 | return batch_selectors 187 | 188 | def _gen_empty_batch_selectors(self) -> BatchSelectorSet: 189 | """ 190 | Create an empty batch selector set that can be populated by appending 191 | patches to each batch. 192 | """ 193 | n_batches = np.prod(list(self._n_batches_per_dim.values())) 194 | return {k: [] for k in range(n_batches)} 195 | 196 | def _gen_patch_numbers(self, ds: xr.DataArray | xr.Dataset): 197 | """ 198 | Calculate the number of patches per dimension and the number of patches 199 | in each batch per dimension. 200 | """ 201 | self._n_patches_per_batch: dict[Hashable, int] = { 202 | dim: int(np.ceil(length / self._input_stride.get(dim, length))) 203 | for dim, length in self.batch_dims.items() 204 | } 205 | self._n_patches_per_dim: dict[Hashable, int] = { 206 | dim: int( 207 | (ds.sizes[dim] - self.input_overlap.get(dim, 0)) 208 | // (length - self.input_overlap.get(dim, 0)) 209 | ) 210 | for dim, length in self._all_sliced_dims.items() 211 | } 212 | 213 | def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset): 214 | """ 215 | Calculate the number of batches per dimension 216 | """ 217 | self._n_batches_per_dim: dict[Hashable, int] = { 218 | dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) 219 | for dim in self._all_sliced_dims.keys() 220 | } 221 | 222 | def _get_batch_multi_index_per_patch(self): 223 | """ 224 | Calculate the batch multi-index for each patch 225 | """ 226 | batch_id_per_dim: dict[Hashable, Any] = { 227 | dim: np.floor( 228 | np.arange(0, n_patches) 229 | / self._n_patches_per_batch.get(dim, n_patches + 1) 230 | ).astype(np.int64) 231 | for dim, n_patches in self._n_patches_per_dim.items() 232 | } 233 | batch_id_per_patch = np.array( 234 | list(itertools.product(*batch_id_per_dim.values())) 235 | ).transpose() 236 | return batch_id_per_patch 237 | 238 | def _ravel_batch_multi_index(self, batch_multi_index): 239 | """ 240 | Convert the batch multi-index to a flat index for each patch 241 | """ 242 | return np.ravel_multi_index( 243 | multi_index=batch_multi_index, 244 | dims=tuple(self._n_batches_per_dim.values()), 245 | mode='clip', 246 | ) 247 | 248 | def _get_batch_in_range_per_batch(self, batch_multi_index): 249 | """ 250 | Determine whether each patch is contained within any of the batches. 251 | """ 252 | batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int) 253 | batch_id_maximum = np.pad( 254 | batch_id_maximum, 255 | (0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))), 256 | constant_values=(1), 257 | ) 258 | batch_id_maximum = batch_id_maximum[:, np.newaxis] 259 | batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0) 260 | return batch_in_range_per_patch 261 | 262 | def to_json(self): 263 | """ 264 | Dump the BatchSchema properties to a JSON file. 265 | 266 | Returns 267 | ---------- 268 | out_json: str 269 | The JSON representation of the BatchSchema 270 | """ 271 | out_dict = { 272 | 'input_dims': self.input_dims, 273 | 'input_overlap': self.input_overlap, 274 | 'batch_dims': self.batch_dims, 275 | 'concat_input_dims': self.input_dims, 276 | 'preload_batch': self.preload_batch, 277 | } 278 | batch_selector_dict = {} 279 | for i in self.selectors.keys(): 280 | batch_selector_dict[i] = self.selectors[i] 281 | for member in batch_selector_dict[i]: 282 | member_keys = list(member.keys()) 283 | out_member_dict = { 284 | member_key: { 285 | 'start': member[member_key].start, 286 | 'stop': member[member_key].stop, 287 | 'step': member[member_key].step, 288 | } 289 | for member_key in member_keys 290 | } 291 | out_dict['selector'] = out_member_dict 292 | return json.dumps(out_dict) 293 | 294 | def to_file(self, out_file_name: str): 295 | """ 296 | Dumps the JSON representation of the BatchSchema object to a file. 297 | 298 | Parameters 299 | ---------- 300 | out_file_name: str 301 | The path to the json file to write to. 302 | """ 303 | out_json = self.to_json() 304 | with open(out_file_name, mode='w') as out_file: 305 | out_file.write(out_json) 306 | 307 | 308 | def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[slice]: 309 | # return a list of slices to chop up a single dimension 310 | if overlap >= slice_size: 311 | raise ValueError( 312 | 'input overlap must be less than the input sample length, but ' 313 | f'the input sample length is {slice_size} and the overlap is {overlap}' 314 | ) 315 | slices = [] 316 | stride = slice_size - overlap 317 | for start in range(0, dim_size, stride): 318 | end = start + slice_size 319 | if end <= dim_size: 320 | slices.append(slice(start, end)) 321 | return slices 322 | 323 | 324 | def _iterate_through_dimensions( 325 | ds: xr.Dataset | xr.DataArray, 326 | *, 327 | dims: dict[Hashable, int], 328 | overlap: dict[Hashable, int] | None = None, 329 | ) -> Iterator[dict[Hashable, slice]]: 330 | if overlap is None: 331 | overlap = {} 332 | dim_slices = [] 333 | for dim, slice_size in dims.items(): 334 | dim_size = ds.sizes[dim] 335 | slice_overlap = overlap.get(dim, 0) 336 | if slice_size > dim_size: 337 | raise ValueError( 338 | 'input sample length must be less than or equal to the ' 339 | f'dimension length, but the sample length of {slice_size} ' 340 | f'is greater than the dimension length of {dim_size} ' 341 | f'for {dim}' 342 | ) 343 | dim_slices.append( 344 | _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap) 345 | ) 346 | for slices in itertools.product(*dim_slices): 347 | selector = dict(zip(dims, slices)) 348 | yield selector 349 | 350 | 351 | def _drop_input_dims( 352 | ds: xr.Dataset | xr.DataArray, 353 | input_dims: dict[Hashable, int], 354 | suffix: str = '_input', 355 | ) -> xr.Dataset | xr.DataArray: 356 | # remove input_dims coordinates from datasets, rename the dimensions 357 | # then put intput_dims back in as coordinates 358 | out = ds.copy() 359 | for dim in input_dims: 360 | newdim = f'{dim}{suffix}' 361 | out = out.rename({dim: newdim}) 362 | # extra steps needed if there is a coordinate 363 | if newdim in out: 364 | out = out.drop_vars(newdim) 365 | out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs 366 | return out 367 | 368 | 369 | def _maybe_stack_batch_dims( 370 | ds: xr.Dataset | xr.DataArray, 371 | input_dims: Sequence[Hashable], 372 | ) -> xr.Dataset | xr.DataArray: 373 | batch_dims = [d for d in ds.sizes if d not in input_dims] 374 | if len(batch_dims) < 2: 375 | return ds 376 | ds_stack = ds.stack(sample=batch_dims) 377 | # ensure correct order 378 | dim_order = ('sample',) + tuple(input_dims) 379 | return ds_stack.transpose(*dim_order) 380 | 381 | 382 | class BatchGenerator: 383 | """Create generator for iterating through Xarray DataArrays / Datasets in 384 | batches. 385 | 386 | Parameters 387 | ---------- 388 | ds : ``xarray.Dataset`` or ``xarray.DataArray`` 389 | The data to iterate over 390 | input_dims : dict 391 | A dictionary specifying the size of the inputs in each dimension, 392 | e.g. ``{'lat': 30, 'lon': 30}`` 393 | These are the dimensions the ML library will see. All other dimensions 394 | will be stacked into one dimension called ``sample``. 395 | input_overlap : dict, optional 396 | A dictionary specifying the overlap along each dimension 397 | e.g. ``{'lat': 3, 'lon': 3}`` 398 | batch_dims : dict, optional 399 | A dictionary specifying the size of the batch along each dimension 400 | e.g. ``{'time': 10}``. These will always be iterated over. 401 | concat_input_dims : bool, optional 402 | If ``True``, the dimension chunks specified in ``input_dims`` will be 403 | concatenated and stacked into the ``sample`` dimension. The batch index 404 | will be included as a new level ``input_batch`` in the ``sample`` 405 | coordinate. 406 | If ``False``, the dimension chunks specified in ``input_dims`` will be 407 | iterated over. 408 | preload_batch : bool, optional 409 | If ``True``, each batch will be loaded into memory before reshaping / 410 | processing, triggering any dask arrays to be computed. 411 | cache : dict, optional 412 | Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note: 413 | The caching API is experimental and subject to change. 414 | cache_preprocess: callable, optional 415 | A function to apply to batches prior to caching. 416 | Note: The caching API is experimental and subject to change. 417 | 418 | Yields 419 | ------ 420 | ds_slice : ``xarray.Dataset`` or ``xarray.DataArray`` 421 | Slices of the array matching the given batch size specification. 422 | """ 423 | 424 | def __init__( 425 | self, 426 | ds: xr.Dataset | xr.DataArray, 427 | input_dims: dict[Hashable, int], 428 | input_overlap: dict[Hashable, int] | None = None, 429 | batch_dims: dict[Hashable, int] | None = None, 430 | concat_input_dims: bool = False, 431 | preload_batch: bool = True, 432 | cache: dict[str, Any] | None = None, 433 | cache_preprocess: Callable | None = None, 434 | ): 435 | if input_overlap is None: 436 | input_overlap = {} 437 | if batch_dims is None: 438 | batch_dims = {} 439 | self.ds = ds 440 | self.cache = cache 441 | self.cache_preprocess = cache_preprocess 442 | 443 | self._batch_selectors: BatchSchema = BatchSchema( 444 | ds, 445 | input_dims=input_dims, 446 | input_overlap=input_overlap, 447 | batch_dims=batch_dims, 448 | concat_input_bins=concat_input_dims, 449 | preload_batch=preload_batch, 450 | ) 451 | 452 | @property 453 | def input_dims(self): 454 | return self._batch_selectors.input_dims 455 | 456 | @property 457 | def input_overlap(self): 458 | return self._batch_selectors.input_overlap 459 | 460 | @property 461 | def batch_dims(self): 462 | return self._batch_selectors.batch_dims 463 | 464 | @property 465 | def concat_input_dims(self): 466 | return self._batch_selectors.concat_input_dims 467 | 468 | @property 469 | def preload_batch(self): 470 | return self._batch_selectors.preload_batch 471 | 472 | def __iter__(self) -> Iterator[xr.DataArray | xr.Dataset]: 473 | for idx in self._batch_selectors.selectors: 474 | yield self[idx] 475 | 476 | def __len__(self) -> int: 477 | return len(self._batch_selectors.selectors) 478 | 479 | def __getitem__(self, idx: int) -> xr.Dataset | xr.DataArray: 480 | if not isinstance(idx, int): 481 | raise NotImplementedError( 482 | f'{type(self).__name__}.__getitem__ currently requires a single integer key' 483 | ) 484 | 485 | if idx < 0: 486 | idx = list(self._batch_selectors.selectors)[idx] 487 | 488 | if self.cache and self._batch_in_cache(idx): 489 | return self._get_cached_batch(idx) 490 | 491 | if idx not in self._batch_selectors.selectors: 492 | raise IndexError('list index out of range') 493 | 494 | if self.concat_input_dims: 495 | new_dim_suffix = '_input' 496 | all_dsets: list = [] 497 | batch_selector = {} 498 | for dim in self._batch_selectors.batch_dims.keys(): 499 | starts = [x[dim].start for x in self._batch_selectors.selectors[idx]] 500 | stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] 501 | batch_selector[dim] = slice(min(starts), max(stops)) 502 | batch_ds = self.ds.isel(batch_selector) 503 | if self.preload_batch: 504 | batch_ds.load() 505 | for selector in self._batch_selectors.selectors[idx]: 506 | patch_ds = self.ds.isel(selector) 507 | all_dsets.append( 508 | _drop_input_dims( 509 | patch_ds, 510 | self.input_dims, 511 | suffix=new_dim_suffix, 512 | ) 513 | ) 514 | dsc = xr.concat(all_dsets, dim='input_batch') 515 | new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] 516 | batch = _maybe_stack_batch_dims(dsc, new_input_dims) 517 | else: 518 | batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) 519 | if self.preload_batch: 520 | batch_ds.load() 521 | batch = _maybe_stack_batch_dims( 522 | batch_ds, 523 | list(self.input_dims), 524 | ) 525 | if self.cache is not None and self.cache_preprocess is not None: 526 | batch = self.cache_preprocess(batch) 527 | if self.cache is not None: 528 | self._cache_batch(idx, batch) 529 | 530 | return batch 531 | 532 | def _batch_in_cache(self, idx: int) -> bool: 533 | return self.cache is not None and f'{idx}/.zgroup' in self.cache 534 | 535 | def _cache_batch(self, idx: int, batch: xr.Dataset | xr.DataArray) -> None: 536 | batch.to_zarr(self.cache, group=str(idx), mode='a') 537 | 538 | def _get_cached_batch(self, idx: int) -> xr.Dataset: 539 | ds = xr.open_zarr(self.cache, group=str(idx)) 540 | if self.preload_batch: 541 | ds = ds.load() 542 | return ds 543 | -------------------------------------------------------------------------------- /xbatcher/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xarray-contrib/xbatcher/5898a76ce88200a28eed036f0fbec9890a4280b5/xbatcher/loaders/__init__.py -------------------------------------------------------------------------------- /xbatcher/loaders/keras.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | try: 5 | import tensorflow as tf 6 | except ImportError as exc: 7 | raise ImportError( 8 | 'The Xbatcher TensorFlow Dataset API depends on TensorFlow. Please ' 9 | 'install TensorFlow to proceed.' 10 | ) from exc 11 | 12 | # Notes: 13 | # This module includes one Keras dataset, which can be provided to model.fit(). 14 | # - The CustomTFDataset provides an indexable interface 15 | # Assumptions made: 16 | # - The dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset) 17 | 18 | 19 | class CustomTFDataset(tf.keras.utils.Sequence): 20 | def __init__( 21 | self, 22 | X_generator, 23 | y_generator, 24 | *, 25 | transform: Callable | None = None, 26 | target_transform: Callable | None = None, 27 | ) -> None: 28 | """ 29 | Keras Dataset adapter for Xbatcher 30 | 31 | Parameters 32 | ---------- 33 | X_generator : xbatcher.BatchGenerator 34 | y_generator : xbatcher.BatchGenerator 35 | transform : callable, optional 36 | A function/transform that takes in an array and returns a transformed version. 37 | target_transform : callable, optional 38 | A function/transform that takes in the target and transforms it. 39 | """ 40 | self.X_generator = X_generator 41 | self.y_generator = y_generator 42 | self.transform = transform 43 | self.target_transform = target_transform 44 | 45 | def __len__(self) -> int: 46 | return len(self.X_generator) 47 | 48 | def __getitem__(self, idx: int) -> tuple[Any, Any]: 49 | X_batch = tf.convert_to_tensor(self.X_generator[idx].data) 50 | y_batch = tf.convert_to_tensor(self.y_generator[idx].data) 51 | 52 | # TODO: Should the transformations be applied before tensor conversion? 53 | if self.transform: 54 | X_batch = self.transform(X_batch) 55 | 56 | if self.target_transform: 57 | y_batch = self.target_transform(y_batch) 58 | return X_batch, y_batch 59 | -------------------------------------------------------------------------------- /xbatcher/loaders/torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from types import ModuleType 5 | 6 | import xarray as xr 7 | 8 | from xbatcher import BatchGenerator 9 | 10 | try: 11 | import torch 12 | except ImportError as exc: 13 | raise ImportError( 14 | 'The Xbatcher PyTorch Dataset API depends on PyTorch. Please ' 15 | 'install PyTorch to proceed.' 16 | ) from exc 17 | 18 | try: 19 | import dask 20 | except ImportError: 21 | dask: ModuleType | None = None # type: ignore[no-redef] 22 | 23 | T_DataArrayOrSet = xr.DataArray | xr.Dataset 24 | 25 | # Notes: 26 | # This module includes two PyTorch datasets. 27 | # - The MapDataset provides an indexable interface 28 | # - The IterableDataset provides a simple iterable interface 29 | # Both can be provided as arguments to the the Torch DataLoader 30 | # Assumptions made: 31 | # - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset) 32 | # TODOs: 33 | # - need to test with additional dataset parameters (e.g. transforms) 34 | 35 | 36 | def to_tensor(xr_obj: T_DataArrayOrSet) -> torch.Tensor: 37 | """Convert this DataArray or Dataset to a torch.Tensor""" 38 | if isinstance(xr_obj, xr.Dataset): 39 | xr_obj = xr_obj.to_array().squeeze(dim='variable') 40 | if isinstance(xr_obj, xr.DataArray): 41 | xr_obj = xr_obj.data 42 | return torch.tensor(xr_obj) 43 | 44 | 45 | class MapDataset(torch.utils.data.Dataset): 46 | def __init__( 47 | self, 48 | X_generator: BatchGenerator, 49 | y_generator: BatchGenerator | None = None, 50 | transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor, 51 | target_transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor, 52 | ) -> None: 53 | """ 54 | PyTorch Dataset adapter for Xbatcher 55 | 56 | Parameters 57 | ---------- 58 | X_generator : xbatcher.BatchGenerator 59 | y_generator : xbatcher.BatchGenerator 60 | transform, target_transform : callable, optional 61 | A function/transform that takes in an Xarray object and returns a transformed version in the form of a torch.Tensor. 62 | """ 63 | self.X_generator = X_generator 64 | self.y_generator = y_generator 65 | self.transform = transform 66 | self.target_transform = target_transform 67 | 68 | def __len__(self) -> int: 69 | return len(self.X_generator) 70 | 71 | def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: 72 | if torch.is_tensor(idx): 73 | idx = idx.tolist() 74 | if len(idx) == 1: 75 | idx = idx[0] 76 | else: 77 | raise NotImplementedError( 78 | f'{type(self).__name__}.__getitem__ currently requires a single integer key' 79 | ) 80 | 81 | # generate batch (or batches) 82 | if self.y_generator is not None: 83 | X_batch, y_batch = self.X_generator[idx], self.y_generator[idx] 84 | else: 85 | X_batch, y_batch = self.X_generator[idx], None 86 | 87 | # load batch (or batches) with dask if possible 88 | if dask is not None: 89 | X_batch, y_batch = dask.compute(X_batch, y_batch) 90 | 91 | # apply transformation(s) 92 | X_batch_tensor = self.transform(X_batch) 93 | if y_batch is not None: 94 | y_batch_tensor = self.target_transform(y_batch) 95 | 96 | assert isinstance(X_batch_tensor, torch.Tensor), self.transform 97 | 98 | if y_batch is None: 99 | return X_batch_tensor 100 | assert isinstance(y_batch_tensor, torch.Tensor) 101 | return X_batch_tensor, y_batch_tensor 102 | 103 | 104 | class IterableDataset(torch.utils.data.IterableDataset): 105 | def __init__( 106 | self, 107 | X_generator, 108 | y_generator, 109 | ) -> None: 110 | """ 111 | PyTorch Dataset adapter for Xbatcher 112 | 113 | Parameters 114 | ---------- 115 | X_generator : xbatcher.BatchGenerator 116 | y_generator : xbatcher.BatchGenerator 117 | """ 118 | 119 | self.X_generator = X_generator 120 | self.y_generator = y_generator 121 | 122 | def __iter__(self): 123 | for xb, yb in zip(self.X_generator, self.y_generator): 124 | yield (xb.torch.to_tensor(), yb.torch.to_tensor()) 125 | -------------------------------------------------------------------------------- /xbatcher/testing.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Hashable 2 | from unittest import TestCase 3 | 4 | import numpy as np 5 | import xarray as xr 6 | 7 | from .generators import BatchGenerator 8 | 9 | 10 | def _get_non_specified_dims(generator: BatchGenerator) -> dict[Hashable, int]: 11 | """ 12 | Return all dimensions that are in the input dataset but not ``input_dims`` 13 | or ``batch_dims``. 14 | 15 | Parameters 16 | ---------- 17 | generator : xbatcher.BatchGenerator 18 | The batch generator object. 19 | 20 | Returns 21 | ------- 22 | d : dict 23 | Dict containing all dimensions in the input dataset that are not 24 | in the input_dims or batch_dims attributes of the batch generator. 25 | """ 26 | return { 27 | dim: length 28 | for dim, length in generator.ds.sizes.items() 29 | if generator.input_dims.get(dim) is None 30 | and generator.batch_dims.get(dim) is None 31 | } 32 | 33 | 34 | def _get_non_input_batch_dims(generator: BatchGenerator) -> dict[Hashable, int]: 35 | """ 36 | Return all dimensions that are in batch_dims but not input_dims. 37 | 38 | Parameters 39 | ---------- 40 | generator : xbatcher.BatchGenerator 41 | The batch generator object. 42 | 43 | Returns 44 | ------- 45 | d : dict 46 | Dict containing all dimensions in specified in batch_dims that are 47 | not also in input_dims 48 | """ 49 | return { 50 | dim: length 51 | for dim, length in generator.batch_dims.items() 52 | if generator.input_dims.get(dim) is None 53 | } 54 | 55 | 56 | def _get_duplicate_batch_dims(generator: BatchGenerator) -> dict[Hashable, int]: 57 | """ 58 | Return all dimensions that are in both batch_dims and input_dims. 59 | 60 | Parameters 61 | ---------- 62 | generator : xbatcher.BatchGenerator 63 | The batch generator object. 64 | 65 | Returns 66 | ------- 67 | d : dict 68 | Dict containing all dimensions duplicated between batch_dims and input_dims. 69 | """ 70 | return { 71 | dim: length 72 | for dim, length in generator.batch_dims.items() 73 | if generator.input_dims.get(dim) is not None 74 | } 75 | 76 | 77 | def _get_sample_length( 78 | *, 79 | generator: BatchGenerator, 80 | non_specified_ds_dims: dict[Hashable, int], 81 | non_input_batch_dims: dict[Hashable, int], 82 | ) -> int: 83 | """ 84 | Return the expected length of the sample dimension. 85 | 86 | Parameters 87 | ---------- 88 | generator : xbatcher.BatchGenerator 89 | The batch generator object. 90 | non_specified_ds_dics : dict 91 | Dict containing all dimensions in the input dataset that are not 92 | in the input_dims or batch_dims attributes of the batch generator. 93 | non_input_batch_dims : dict 94 | Dict containing all dimensions in specified in batch_dims that are 95 | not also in input_dims 96 | 97 | Returns 98 | ------- 99 | s : int 100 | Expected length of the sample dimension 101 | """ 102 | if generator.concat_input_dims: 103 | batch_concat_dims = [ 104 | ( 105 | generator.batch_dims.get(dim) // length 106 | if generator.batch_dims.get(dim) 107 | else generator.ds.sizes.get(dim) // length 108 | ) 109 | for dim, length in generator.input_dims.items() 110 | ] 111 | else: 112 | batch_concat_dims = [] 113 | return int( 114 | np.prod(list(non_specified_ds_dims.values())) 115 | * np.prod(list(non_input_batch_dims.values())) 116 | * np.prod(batch_concat_dims) 117 | ) 118 | 119 | 120 | def get_batch_dimensions(generator: BatchGenerator) -> dict[Hashable, int]: 121 | """ 122 | Return the expected batch dimensions based on the ``input_dims``, 123 | ``batch_dims``, and ``concat_input_dims`` attributes of the batch 124 | generator. 125 | 126 | Parameters 127 | ---------- 128 | generator : xbatcher.BatchGenerator 129 | The batch generator object. 130 | 131 | Returns 132 | ------- 133 | d : dict 134 | Dict containing the expected dimensions for batches returned by the 135 | batch generator. 136 | """ 137 | # dimensions that are in the input dataset but not input_dims or batch_dims 138 | non_specified_ds_dims = _get_non_specified_dims(generator) 139 | # dimensions that are in batch_dims but not input_dims 140 | non_input_batch_dims = _get_non_input_batch_dims(generator) 141 | expected_sample_length = _get_sample_length( 142 | generator=generator, 143 | non_specified_ds_dims=non_specified_ds_dims, 144 | non_input_batch_dims=non_input_batch_dims, 145 | ) 146 | # input_dims stay the same, possibly with a new suffix 147 | expected_dims = { 148 | f'{k}_input' if generator.concat_input_dims else k: v 149 | for k, v in generator.input_dims.items() 150 | } 151 | # Add a sample dimension if there's anything to get stacked 152 | if ( 153 | generator.concat_input_dims 154 | and (len(generator.ds.sizes) - len(generator.input_dims)) == 0 155 | ): 156 | expected_dims = {**{'input_batch': expected_sample_length}, **expected_dims} 157 | elif ( 158 | generator.concat_input_dims 159 | or (len(generator.ds.sizes) - len(generator.input_dims)) > 1 160 | ): 161 | expected_dims = {**{'sample': expected_sample_length}, **expected_dims} 162 | else: 163 | expected_dims = dict( 164 | **non_specified_ds_dims, 165 | **non_input_batch_dims, 166 | **expected_dims, 167 | ) 168 | return expected_dims 169 | 170 | 171 | def validate_batch_dimensions( 172 | *, expected_dims: dict[Hashable, int], batch: xr.Dataset | xr.DataArray 173 | ) -> None: 174 | """ 175 | Raises an AssertionError if the shape and dimensions of a batch do not 176 | match expected_dims. 177 | 178 | Parameters 179 | ---------- 180 | expected_dims : Dict 181 | Dict containing the expected dimensions for batches. 182 | batch : xarray.Dataset or xarray.DataArray 183 | The xarray data object returned by the batch generator. 184 | """ 185 | 186 | # Check the names and lengths of the dimensions are equal 187 | TestCase().assertDictEqual( 188 | expected_dims, dict(batch.sizes), msg='Dimension names and/or lengths differ' 189 | ) 190 | # Check the dimension order is equal 191 | for var in batch.data_vars: 192 | TestCase().assertEqual( 193 | tuple(expected_dims.values()), 194 | batch[var].shape, 195 | msg=f'Order differs for dimensions of: {expected_dims}', 196 | ) 197 | 198 | 199 | def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: 200 | """ 201 | Calculate the number of batches expected based on ``input_dims`` and 202 | ``input_overlap``. 203 | 204 | Parameters 205 | ---------- 206 | generator : xbatcher.BatchGenerator 207 | The batch generator object. 208 | 209 | Returns 210 | ------- 211 | s : int 212 | Number of batches expected given ``input_dims`` and ``input_overlap``. 213 | """ 214 | nbatches_from_input_dims = np.prod( 215 | [ 216 | generator.ds.sizes[dim] // length 217 | for dim, length in generator.input_dims.items() 218 | if generator.input_overlap.get(dim) is None 219 | and generator.batch_dims.get(dim) is None 220 | ] 221 | ) 222 | if generator.input_overlap: 223 | nbatches_from_input_overlap = np.prod( 224 | [ 225 | (generator.ds.sizes[dim] - overlap) 226 | // (generator.input_dims[dim] - overlap) 227 | for dim, overlap in generator.input_overlap.items() 228 | ] 229 | ) 230 | return int(nbatches_from_input_overlap * nbatches_from_input_dims) 231 | else: 232 | return int(nbatches_from_input_dims) 233 | 234 | 235 | def validate_generator_length(generator: BatchGenerator) -> None: 236 | """ 237 | Raises an AssertionError if the generator length does not match 238 | expectations based on the input Dataset and ``input_dims``. 239 | 240 | Parameters 241 | ---------- 242 | generator : xbatcher.BatchGenerator 243 | The batch generator object. 244 | """ 245 | non_input_batch_dims = _get_non_input_batch_dims(generator) 246 | duplicate_batch_dims = _get_duplicate_batch_dims(generator) 247 | nbatches_from_unique_batch_dims = np.prod( 248 | [ 249 | generator.ds.sizes[dim] // length 250 | for dim, length in non_input_batch_dims.items() 251 | ] 252 | ) 253 | nbatches_from_duplicate_batch_dims = np.prod( 254 | [ 255 | generator.ds.sizes[dim] // length 256 | for dim, length in duplicate_batch_dims.items() 257 | ] 258 | ) 259 | if generator.concat_input_dims: 260 | expected_length = int( 261 | nbatches_from_unique_batch_dims * nbatches_from_duplicate_batch_dims 262 | ) 263 | else: 264 | nbatches_from_input_dims = _get_nbatches_from_input_dims(generator) 265 | expected_length = int( 266 | nbatches_from_unique_batch_dims 267 | * nbatches_from_duplicate_batch_dims 268 | * nbatches_from_input_dims 269 | ) 270 | TestCase().assertEqual( 271 | expected_length, 272 | len(generator), 273 | msg='Batch generator length differs', 274 | ) 275 | -------------------------------------------------------------------------------- /xbatcher/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xarray-contrib/xbatcher/5898a76ce88200a28eed036f0fbec9890a4280b5/xbatcher/tests/__init__.py -------------------------------------------------------------------------------- /xbatcher/tests/test_accessors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import xarray as xr 4 | 5 | import xbatcher # noqa: F401 6 | from xbatcher import BatchGenerator 7 | 8 | 9 | @pytest.fixture(scope='module') 10 | def sample_ds_3d(): 11 | shape = (10, 50, 100) 12 | ds = xr.Dataset( 13 | { 14 | 'foo': (['time', 'y', 'x'], np.random.rand(*shape)), 15 | 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)), 16 | }, 17 | { 18 | 'x': (['x'], np.arange(shape[-1])), 19 | 'y': (['y'], np.arange(shape[-2])), 20 | }, 21 | ) 22 | return ds 23 | 24 | 25 | @pytest.fixture(scope='module') 26 | def sample_dataArray(): 27 | return xr.DataArray(np.zeros((2, 4), dtype='i4'), dims=('x', 'y'), name='foo') 28 | 29 | 30 | @pytest.fixture(scope='module') 31 | def sample_Dataset(): 32 | return xr.Dataset( 33 | { 34 | 'x': xr.DataArray(np.arange(10), dims='x'), 35 | 'foo': xr.DataArray(np.ones(10, dtype='float'), dims='x'), 36 | } 37 | ) 38 | 39 | 40 | def test_as_xarray_dataarray(sample_dataArray, sample_Dataset): 41 | assert isinstance( 42 | xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray 43 | ) 44 | assert isinstance( 45 | xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray 46 | ) 47 | 48 | 49 | def test_batch_accessor_ds(sample_ds_3d): 50 | bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5}) 51 | bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5}) 52 | assert isinstance(bg_acc, BatchGenerator) 53 | for batch_class, batch_acc in zip(bg_class, bg_acc): 54 | assert isinstance(batch_acc, xr.Dataset) 55 | assert batch_class.equals(batch_acc) 56 | 57 | 58 | def test_batch_accessor_da(sample_ds_3d): 59 | sample_da = sample_ds_3d['foo'] 60 | bg_class = BatchGenerator(sample_da, input_dims={'x': 5}) 61 | bg_acc = sample_da.batch.generator(input_dims={'x': 5}) 62 | assert isinstance(bg_acc, BatchGenerator) 63 | for batch_class, batch_acc in zip(bg_class, bg_acc): 64 | assert batch_class.equals(batch_acc) 65 | 66 | 67 | @pytest.mark.parametrize( 68 | 'foo_var', 69 | [ 70 | 'foo', # xr.DataArray 71 | ['foo'], # xr.Dataset 72 | ], 73 | ) 74 | def test_tf_to_tensor(sample_ds_3d, foo_var): 75 | tf = pytest.importorskip('tensorflow') 76 | 77 | foo = sample_ds_3d[foo_var] 78 | t = foo.tf.to_tensor() 79 | assert isinstance(t, tf.Tensor) 80 | assert t.shape == tuple(foo.sizes.values()) 81 | 82 | foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo 83 | np.testing.assert_array_equal(t, foo_array.values) 84 | 85 | 86 | @pytest.mark.parametrize( 87 | 'foo_var', 88 | [ 89 | 'foo', # xr.DataArray 90 | ['foo'], # xr.Dataset 91 | ], 92 | ) 93 | def test_torch_to_tensor(sample_ds_3d, foo_var): 94 | torch = pytest.importorskip('torch') 95 | 96 | foo = sample_ds_3d[foo_var] 97 | t = foo.torch.to_tensor() 98 | assert isinstance(t, torch.Tensor) 99 | assert t.names == (None, None, None) 100 | assert t.shape == tuple(foo.sizes.values()) 101 | 102 | foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo 103 | np.testing.assert_array_equal(t, foo_array.values) 104 | 105 | 106 | @pytest.mark.parametrize( 107 | 'foo_var', 108 | [ 109 | 'foo', # xr.DataArray 110 | ['foo'], # xr.Dataset 111 | ], 112 | ) 113 | def test_torch_to_named_tensor(sample_ds_3d, foo_var): 114 | torch = pytest.importorskip('torch') 115 | 116 | foo = sample_ds_3d[foo_var] 117 | t = foo.torch.to_named_tensor() 118 | assert isinstance(t, torch.Tensor) 119 | assert t.names == tuple(foo.dims) 120 | assert t.shape == tuple(foo.sizes.values()) 121 | 122 | foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo 123 | np.testing.assert_array_equal(t, foo_array.values) 124 | -------------------------------------------------------------------------------- /xbatcher/tests/test_generators.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | from typing import Any 4 | 5 | import numpy as np 6 | import pytest 7 | import xarray as xr 8 | 9 | from xbatcher import BatchGenerator, BatchSchema 10 | from xbatcher.testing import ( 11 | get_batch_dimensions, 12 | validate_batch_dimensions, 13 | validate_generator_length, 14 | ) 15 | 16 | 17 | @pytest.fixture(scope='module') 18 | def sample_ds_1d(): 19 | """ 20 | Sample 1D xarray.Dataset for testing. 21 | """ 22 | size = 100 23 | ds = xr.Dataset( 24 | { 25 | 'foo': (['x'], np.random.rand(size)), 26 | 'bar': (['x'], np.random.randint(0, 10, size)), 27 | }, 28 | {'x': (['x'], np.arange(size))}, 29 | ) 30 | return ds 31 | 32 | 33 | @pytest.fixture(scope='module') 34 | def sample_ds_3d(): 35 | """ 36 | Sample 3D xarray.Dataset for testing. 37 | """ 38 | shape = (10, 50, 100) 39 | ds = xr.Dataset( 40 | { 41 | 'foo': (['time', 'y', 'x'], np.random.rand(*shape)), 42 | 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)), 43 | }, 44 | { 45 | 'x': (['x'], np.arange(shape[-1])), 46 | 'y': (['y'], np.arange(shape[-2])), 47 | }, 48 | ) 49 | return ds 50 | 51 | 52 | def test_constructor_dataarray(): 53 | """ 54 | Test that the xarray.DataArray passed to the batch generator is stored 55 | in the .ds attribute. 56 | """ 57 | da = xr.DataArray(np.random.rand(10), dims='x', name='foo') 58 | bg = BatchGenerator(da, input_dims={'x': 2}) 59 | xr.testing.assert_identical(da, bg.ds) 60 | 61 | 62 | @pytest.mark.parametrize('input_size', [5, 6]) 63 | def test_generator_length(sample_ds_1d, input_size): 64 | """ " 65 | Test the length of the batch generator. 66 | """ 67 | bg = BatchGenerator(sample_ds_1d, input_dims={'x': input_size}) 68 | validate_generator_length(bg) 69 | 70 | 71 | def test_generator_getitem(sample_ds_1d): 72 | """ 73 | Test indexing on the batch generator. 74 | """ 75 | bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10}) 76 | first_batch = bg[0] 77 | last_batch = bg[-1] 78 | expected_dims = get_batch_dimensions(bg) 79 | validate_batch_dimensions(expected_dims=expected_dims, batch=first_batch) 80 | validate_batch_dimensions(expected_dims=expected_dims, batch=last_batch) 81 | # raises IndexError for out of range index 82 | with pytest.raises(IndexError, match=r'list index out of range'): 83 | bg[9999999] 84 | 85 | # raises NotImplementedError for iterable index 86 | with pytest.raises(NotImplementedError): 87 | bg[[1, 2, 3]] 88 | 89 | 90 | @pytest.mark.parametrize('input_size', [5, 10]) 91 | def test_batch_1d(sample_ds_1d, input_size): 92 | """ 93 | Test batch generation for a 1D dataset using ``input_dims``. 94 | """ 95 | bg = BatchGenerator(sample_ds_1d, input_dims={'x': input_size}) 96 | validate_generator_length(bg) 97 | expected_dims = get_batch_dimensions(bg) 98 | for n, ds_batch in enumerate(bg): 99 | assert ds_batch.dims['x'] == input_size 100 | expected_slice = slice(input_size * n, input_size * (n + 1)) 101 | ds_batch_expected = sample_ds_1d.isel(x=expected_slice) 102 | xr.testing.assert_identical(ds_batch_expected, ds_batch) 103 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 104 | 105 | 106 | @pytest.mark.parametrize('input_size', [5, 10]) 107 | def test_batch_1d_concat(sample_ds_1d, input_size): 108 | """ 109 | Test batch generation for a 1D dataset using ``input_dims`` and concat_input_dims``. 110 | """ 111 | bg = BatchGenerator( 112 | sample_ds_1d, input_dims={'x': input_size}, concat_input_dims=True 113 | ) 114 | validate_generator_length(bg) 115 | expected_dims = get_batch_dimensions(bg) 116 | for ds_batch in bg: 117 | assert isinstance(ds_batch, xr.Dataset) 118 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 119 | assert 'x' in ds_batch.coords 120 | 121 | 122 | def test_batch_1d_concat_duplicate_dim(sample_ds_1d): 123 | """ 124 | Test batch generation for a 1D dataset using ``concat_input_dims`` when 125 | the same dimension occurs in ``input_dims`` and `batch_dims`` 126 | """ 127 | bg = BatchGenerator( 128 | sample_ds_1d, input_dims={'x': 5}, batch_dims={'x': 10}, concat_input_dims=True 129 | ) 130 | validate_generator_length(bg) 131 | expected_dims = get_batch_dimensions(bg) 132 | for ds_batch in bg: 133 | assert isinstance(ds_batch, xr.Dataset) 134 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 135 | 136 | 137 | @pytest.mark.parametrize('input_size', [5, 10]) 138 | def test_batch_1d_no_coordinate(sample_ds_1d, input_size): 139 | """ 140 | Test batch generation for a 1D dataset without coordinates using ``input_dims``. 141 | 142 | Fix for https://github.com/xarray-contrib/xbatcher/issues/3. 143 | """ 144 | ds_dropped = sample_ds_1d.drop_vars('x') 145 | bg = BatchGenerator(ds_dropped, input_dims={'x': input_size}) 146 | validate_generator_length(bg) 147 | expected_dims = get_batch_dimensions(bg) 148 | for n, ds_batch in enumerate(bg): 149 | assert ds_batch.dims['x'] == input_size 150 | expected_slice = slice(input_size * n, input_size * (n + 1)) 151 | ds_batch_expected = ds_dropped.isel(x=expected_slice) 152 | xr.testing.assert_identical(ds_batch_expected, ds_batch) 153 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 154 | 155 | 156 | @pytest.mark.parametrize('input_size', [5, 10]) 157 | def test_batch_1d_concat_no_coordinate(sample_ds_1d, input_size): 158 | """ 159 | Test batch generation for a 1D dataset without coordinates using ``input_dims`` 160 | and ``concat_input_dims``. 161 | 162 | Fix for https://github.com/xarray-contrib/xbatcher/issues/3. 163 | """ 164 | ds_dropped = sample_ds_1d.drop_vars('x') 165 | bg = BatchGenerator( 166 | ds_dropped, input_dims={'x': input_size}, concat_input_dims=True 167 | ) 168 | validate_generator_length(bg) 169 | expected_dims = get_batch_dimensions(bg) 170 | for ds_batch in bg: 171 | assert isinstance(ds_batch, xr.Dataset) 172 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 173 | assert 'x' not in ds_batch.coords 174 | 175 | 176 | @pytest.mark.parametrize('input_overlap', [1, 4]) 177 | def test_batch_1d_overlap(sample_ds_1d, input_overlap): 178 | """ 179 | Test batch generation for a 1D dataset without coordinates using ``input_dims`` 180 | and ``input_overlap``. 181 | """ 182 | input_size = 10 183 | bg = BatchGenerator( 184 | sample_ds_1d, input_dims={'x': input_size}, input_overlap={'x': input_overlap} 185 | ) 186 | validate_generator_length(bg) 187 | expected_dims = get_batch_dimensions(bg) 188 | stride = input_size - input_overlap 189 | for n, ds_batch in enumerate(bg): 190 | assert ds_batch.dims['x'] == input_size 191 | expected_slice = slice(stride * n, stride * n + input_size) 192 | ds_batch_expected = sample_ds_1d.isel(x=expected_slice) 193 | xr.testing.assert_identical(ds_batch_expected, ds_batch) 194 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 195 | 196 | 197 | @pytest.mark.parametrize('input_size', [5, 10]) 198 | def test_batch_3d_1d_input(sample_ds_3d, input_size): 199 | """ 200 | Test batch generation for a 3D dataset with 1 dimension 201 | specified in ``input_dims``. 202 | """ 203 | bg = BatchGenerator(sample_ds_3d, input_dims={'x': input_size}) 204 | validate_generator_length(bg) 205 | expected_dims = get_batch_dimensions(bg) 206 | for n, ds_batch in enumerate(bg): 207 | assert ds_batch.dims['x'] == input_size 208 | # time and y should be collapsed into batch dimension 209 | assert ( 210 | ds_batch.dims['sample'] 211 | == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time'] 212 | ) 213 | expected_slice = slice(input_size * n, input_size * (n + 1)) 214 | ds_batch_expected = ( 215 | sample_ds_3d.isel(x=expected_slice) 216 | .stack(sample=['time', 'y']) 217 | .transpose('sample', 'x') 218 | ) 219 | xr.testing.assert_identical(ds_batch_expected, ds_batch) 220 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 221 | 222 | 223 | @pytest.mark.parametrize( 224 | 'concat', 225 | [ 226 | True, 227 | pytest.param( 228 | False, 229 | marks=pytest.mark.xfail( 230 | reason='Bug described in https://github.com/xarray-contrib/xbatcher/issues/126' 231 | ), 232 | ), 233 | ], 234 | ) 235 | def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): 236 | """ 237 | Test batch generation for a 3D dataset using ``input_dims`` and batch_dims``. 238 | """ 239 | bg = BatchGenerator( 240 | sample_ds_3d, 241 | input_dims={'x': 5, 'y': 10}, 242 | batch_dims={'time': 2}, 243 | concat_input_dims=concat, 244 | ) 245 | validate_generator_length(bg) 246 | expected_dims = get_batch_dimensions(bg) 247 | for ds_batch in bg: 248 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 249 | 250 | 251 | def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): 252 | """ 253 | Test batch generation for a 3D dataset using ``concat_input_dims`` when 254 | the same dimension occurs in ``input_dims`` and batch_dims``. 255 | """ 256 | bg = BatchGenerator( 257 | sample_ds_3d, 258 | input_dims={'x': 5, 'y': 10}, 259 | batch_dims={'x': 10, 'y': 20}, 260 | concat_input_dims=True, 261 | ) 262 | validate_generator_length(bg) 263 | expected_dims = get_batch_dimensions(bg) 264 | for ds_batch in bg: 265 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 266 | 267 | 268 | @pytest.mark.parametrize('input_size', [5, 10]) 269 | def test_batch_3d_2d_input(sample_ds_3d, input_size): 270 | """ 271 | Test batch generation for a 3D dataset with 2 dimensions 272 | specified in ``input_dims``. 273 | """ 274 | x_input_size = 20 275 | bg = BatchGenerator(sample_ds_3d, input_dims={'y': input_size, 'x': x_input_size}) 276 | validate_generator_length(bg) 277 | expected_dims = get_batch_dimensions(bg) 278 | for n, ds_batch in enumerate(bg): 279 | yn, xn = np.unravel_index( 280 | n, 281 | ( 282 | (sample_ds_3d.dims['y'] // input_size), 283 | (sample_ds_3d.dims['x'] // x_input_size), 284 | ), 285 | ) 286 | expected_xslice = slice(x_input_size * xn, x_input_size * (xn + 1)) 287 | expected_yslice = slice(input_size * yn, input_size * (yn + 1)) 288 | ds_batch_expected = sample_ds_3d.isel(x=expected_xslice, y=expected_yslice) 289 | xr.testing.assert_identical(ds_batch_expected, ds_batch) 290 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 291 | 292 | 293 | @pytest.mark.parametrize('input_size', [5, 10]) 294 | def test_batch_3d_2d_input_concat(sample_ds_3d, input_size): 295 | """ 296 | Test batch generation for a 3D dataset with 2 dimensions 297 | specified in ``input_dims`` using ``concat_input_dims``. 298 | """ 299 | x_input_size = 20 300 | bg = BatchGenerator( 301 | sample_ds_3d, 302 | input_dims={'y': input_size, 'x': x_input_size}, 303 | concat_input_dims=True, 304 | ) 305 | validate_generator_length(bg) 306 | expected_dims = get_batch_dimensions(bg) 307 | for ds_batch in bg: 308 | assert isinstance(ds_batch, xr.Dataset) 309 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 310 | 311 | bg = BatchGenerator( 312 | sample_ds_3d, 313 | input_dims={'time': input_size, 'x': x_input_size}, 314 | concat_input_dims=True, 315 | ) 316 | validate_generator_length(bg) 317 | expected_dims = get_batch_dimensions(bg) 318 | for ds_batch in bg: 319 | assert isinstance(ds_batch, xr.Dataset) 320 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) 321 | 322 | 323 | def test_preload_batch_false(sample_ds_1d): 324 | """ 325 | Test ``preload_batch=False`` does not compute Dask arrays. 326 | """ 327 | sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) 328 | bg = BatchGenerator(sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=False) 329 | assert bg.preload_batch is False 330 | for ds_batch in bg: 331 | assert isinstance(ds_batch, xr.Dataset) 332 | assert ds_batch.chunks 333 | 334 | 335 | def test_preload_batch_true(sample_ds_1d): 336 | """ 337 | Test ``preload_batch=True`` does computes Dask arrays. 338 | """ 339 | sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) 340 | bg = BatchGenerator(sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=True) 341 | assert bg.preload_batch is True 342 | for ds_batch in bg: 343 | assert isinstance(ds_batch, xr.Dataset) 344 | assert not ds_batch.chunks 345 | 346 | 347 | def test_input_dim_exceptions(sample_ds_1d): 348 | """ 349 | Test that a ValueError is raised when input_dim[dim] > ds.sizes[dim] 350 | """ 351 | with pytest.raises(ValueError) as e: 352 | BatchGenerator(sample_ds_1d, input_dims={'x': 110}) 353 | assert len(e) == 1 354 | 355 | 356 | def test_input_overlap_exceptions(sample_ds_1d): 357 | """ 358 | Test that a ValueError is raised when input_overlap[dim] > input_dim[dim] 359 | """ 360 | with pytest.raises(ValueError) as e: 361 | BatchGenerator(sample_ds_1d, input_dims={'x': 10}, input_overlap={'x': 20}) 362 | assert len(e) == 1 363 | 364 | 365 | @pytest.mark.parametrize('input_size', [5, 10]) 366 | def test_to_json(sample_ds_3d, input_size): 367 | x_input_size = 20 368 | bg = BatchSchema( 369 | sample_ds_3d, 370 | input_dims={'time': input_size, 'x': x_input_size}, 371 | ) 372 | out_file = tempfile.NamedTemporaryFile(mode='w+b') 373 | bg.to_file(out_file.name) 374 | in_dict = json.load(out_file) 375 | assert in_dict['input_dims']['time'] == input_size 376 | assert in_dict['input_dims']['x'] == x_input_size 377 | out_file.close() 378 | 379 | 380 | @pytest.mark.parametrize('preload', [True, False]) 381 | def test_batcher_cached_getitem(sample_ds_1d, preload) -> None: 382 | pytest.importorskip('zarr') 383 | cache: dict[str, Any] = {} 384 | 385 | def preproc(ds): 386 | processed = ds.load().chunk(-1) 387 | processed.attrs['foo'] = 'bar' 388 | return processed 389 | 390 | bg = BatchGenerator( 391 | sample_ds_1d, 392 | input_dims={'x': 10}, 393 | cache=cache, 394 | cache_preprocess=preproc, 395 | preload_batch=preload, 396 | ) 397 | 398 | # first batch 399 | assert bg[0].sizes['x'] == 10 400 | ds_no_cache = bg[1] 401 | # last batch 402 | assert bg[-1].sizes['x'] == 10 403 | 404 | assert '0/.zgroup' in cache 405 | 406 | # now from cache 407 | # first batch 408 | assert bg[0].sizes['x'] == 10 409 | # last batch 410 | assert bg[-1].sizes['x'] == 10 411 | ds_cache = bg[1] 412 | 413 | assert ds_no_cache.attrs['foo'] == 'bar' 414 | assert ds_cache.attrs['foo'] == 'bar' 415 | 416 | xr.testing.assert_equal(ds_no_cache, ds_cache) 417 | xr.testing.assert_identical(ds_no_cache, ds_cache) 418 | 419 | # without preprocess func 420 | bg = BatchGenerator( 421 | sample_ds_1d, input_dims={'x': 10}, cache=cache, preload_batch=preload 422 | ) 423 | assert bg.cache_preprocess is None 424 | assert bg[0].sizes['x'] == 10 425 | ds_no_cache = bg[1] 426 | assert '1/.zgroup' in cache 427 | ds_cache = bg[1] 428 | xr.testing.assert_equal(ds_no_cache, ds_cache) 429 | xr.testing.assert_identical(ds_no_cache, ds_cache) 430 | -------------------------------------------------------------------------------- /xbatcher/tests/test_keras_loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import xarray as xr 4 | 5 | from xbatcher import BatchGenerator 6 | from xbatcher.loaders.keras import CustomTFDataset 7 | 8 | tf = pytest.importorskip('tensorflow') 9 | 10 | 11 | @pytest.fixture(scope='module') 12 | def ds_xy(): 13 | n_samples = 100 14 | n_features = 5 15 | ds = xr.Dataset( 16 | { 17 | 'x': ( 18 | ['sample', 'feature'], 19 | np.random.random((n_samples, n_features)), 20 | ), 21 | 'y': (['sample'], np.random.random(n_samples)), 22 | }, 23 | ) 24 | return ds 25 | 26 | 27 | def test_custom_dataarray(ds_xy): 28 | x = ds_xy['x'] 29 | y = ds_xy['y'] 30 | 31 | x_gen = BatchGenerator(x, {'sample': 10}) 32 | y_gen = BatchGenerator(y, {'sample': 10}) 33 | 34 | dataset = CustomTFDataset(x_gen, y_gen) 35 | 36 | # test __getitem__ 37 | x_batch, y_batch = dataset[0] 38 | assert x_batch.shape == (10, 5) 39 | assert y_batch.shape == (10,) 40 | assert tf.is_tensor(x_batch) 41 | assert tf.is_tensor(y_batch) 42 | 43 | # test __len__ 44 | assert len(dataset) == len(x_gen) 45 | 46 | 47 | def test_custom_dataarray_with_transform(ds_xy): 48 | x = ds_xy['x'] 49 | y = ds_xy['y'] 50 | 51 | x_gen = BatchGenerator(x, {'sample': 10}) 52 | y_gen = BatchGenerator(y, {'sample': 10}) 53 | 54 | def x_transform(batch): 55 | return batch * 0 + 1 56 | 57 | def y_transform(batch): 58 | return batch * 0 - 1 59 | 60 | dataset = CustomTFDataset( 61 | x_gen, y_gen, transform=x_transform, target_transform=y_transform 62 | ) 63 | x_batch, y_batch = dataset[0] 64 | assert x_batch.shape == (10, 5) 65 | assert y_batch.shape == (10,) 66 | assert tf.is_tensor(x_batch) 67 | assert tf.is_tensor(y_batch) 68 | assert tf.experimental.numpy.all(x_batch == 1) 69 | assert tf.experimental.numpy.all(y_batch == -1) 70 | -------------------------------------------------------------------------------- /xbatcher/tests/test_print_versions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | 5 | import xbatcher 6 | 7 | 8 | def test_show_versions() -> None: 9 | """ 10 | Test xbatcher.show_versions() 11 | 12 | Based on https://github.com/pydata/xarray/blob/main/xarray/tests/test_print_versions.py 13 | """ 14 | f = io.StringIO() 15 | xbatcher.show_versions(file=f) 16 | assert 'xbatcher information' in f.getvalue() 17 | -------------------------------------------------------------------------------- /xbatcher/tests/test_torch_loaders.py: -------------------------------------------------------------------------------- 1 | from importlib import reload 2 | 3 | import numpy as np 4 | import pytest 5 | import xarray as xr 6 | 7 | from xbatcher import BatchGenerator 8 | from xbatcher.loaders.torch import IterableDataset, MapDataset, to_tensor 9 | 10 | torch = pytest.importorskip('torch') 11 | 12 | 13 | def test_import_torch_failure(monkeypatch): 14 | import sys 15 | 16 | import xbatcher.loaders 17 | 18 | monkeypatch.setitem(sys.modules, 'torch', None) 19 | 20 | with pytest.raises(ImportError) as excinfo: 21 | reload(xbatcher.loaders.torch) 22 | 23 | assert 'install PyTorch to proceed' in str(excinfo.value) 24 | 25 | 26 | def test_import_dask_failure(monkeypatch): 27 | import sys 28 | 29 | import xbatcher.loaders 30 | 31 | monkeypatch.setitem(sys.modules, 'dask', None) 32 | reload(xbatcher.loaders.torch) 33 | 34 | assert xbatcher.loaders.torch.dask is None 35 | 36 | 37 | @pytest.fixture(scope='module', params=[True, False]) 38 | def ds_xy(request): 39 | n_samples = 100 40 | n_features = 5 41 | ds = xr.Dataset( 42 | { 43 | 'x': ( 44 | ['sample', 'feature'], 45 | np.random.random((n_samples, n_features)), 46 | ), 47 | 'y': (['sample'], np.random.random(n_samples)), 48 | }, 49 | ) 50 | 51 | if request.param: 52 | ds = ds.chunk({'sample': 10}) 53 | 54 | return ds 55 | 56 | 57 | @pytest.mark.parametrize('x_var', ['x', ['x']]) 58 | def test_map_dataset_without_y(ds_xy, x_var) -> None: 59 | x = ds_xy[x_var] 60 | 61 | x_gen = BatchGenerator(x, {'sample': 10}) 62 | 63 | dataset = MapDataset(x_gen) 64 | 65 | # test __getitem__ 66 | x_batch = dataset[0] 67 | assert x_batch.shape == (10, 5) # type: ignore[union-attr] 68 | assert isinstance(x_batch, torch.Tensor) 69 | 70 | idx = torch.tensor([0]) 71 | x_batch = dataset[idx] 72 | assert x_batch.shape == (10, 5) 73 | assert isinstance(x_batch, torch.Tensor) 74 | 75 | with pytest.raises(NotImplementedError): 76 | idx = torch.tensor([0, 1]) 77 | x_batch = dataset[idx] 78 | 79 | # test __len__ 80 | assert len(dataset) == len(x_gen) 81 | 82 | # test integration with torch DataLoader 83 | loader = torch.utils.data.DataLoader(dataset, batch_size=None) 84 | 85 | for x_batch in loader: 86 | assert x_batch.shape == (10, 5) # type: ignore[union-attr] 87 | assert isinstance(x_batch, torch.Tensor) 88 | 89 | # Check that array shape of last item in generator is same as the batch image 90 | assert tuple(x_gen[-1].sizes.values()) == x_batch.shape # type: ignore[union-attr] 91 | # Check that array values from last item in generator and batch are the same 92 | gen_array = ( 93 | x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1] 94 | ) 95 | np.testing.assert_array_equal(gen_array, x_batch) # type: ignore 96 | 97 | 98 | @pytest.mark.parametrize( 99 | ('x_var', 'y_var'), 100 | [ 101 | ('x', 'y'), # xr.DataArray 102 | (['x'], ['y']), # xr.Dataset 103 | ], 104 | ) 105 | def test_map_dataset(ds_xy, x_var, y_var) -> None: 106 | x = ds_xy[x_var] 107 | y = ds_xy[y_var] 108 | 109 | x_gen = BatchGenerator(x, {'sample': 10}) 110 | y_gen = BatchGenerator(y, {'sample': 10}) 111 | 112 | dataset = MapDataset(x_gen, y_gen) 113 | 114 | # test __getitem__ 115 | x_batch, y_batch = dataset[0] 116 | assert x_batch.shape == (10, 5) 117 | assert y_batch.shape == (10,) 118 | assert isinstance(x_batch, torch.Tensor) 119 | 120 | idx = torch.tensor([0]) 121 | x_batch, y_batch = dataset[idx] 122 | assert x_batch.shape == (10, 5) 123 | assert y_batch.shape == (10,) 124 | assert isinstance(x_batch, torch.Tensor) 125 | 126 | with pytest.raises(NotImplementedError): 127 | idx = torch.tensor([0, 1]) 128 | x_batch, y_batch = dataset[idx] 129 | 130 | # test __len__ 131 | assert len(dataset) == len(x_gen) 132 | 133 | # test integration with torch DataLoader 134 | loader = torch.utils.data.DataLoader(dataset, batch_size=None) 135 | 136 | for x_batch, y_batch in loader: 137 | assert x_batch.shape == (10, 5) 138 | assert y_batch.shape == (10,) 139 | assert isinstance(x_batch, torch.Tensor) 140 | 141 | # Check that array shape of last item in generator is same as the batch image 142 | assert tuple(x_gen[-1].sizes.values()) == x_batch.shape 143 | # Check that array values from last item in generator and batch are the same 144 | gen_array = ( 145 | x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1] 146 | ) 147 | np.testing.assert_array_equal(gen_array, x_batch) # type: ignore 148 | 149 | 150 | @pytest.mark.parametrize( 151 | ('x_var', 'y_var'), 152 | [ 153 | ('x', 'y'), # xr.DataArray 154 | (['x'], ['y']), # xr.Dataset 155 | ], 156 | ) 157 | def test_map_dataset_with_transform(ds_xy, x_var, y_var) -> None: 158 | x = ds_xy[x_var] 159 | y = ds_xy[y_var] 160 | 161 | x_gen = BatchGenerator(x, {'sample': 10}) 162 | y_gen = BatchGenerator(y, {'sample': 10}) 163 | 164 | def x_transform(batch): 165 | return to_tensor(batch * 0 + 1) 166 | 167 | def y_transform(batch): 168 | return to_tensor(batch * 0 - 1) 169 | 170 | dataset = MapDataset( 171 | x_gen, y_gen, transform=x_transform, target_transform=y_transform 172 | ) 173 | x_batch, y_batch = dataset[0] 174 | assert x_batch.shape == (10, 5) 175 | assert y_batch.shape == (10,) 176 | assert isinstance(x_batch, torch.Tensor) 177 | assert (x_batch == 1).all() 178 | assert (y_batch == -1).all() 179 | 180 | 181 | @pytest.mark.parametrize( 182 | ('x_var', 'y_var'), 183 | [ 184 | ('x', 'y'), # xr.DataArray 185 | (['x'], ['y']), # xr.Dataset 186 | ], 187 | ) 188 | def test_iterable_dataset(ds_xy, x_var, y_var): 189 | x = ds_xy[x_var] 190 | y = ds_xy[y_var] 191 | 192 | x_gen = BatchGenerator(x, {'sample': 10}) 193 | y_gen = BatchGenerator(y, {'sample': 10}) 194 | 195 | dataset = IterableDataset(x_gen, y_gen) 196 | 197 | # test integration with torch DataLoader 198 | loader = torch.utils.data.DataLoader(dataset, batch_size=None) 199 | 200 | for x_batch, y_batch in loader: 201 | assert x_batch.shape == (10, 5) 202 | assert y_batch.shape == (10,) 203 | assert isinstance(x_batch, torch.Tensor) 204 | 205 | # Check that array shape of last item in generator is same as the batch image 206 | assert tuple(x_gen[-1].sizes.values()) == x_batch.shape 207 | # Check that array values from last item in generator and batch are the same 208 | gen_array = ( 209 | x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1] 210 | ) 211 | np.testing.assert_array_equal(gen_array, x_batch) 212 | -------------------------------------------------------------------------------- /xbatcher/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xarray-contrib/xbatcher/5898a76ce88200a28eed036f0fbec9890a4280b5/xbatcher/util/__init__.py -------------------------------------------------------------------------------- /xbatcher/util/print_versions.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import platform 3 | import sys 4 | 5 | 6 | def show_versions(file=sys.stdout): 7 | """ 8 | Print various dependency versions, including information about: 9 | 10 | - xbatcher 11 | - System information (Python version, Operating System) 12 | - Dependency versions (Xarray, etc) 13 | 14 | Based on https://github.com/GenericMappingTools/pygmt/blob/09f9e65ebebfa929f9ddc2af90e05f3302c2239d/pygmt/__init__.py#L95 15 | """ 16 | 17 | def _get_module_version(modname): 18 | """ 19 | Get version information of a Python module. 20 | 21 | Copied from https://github.com/GenericMappingTools/pygmt/blob/09f9e65ebebfa929f9ddc2af90e05f3302c2239d/pygmt/__init__.py#L111 22 | """ 23 | try: 24 | if modname in sys.modules: 25 | module = sys.modules[modname] 26 | else: 27 | module = importlib.import_module(modname) 28 | 29 | return getattr(module, '__version__', 'installed') 30 | except ImportError: 31 | return None 32 | 33 | sys_info = { 34 | 'python': sys.version.replace('\n', ' '), 35 | 'executable': sys.executable, 36 | 'machine': platform.platform(), 37 | } 38 | 39 | deps = [ 40 | # Required 41 | 'dask', 42 | 'numpy', 43 | 'xarray', 44 | # Optional 45 | 'torch', 46 | # Setup/test 47 | 'pip', 48 | 'conda', 49 | 'pytest', 50 | # Misc. 51 | 'IPython', 52 | 'sphinx', 53 | ] 54 | __version__ = f'v{importlib.metadata.version("xbatcher")}' 55 | 56 | print('xbatcher information:', file=file) 57 | print(f' version: {__version__}', file=file) 58 | 59 | print('System information:', file=file) 60 | for key, val in sys_info.items(): 61 | print(f' {key}: {val}', file=file) 62 | 63 | print('Dependency information:', file=file) 64 | for modname in deps: 65 | print(f' {modname}: {_get_module_version(modname)}', file=file) 66 | 67 | 68 | if __name__ == '__main__': # pragma: no cover 69 | show_versions() 70 | --------------------------------------------------------------------------------