├── .coveragerc
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ ├── feature_request.yml
│ └── misc.yml
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── release-drafter.yml
└── workflows
│ ├── main.yaml
│ ├── pypi-release.yaml
│ ├── release-drafter.yml
│ └── testpypi-release.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.rst
├── asv_bench
├── asv.conf.json
└── benchmarks
│ ├── __init__.py
│ ├── accessors.py
│ ├── batches.py
│ ├── generators.py
│ └── loaders.py
├── ci
└── requirements
│ ├── asv.yml
│ ├── doc.yml
│ └── environment.yml
├── conftest.py
├── doc
├── Makefile
├── _static
│ ├── logo.svg
│ └── switcher.json
├── api.rst
├── conf.py
├── contributing.rst
├── demo.ipynb
├── index.rst
├── roadmap.rst
├── tutorials-and-presentations.rst
└── user-guide
│ ├── caching.ipynb
│ ├── create-fashion-mnist-dataset.ipynb
│ ├── index.rst
│ ├── training-a-neural-network-with-Pytorch-and-xbatcher.ipynb
│ └── training-a-neural-network-with-keras-and-xbatcher.ipynb
├── pyproject.toml
├── readthedocs.yaml
└── xbatcher
├── __init__.py
├── accessors.py
├── generators.py
├── loaders
├── __init__.py
├── keras.py
└── torch.py
├── testing.py
├── tests
├── __init__.py
├── test_accessors.py
├── test_generators.py
├── test_keras_loaders.py
├── test_print_versions.py
└── test_torch_loaders.py
└── util
├── __init__.py
└── print_versions.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = True
3 | source = xbatcher/*
4 | include = xbatcher/*
5 | omit = */setup.py
6 | */version.py
7 | xbatcher/tests/*
8 | */__init__.py
9 | [report]
10 | include = xbatcher/*
11 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: 🐛 Bug Report
2 | description: File a bug report to help us improve
3 | labels: [bug, "needs triage"]
4 | body:
5 | - type: textarea
6 | id: what-happened
7 | attributes:
8 | label: What happened?
9 | description: |
10 | Thanks for reporting a bug! Please describe what you were trying to get done.
11 | Tell us what happened, what went wrong.
12 | validations:
13 | required: true
14 |
15 | - type: textarea
16 | id: what-did-you-expect-to-happen
17 | attributes:
18 | label: What did you expect to happen?
19 | description: |
20 | Describe what you expected to happen.
21 | validations:
22 | required: false
23 |
24 | - type: textarea
25 | id: sample-code
26 | attributes:
27 | label: Minimal Complete Verifiable Example
28 | description: |
29 | Minimal, self-contained copy-pastable example that demonstrates the issue. For more details, check out:
30 |
31 | - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve)
32 | - [Craft Minimal Bug Reports](http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports)
33 |
34 | This will be automatically formatted into code, so no need for markdown backticks.
35 | render: Python
36 |
37 | - type: textarea
38 | id: log-output
39 | attributes:
40 | label: Relevant log output
41 | description: Please copy and paste any relevant output. This will be automatically formatted into code, so no need for markdown backticks.
42 | render: Python
43 |
44 | - type: textarea
45 | id: extra
46 | attributes:
47 | label: Anything else we need to know?
48 | description: |
49 | Please describe any other information you want to share.
50 |
51 | - type: textarea
52 | id: show-versions
53 | attributes:
54 | label: Environment
55 | description: Please paste the output of `xbatcher.show_versions()`. This will be automatically formatted into code, so no need for markdown backticks.
56 | render: Python
57 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: ❓ Usage question
4 | url: https://github.com/xarray-contrib/xbatcher/discussions
5 | about: |
6 | Ask questions and discuss with other community members here.
7 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: 💡 Feature Request
2 | description: Suggest an idea for xbatcher
3 | labels: [feature request]
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Is your feature request related to a problem?
9 | description: |
10 | Please provide a clear and concise description of what the problem is.
11 | validations:
12 | required: true
13 | - type: textarea
14 | id: solution
15 | attributes:
16 | label: Describe the solution you'd like
17 | description: |
18 | A clear and concise description of what you want to happen.
19 | - type: textarea
20 | id: alternatives
21 | attributes:
22 | label: Describe alternatives you've considered
23 | description: |
24 | A clear and concise description of any alternative solutions or features you've considered.
25 | validations:
26 | required: false
27 | - type: textarea
28 | id: additional-context
29 | attributes:
30 | label: Additional context
31 | description: |
32 | Add any other context about the feature request here.
33 | validations:
34 | required: false
35 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/misc.yml:
--------------------------------------------------------------------------------
1 | name: 📝 Issue
2 | description: General issue, that's not a bug report.
3 | labels: ["needs triage"]
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Please describe your issue here.
9 | - type: textarea
10 | id: issue-description
11 | attributes:
12 | label: What is your issue?
13 | description: |
14 | Thank you for filing an issue! Please give us further information on how we can help you.
15 | placeholder: Please describe your issue.
16 | validations:
17 | required: true
18 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | **Description of proposed changes**
2 |
3 |
4 |
5 |
6 |
7 | Fixes #
8 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: pip
4 | directory: "/"
5 | schedule:
6 | interval: daily
7 | groups:
8 | pip-updates:
9 | patterns:
10 | - "*"
11 | - package-ecosystem: "github-actions"
12 | directory: "/"
13 | schedule:
14 | interval: "daily"
15 | groups:
16 | gh-actions:
17 | patterns:
18 | - "*"
19 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name-template: "v$RESOLVED_VERSION"
2 | tag-template: "v$RESOLVED_VERSION"
3 | categories:
4 | - title: "Features"
5 | label: "feature"
6 | - title: "Enhancement"
7 | label: "enhancement"
8 | - title: "Bug Fixes"
9 | label: "bug"
10 | - title: "Documentation"
11 | label: "documentation"
12 | - title: "Maintenance"
13 | label: "maintenance"
14 | change-template: "- $TITLE @$AUTHOR ([#$NUMBER]($URL))"
15 | change-title-escapes: '\<*_@'
16 | version-resolver:
17 | major:
18 | labels:
19 | - "major"
20 | minor:
21 | labels:
22 | - "feature"
23 | - "enhancement"
24 | default: patch
25 | exclude-labels:
26 | - "skip-changelog"
27 | template: |
28 | ## Release v$RESOLVED_VERSION (20YY/MM/DD)
29 |
30 | $CHANGES
31 |
--------------------------------------------------------------------------------
/.github/workflows/main.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: main
6 | pull_request:
7 | branches: main
8 | paths-ignore:
9 | - ".github/workflows/*-release.yaml"
10 | - "asv_bench/**"
11 | - "doc/**"
12 | schedule:
13 | - cron: "0 0 * * *"
14 |
15 | concurrency:
16 | group: ${{ github.workflow }}-${{ github.ref }}
17 | cancel-in-progress: true
18 |
19 | jobs:
20 | test:
21 | name: ${{ matrix.python-version }}-build
22 | runs-on: ubuntu-latest
23 | strategy:
24 | matrix:
25 | python-version: ["3.10", "3.11", "3.12"]
26 | fail-fast: false
27 | steps:
28 | - uses: actions/checkout@v4
29 | - name: Setup Python
30 | uses: actions/setup-python@v5.4.0
31 | with:
32 | python-version: ${{ matrix.python-version }}
33 | architecture: x64
34 | - uses: actions/cache@v4
35 | with:
36 | path: ~/.cache/pip
37 | key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.txt') }}
38 | restore-keys: |
39 | ${{ runner.os }}-pip-
40 | - run: |
41 | python -m pip install -e .[dev]
42 | python -m pip list
43 | - name: Running Tests
44 | run: |
45 | pytest --verbose --cov=. --cov-report=xml
46 | - name: Upload coverage to Codecov
47 | uses: codecov/codecov-action@v5.4.0
48 | if: ${{ matrix.python-version }} == 3.10
49 | with:
50 | file: ./coverage.xml
51 | fail_ci_if_error: false
52 |
53 | test-upstream:
54 | name: ${{ matrix.python-version }}-dev-build
55 | runs-on: ubuntu-latest
56 | strategy:
57 | matrix:
58 | python-version: ["3.10", "3.11", "3.12"]
59 | fail-fast: false
60 | steps:
61 | - uses: actions/checkout@v4
62 | - name: Setup Python
63 | uses: actions/setup-python@v5.4.0
64 | with:
65 | python-version: ${{ matrix.python-version }}
66 | architecture: x64
67 | - run: |
68 | python -m pip install -e .[dev]
69 | python -m pip install --upgrade \
70 | git+https://github.com/dask/dask \
71 | git+https://github.com/pydata/xarray
72 | python -m pip list
73 | - name: Running Tests
74 | run: |
75 | py.test --verbose --cov=.
76 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-release.yaml:
--------------------------------------------------------------------------------
1 | name: Build and Upload xbatcher to PyPI
2 | on:
3 | release:
4 | types:
5 | - published
6 | # Runs for pull requests should be disabled other than for testing purposes
7 | #pull_request:
8 | # branches:
9 | # - main
10 |
11 | permissions:
12 | contents: read
13 |
14 | jobs:
15 | build-artifacts:
16 | runs-on: ubuntu-latest
17 | if: github.repository == 'xarray-contrib/xbatcher'
18 | steps:
19 | - uses: actions/checkout@v4
20 | with:
21 | fetch-depth: 0
22 | - uses: actions/setup-python@v5.4.0
23 | name: Install Python
24 | with:
25 | python-version: 3.11
26 |
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install build twine
31 |
32 | # This step is only necessary for testing purposes and for TestPyPI
33 | - name: Fix up version string for TestPyPI
34 | if: ${{ !startsWith(github.ref, 'refs/tags') }}
35 | run: |
36 | # Change setuptools-scm local_scheme to "no-local-version" so the
37 | # local part of the version isn't included, making the version string
38 | # compatible with PyPI.
39 | sed --in-place "s/node-and-date/no-local-version/g" pyproject.toml
40 |
41 | - name: Build tarball and wheels
42 | run: |
43 | git clean -xdf
44 | git restore -SW .
45 | python -m build
46 | - name: Check built artifacts
47 | run: |
48 | python -m twine check --strict dist/*
49 | pwd
50 | if [ -f dist/xbatcher-0.0.0.tar.gz ]; then
51 | echo "❌ INVALID VERSION NUMBER"
52 | exit 1
53 | else
54 | echo "✅ Looks good"
55 | fi
56 | - uses: actions/upload-artifact@v4
57 | with:
58 | name: releases
59 | path: dist
60 |
61 | test-built-dist:
62 | needs: build-artifacts
63 | runs-on: ubuntu-latest
64 | steps:
65 | - uses: actions/setup-python@v5.4.0
66 | name: Install Python
67 | with:
68 | python-version: 3.11
69 | - uses: actions/download-artifact@v4
70 | with:
71 | name: releases
72 | path: dist
73 | - name: List contents of built dist
74 | run: |
75 | ls -ltrh
76 | ls -ltrh dist
77 | - name: Verify the built dist/wheel is valid
78 | run: |
79 | python -m pip install --upgrade pip
80 | python -m pip install dist/xbatcher*.whl
81 | python -m xbatcher.util.print_versions
82 | - name: Publish package to TestPyPI
83 | uses: pypa/gh-action-pypi-publish@v1.12.4
84 | with:
85 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
86 | repository-url: https://test.pypi.org/legacy/
87 | # verbose: true
88 |
89 | upload-to-pypi:
90 | needs: test-built-dist
91 | if: github.event_name == 'release'
92 | runs-on: ubuntu-latest
93 | steps:
94 | - uses: actions/download-artifact@v4
95 | with:
96 | name: releases
97 | path: dist
98 | - name: Publish package to PyPI
99 | uses: pypa/gh-action-pypi-publish@v1.12.4
100 | with:
101 | password: ${{ secrets.PYPI_API_TOKEN }}
102 | # verbose: true
103 |
--------------------------------------------------------------------------------
/.github/workflows/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name: Release Drafter
2 |
3 | on:
4 | push:
5 | # branches to consider in the event; optional, defaults to all
6 | branches:
7 | - main
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 | update_release_draft:
14 | permissions:
15 | # write permission is required to create a github release
16 | contents: write
17 | # write permission is required for autolabeler
18 | # otherwise, read permission is required at least
19 | pull-requests: write
20 | runs-on: ubuntu-latest
21 | steps:
22 | # (Optional) GitHub Enterprise requires GHE_HOST variable set
23 | #- name: Set GHE_HOST
24 | # run: |
25 | # echo "GHE_HOST=${GITHUB_SERVER_URL##https:\/\/}" >> $GITHUB_ENV
26 |
27 | # Drafts your next Release notes as Pull Requests are merged into "main"
28 | - uses: release-drafter/release-drafter@v6
29 | # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml
30 | # with:
31 | # config-name: my-config.yml
32 | # disable-autolabeler: true
33 | env:
34 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
35 |
--------------------------------------------------------------------------------
/.github/workflows/testpypi-release.yaml:
--------------------------------------------------------------------------------
1 | name: Build and Upload xbatcher to TestPyPI
2 | on:
3 | push:
4 | branches:
5 | - main
6 | # pull_request:
7 | # branches:
8 | # - main
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | publish-testpypi:
15 | name: Publish to Test PyPI
16 | runs-on: ubuntu-latest
17 | if: github.repository == 'xarray-contrib/xbatcher'
18 |
19 | steps:
20 | - name: Checkout
21 | uses: actions/checkout@v4
22 | with:
23 | # fetch all history so that setuptools-scm works
24 | fetch-depth: 0
25 |
26 | - name: Set up Python
27 | uses: actions/setup-python@v5.4.0
28 | with:
29 | python-version: "3.11"
30 |
31 | - name: Install dependencies
32 | run: python -m pip install build
33 |
34 | - name: Fix up version string for TestPyPI
35 | if: ${{ !startsWith(github.ref, 'refs/tags') }}
36 | run: |
37 | sed --in-place "s/node-and-date/no-local-version/g" pyproject.toml
38 |
39 | - name: Build tarball and wheels
40 | run: |
41 | python -m build
42 | echo "Generated files:"
43 | ls -lh dist/
44 |
45 | - name: Verify the built dist/wheel is valid
46 | run: |
47 | python -m pip install --upgrade pip
48 | python -m pip install dist/xbatcher*.whl
49 | python -m xbatcher.util.print_versions
50 |
51 | - name: Publish package to TestPyPI
52 | uses: pypa/gh-action-pypi-publish@v1.12.4
53 | with:
54 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
55 | repository-url: https://test.pypi.org/legacy/
56 | # verbose: true
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | env/
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 | *,cover
45 |
46 | # asv environments
47 | .asv
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 |
56 | # Sphinx documentation
57 | doc/_build/
58 | doc/generated/
59 |
60 | # PyBuilder
61 | target/
62 |
63 | # notebook
64 | */.ipynb_checkpoints/*
65 |
66 | # tests
67 | .pytest_cache/*
68 | .mypy_cache/
69 | .ipynb_checkpoints/
70 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autoupdate_schedule: quarterly
3 | autofix_prs: false
4 |
5 | repos:
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v5.0.0
8 | hooks:
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-docstring-first
12 | - id: check-json
13 | exclude: "asv_bench/asv.conf.json"
14 | - id: check-yaml
15 |
16 | - repo: https://github.com/astral-sh/ruff-pre-commit
17 | rev: "v0.8.6"
18 | hooks:
19 | - id: ruff
20 | args: ["--fix"]
21 | - id: ruff-format
22 |
23 | - repo: https://github.com/pre-commit/mirrors-prettier
24 | rev: v4.0.0-alpha.8
25 | hooks:
26 | - id: prettier
27 |
28 | - repo: https://github.com/pre-commit/mirrors-mypy
29 | rev: v1.14.1
30 | hooks:
31 | - id: mypy
32 | additional_dependencies: [
33 | # Type stubs
34 | types-setuptools,
35 | # Dependencies that are typed
36 | numpy,
37 | xarray,
38 | ]
39 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 |
3 | title: xbatcher
4 | doi: 10.5281/zenodo.13776824
5 | type: software
6 | url: "https://github.com/xarray-contrib/xbatcher"
7 | abstract: >-
8 | Xbatcher is a small library for iterating Xarray DataArrays and Datasets in
9 | batches. The goal is to make it easy to feed Xarray objects to machine
10 | learning libraries such as PyTorch or TensorFlow.
11 |
12 | keywords:
13 | - Xarray
14 | - Machine Learning
15 | - Deep Learning
16 | - PyTorch
17 | - TensorFlow
18 | - Dask
19 | version: 0.4.0
20 | date-released: 2024-09-17
21 | message: "If you use this software, please cite it as below."
22 | authors:
23 | - family-names: Jones
24 | given-names: Max
25 | orcid: https://orcid.org/0000-0003-0180-8928
26 | - family-names: Abernathey
27 | given-names: Ryan
28 | orcid: https://orcid.org/0000-0001-5999-4917
29 | - family-names: Hamman
30 | given-names: Joseph
31 | orcid: https://orcid.org/0000-0001-7479-8439
32 | - family-names: Banihirwe
33 | given-names: Anderson
34 | orcid: https://orcid.org/0000-0001-6583-571X
35 | - family-names: Leong
36 | given-names: Wei Ji
37 | orcid: https://orcid.org/0000-0003-2354-1988
38 | - family-names: Cindy
39 | given-names: Chiao
40 | - family-names: Bell
41 | given-names: Ryan
42 | - family-names: Hagen
43 | given-names: Raphael
44 | - family-names: Scott
45 | given-names: Richard
46 | - family-names: Bednar
47 | given-names: James
48 | - family-names: Vandal
49 | given-names: TJ
50 | - family-names: Bourbeau
51 | given-names: James
52 | - family-names: Jackson
53 | given-names: Robert
54 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Xbatcher's contributor guidelines [can be found in the online documentation](https://xbatcher.readthedocs.io/en/latest/contributing.html).
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 Xbatcher contributors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | xbatcher: Batch Generation from Xarray Datasets
2 | ===============================================
3 |
4 | |Build Status| |codecov| |docs| |pypi| |conda-forge| |license| |zenodo|
5 |
6 |
7 | Xbatcher is a small library for iterating Xarray DataArrays and Datasets in
8 | batches. The goal is to make it easy to feed Xarray objects to machine
9 | learning libraries such as PyTorch_ or TensorFlow_. View the |docs| for more
10 | info.
11 |
12 | .. _TensorFlow: https://www.tensorflow.org/
13 |
14 | .. _PyTorch: https://pytorch.org/
15 |
16 |
17 | .. |Build Status| image:: https://github.com/xarray-contrib/xbatcher/workflows/CI/badge.svg
18 | :target: https://github.com/xarray-contrib/xbatcher/actions
19 | :alt: github actions build status
20 | .. |codecov| image:: https://codecov.io/gh/xarray-contrib/xbatcher/branch/main/graph/badge.svg
21 | :target: https://codecov.io/gh/xarray-contrib/xbatcher
22 | :alt: code coverage
23 | .. |docs| image:: http://readthedocs.org/projects/xbatcher/badge/?version=latest
24 | :target: http://xbatcher.readthedocs.org/en/latest/?badge=latest
25 | :alt: docs
26 | .. |pypi| image:: https://img.shields.io/pypi/v/xbatcher.svg
27 | :target: https://pypi.python.org/pypi/xbatcher
28 | :alt: pypi
29 | .. |conda-forge| image:: https://img.shields.io/conda/vn/conda-forge/xbatcher.svg
30 | :target: https://anaconda.org/conda-forge/xbatcher
31 | :alt: conda-forge
32 | .. |license| image:: https://img.shields.io/github/license/xarray-contrib/xbatcher.svg
33 | :target: https://github.com/xarray-contrib/xbatcher
34 | :alt: license
35 | .. |zenodo| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.13776824.svg
36 | :target: https://doi.org/10.5281/zenodo.13776824
37 | :alt: zenodo
38 |
39 | Installation
40 | ------------
41 |
42 | Xbatcher can be installed from PyPI as::
43 |
44 | python -m pip install xbatcher
45 |
46 | Or via Conda as::
47 |
48 | conda install -c conda-forge xbatcher
49 |
50 | Or from source as::
51 |
52 | python -m pip install git+https://github.com/xarray-contrib/xbatcher.git
53 |
54 | .. note::
55 | The required dependencies installed with Xbatcher are `Xarray `_,
56 | `Dask `_, and `NumPy `_.
57 | You will need to separately install `TensorFlow `_
58 | or `PyTorch `_ to use those data loaders or
59 | Xarray accessors. `Review the installation instructions `_
60 | for more details.
61 |
62 | Documentation
63 | -------------
64 |
65 | Documentation is hosted on ReadTheDocs: https://xbatcher.readthedocs.org
66 |
67 | License
68 | ------------
69 |
70 | Apache License 2.0, see LICENSE file.
71 |
72 | Acknowledgements
73 | ----------------
74 |
75 | This work was funded in part by:
76 |
77 | NASA ACCESS19-0049: Pangeo ML: Open Source Tools and Pipelines for Scalable Machine Learning Using NASA Earth Observation Data
78 |
79 | This work was motivated by many conversations in the Pangeo community and Pangeo ML working group
80 |
--------------------------------------------------------------------------------
/asv_bench/asv.conf.json:
--------------------------------------------------------------------------------
1 | {
2 | // The version of the config file format. Do not change, unless
3 | // you know what you are doing.
4 | "version": 1,
5 |
6 | // The name of the project being benchmarked
7 | "project": "xbatcher",
8 |
9 | // The project's homepage
10 | "project_url": "https://xbatcher.readthedocs.io/",
11 |
12 | // The URL or local path of the source code repository for the
13 | // project being benchmarked
14 | "repo": "..",
15 |
16 | // The Python project's subdirectory in your repo. If missing or
17 | // the empty string, the project is assumed to be located at the root
18 | // of the repository.
19 | // "repo_subdir": "",
20 |
21 | // Customizable commands for building, installing, and
22 | // uninstalling the project. See asv.conf.json documentation.
23 | //
24 | // "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"],
25 | // "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"],
26 | "build_command": [
27 | "python -m pip install build",
28 | "python -m build --wheel -o {build_cache_dir} {build_dir}"
29 | //"PIP_NO_BUILD_ISOLATION=false python -m pip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}"
30 | ],
31 |
32 | // List of branches to benchmark. If not provided, defaults to "master"
33 | // (for git) or "default" (for mercurial).
34 | "branches": ["main"], // for git
35 |
36 | // The DVCS being used. If not set, it will be automatically
37 | // determined from "repo" by looking at the protocol in the URL
38 | // (if remote), or by looking for special directories, such as
39 | // ".git" (if local).
40 | "dvcs": "git",
41 |
42 | // The tool to use to create environments. May be "conda",
43 | // "virtualenv" or other value depending on the plugins in use.
44 | // If missing or the empty string, the tool will be automatically
45 | // determined by looking for tools on the PATH environment
46 | // variable.
47 | "environment_type": "conda",
48 |
49 | // timeout in seconds for installing any dependencies in environment
50 | // defaults to 10 min
51 | "install_timeout": 600,
52 |
53 | // the base URL to show a commit for the project.
54 | // "show_commit_url": "http://github.com/xarray-contrib/xbatcher/commit/",
55 |
56 | // The Pythons you'd like to test against. If not provided, defaults
57 | // to the current version of Python used to run `asv`.
58 | // "pythons": ["3.8"],
59 |
60 | // The list of conda channel names to be searched for benchmark
61 | // dependency packages in the specified order
62 | "conda_channels": ["conda-forge"],
63 |
64 | // A conda environment file that is used for environment creation.
65 | "conda_environment_file": "../ci/requirements/asv.yml",
66 |
67 | // The matrix of dependencies to test. Each key of the "req"
68 | // requirements dictionary is the name of a package (in PyPI) and
69 | // the values are version numbers. An empty list or empty string
70 | // indicates to just test against the default (latest)
71 | // version. null indicates that the package is to not be
72 | // installed. If the package to be tested is only available from
73 | // PyPi, and the 'environment_type' is conda, then you can preface
74 | // the package name by 'pip+', and the package will be installed
75 | // via pip (with all the conda available packages installed first,
76 | // followed by the pip installed packages).
77 | //
78 | // The ``@env`` and ``@env_nobuild`` keys contain the matrix of
79 | // environment variables to pass to build and benchmark commands.
80 | // An environment will be created for every combination of the
81 | // cartesian product of the "@env" variables in this matrix.
82 | // Variables in "@env_nobuild" will be passed to every environment
83 | // during the benchmark phase, but will not trigger creation of
84 | // new environments. A value of ``null`` means that the variable
85 | // will not be set for the current combination.
86 | //
87 | // "matrix": {
88 | // "req": {
89 | // "numpy": ["1.6", "1.7"],
90 | // "six": ["", null], // test with and without six installed
91 | // "pip+emcee": [""] // emcee is only available for install with pip.
92 | // },
93 | // "env": {"ENV_VAR_1": ["val1", "val2"]},
94 | // "env_nobuild": {"ENV_VAR_2": ["val3", null]},
95 | // },
96 | // "matrix": {
97 | // "xarray": [""],
98 | // "numpy": [""],
99 | // "dask": [""],
100 | // },
101 |
102 | // Combinations of libraries/python versions can be excluded/included
103 | // from the set to test. Each entry is a dictionary containing additional
104 | // key-value pairs to include/exclude.
105 | //
106 | // An exclude entry excludes entries where all values match. The
107 | // values are regexps that should match the whole string.
108 | //
109 | // An include entry adds an environment. Only the packages listed
110 | // are installed. The 'python' key is required. The exclude rules
111 | // do not apply to includes.
112 | //
113 | // In addition to package names, the following keys are available:
114 | //
115 | // - python
116 | // Python version, as in the *pythons* variable above.
117 | // - environment_type
118 | // Environment type, as above.
119 | // - sys_platform
120 | // Platform, as in sys.platform. Possible values for the common
121 | // cases: 'linux2', 'win32', 'cygwin', 'darwin'.
122 | // - req
123 | // Required packages
124 | // - env
125 | // Environment variables
126 | // - env_nobuild
127 | // Non-build environment variables
128 | //
129 | // "exclude": [
130 | // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows
131 | // {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda
132 | // {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1
133 | // ],
134 | //
135 | // "include": [
136 | // // additional env for python2.7
137 | // {"python": "2.7", "req": {"numpy": "1.8"}, "env_nobuild": {"FOO": "123"}},
138 | // // additional env if run on windows+conda
139 | // {"platform": "win32", "environment_type": "conda", "python": "2.7", "req": {"libpython": ""}},
140 | // ],
141 |
142 | // The directory (relative to the current directory) that benchmarks are
143 | // stored in. If not provided, defaults to "benchmarks"
144 | "benchmark_dir": "benchmarks",
145 |
146 | // The directory (relative to the current directory) to cache the Python
147 | // environments in. If not provided, defaults to "env"
148 | "env_dir": ".asv/env",
149 |
150 | // The directory (relative to the current directory) that raw benchmark
151 | // results are stored in. If not provided, defaults to "results".
152 | "results_dir": ".asv/results",
153 |
154 | // The directory (relative to the current directory) that the html tree
155 | // should be written to. If not provided, defaults to "html".
156 | "html_dir": ".asv/html"
157 |
158 | // The number of characters to retain in the commit hashes.
159 | // "hash_length": 8,
160 |
161 | // `asv` will cache results of the recent builds in each
162 | // environment, making them faster to install next time. This is
163 | // the number of builds to keep, per environment.
164 | // "build_cache_size": 2,
165 |
166 | // The commits after which the regression search in `asv publish`
167 | // should start looking for regressions. Dictionary whose keys are
168 | // regexps matching to benchmark names, and values corresponding to
169 | // the commit (exclusive) after which to start looking for
170 | // regressions. The default is to start from the first commit
171 | // with results. If the commit is `null`, regression detection is
172 | // skipped for the matching benchmark.
173 | //
174 | // "regressions_first_commits": {
175 | // "some_benchmark": "352cdf", // Consider regressions only after this commit
176 | // "another_benchmark": null, // Skip regression detection altogether
177 | // },
178 |
179 | // The thresholds for relative change in results, after which `asv
180 | // publish` starts reporting regressions. Dictionary of the same
181 | // form as in ``regressions_first_commits``, with values
182 | // indicating the thresholds. If multiple entries match, the
183 | // maximum is taken. If no entry matches, the default is 5%.
184 | //
185 | // "regressions_thresholds": {
186 | // "some_benchmark": 0.01, // Threshold of 1%
187 | // "another_benchmark": 0.5, // Threshold of 50%
188 | // },
189 | }
190 |
--------------------------------------------------------------------------------
/asv_bench/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def parameterized(names, params):
5 | """
6 | Copied from xarray benchmarks:
7 | https://github.com/pydata/xarray/blob/main/asv_bench/benchmarks/__init__.py#L9-L15
8 | """
9 |
10 | def decorator(func):
11 | func.param_names = names
12 | func.params = params
13 | return func
14 |
15 | return decorator
16 |
17 |
18 | def randn(shape, frac_nan=None, chunks=None, seed=0):
19 | """
20 | Copied from xarray benchmarks:
21 | https://github.com/pydata/xarray/blob/main/asv_bench/benchmarks/__init__.py#L32-L46
22 | """
23 | rng = np.random.RandomState(seed)
24 | if chunks is None:
25 | x = rng.standard_normal(shape)
26 | else:
27 | import dask.array as da
28 |
29 | rng = da.random.RandomState(seed)
30 | x = rng.standard_normal(shape, chunks=chunks)
31 |
32 | if frac_nan is not None:
33 | inds = rng.choice(range(x.size), int(x.size * frac_nan))
34 | x.flat[inds] = np.nan
35 |
36 | return x
37 |
--------------------------------------------------------------------------------
/asv_bench/benchmarks/accessors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import xarray as xr
4 |
5 | import xbatcher # noqa: F401
6 |
7 | from . import parameterized, randn
8 |
9 | nx = 250
10 | ny = 50
11 | nt = 10
12 |
13 | randn_xyt = randn((nx, ny, nt), frac_nan=0.1)
14 |
15 |
16 | class Accessor:
17 | def setup(self, *args, **kwargs):
18 | self.ds = xr.Dataset(
19 | {
20 | 'var1': (('x', 'y', 't'), randn_xyt),
21 | },
22 | coords={
23 | 'x': np.arange(nx),
24 | 'y': np.linspace(0, 1, ny),
25 | 't': pd.date_range('1970-01-01', periods=nt, freq='D'),
26 | },
27 | )
28 |
29 | @parameterized(
30 | ['input_dims'],
31 | ([{'x': 10}, {'x': 10, 'y': 5}, {'x': 10, 'y': 5, 't': 2}],),
32 | )
33 | def time_input_dims(self, input_dims):
34 | """
35 | Benchmark simple batch generation case using xarray accessor
36 | Equivalent to subset of ``time_batch_input()``.
37 | """
38 | bg = self.ds.batch.generator(input_dims=input_dims)
39 | bg[0]
40 |
--------------------------------------------------------------------------------
/asv_bench/benchmarks/batches.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import xarray as xr
4 |
5 | from xbatcher import BatchGenerator
6 |
7 | from . import randn
8 |
9 | nx = 250
10 | ny = 50
11 | nt = 10
12 |
13 | randn_xyt = randn((nx, ny, nt), frac_nan=0.1)
14 |
15 |
16 | class Base:
17 | def setup(self):
18 | self.ds = xr.Dataset(
19 | {
20 | 'var1': (('x', 'y', 't'), randn_xyt),
21 | },
22 | coords={
23 | 'x': np.arange(nx),
24 | 'y': np.linspace(0, 1, ny),
25 | 't': pd.date_range('1970-01-01', periods=nt, freq='D'),
26 | },
27 | )
28 |
29 |
30 | class NoPreload(Base):
31 | """
32 | Get a batch from the generator without computing dask arrays.
33 | """
34 |
35 | def setup(self):
36 | super().setup()
37 | ds_dask = self.ds.chunk({'t': 2})
38 | self.bg = BatchGenerator(ds_dask, input_dims={'t': 2}, preload_batch=False)
39 |
40 | def time_next_batch(self):
41 | """
42 | Get a batch
43 | """
44 | next(iter(self.bg))
45 |
46 |
47 | class OneInputDim(Base):
48 | """
49 | Get a batch from the generator with one input_dim specified.
50 | """
51 |
52 | def setup(self):
53 | super().setup()
54 | self.bg = BatchGenerator(self.ds, input_dims={'x': 10})
55 |
56 | def time_next_batch(self):
57 | """
58 | Get a batch
59 | """
60 | next(iter(self.bg))
61 |
62 |
63 | class AllInputDim(Base):
64 | """
65 | Get a batch from the generator with all dimensions specified in input_dims.
66 | """
67 |
68 | def setup(self):
69 | super().setup()
70 | self.bg = BatchGenerator(self.ds, input_dims={'x': 10, 'y': 10, 't': 5})
71 |
72 | def time_next_batch(self):
73 | """
74 | Get a batch
75 | """
76 | next(iter(self.bg))
77 |
78 |
79 | class InputDimInputOverlap(Base):
80 | """
81 | Get a batch from the generator using input_dims and input_overlap.
82 | """
83 |
84 | def setup(self):
85 | super().setup()
86 | self.bg = BatchGenerator(
87 | self.ds, input_dims={'x': 10, 'y': 10}, input_overlap={'x': 5, 'y': 5}
88 | )
89 |
90 | def time_next_batch(self):
91 | """
92 | Get a batch
93 | """
94 | next(iter(self.bg))
95 |
96 |
97 | class InputDimConcat(Base):
98 | """
99 | Get a batch from the generator with input_dims and concat_input_dims
100 | """
101 |
102 | def setup(self):
103 | super().setup()
104 | self.bg = BatchGenerator(
105 | self.ds, input_dims={'x': 10, 'y': 10}, concat_input_dims=True
106 | )
107 |
108 | def time_next_batch(self):
109 | """
110 | Get a batch
111 | """
112 | next(iter(self.bg))
113 |
114 |
115 | class InputDimBatchDim(Base):
116 | """
117 | Get a batch from the generator with input_dims and batch_dims
118 | """
119 |
120 | def setup(self):
121 | super().setup()
122 | self.bg = BatchGenerator(
123 | self.ds, input_dims={'x': 10, 'y': 10}, batch_dims={'t': 2}
124 | )
125 |
126 | def time_next_batch(self):
127 | """
128 | Get a batch
129 | """
130 | next(iter(self.bg))
131 |
132 |
133 | class InputDimBatchDimConcat(Base):
134 | """
135 | Get a batch from the generator with input_dims, batch_dims and concat_input_dim
136 | """
137 |
138 | def setup(self):
139 | super().setup()
140 | self.bg = BatchGenerator(
141 | self.ds,
142 | input_dims={'x': 5, 'y': 5},
143 | batch_dims={'x': 10, 'y': 10},
144 | concat_input_dims=True,
145 | )
146 |
147 | def time_next_batch(self):
148 | """
149 | Get a batch
150 | """
151 | next(iter(self.bg))
152 |
153 |
154 | class InputDimInputOverlapConcat(Base):
155 | """
156 | Get a batch from the generator with input_dims, input_overlap and concat_input_dim
157 | """
158 |
159 | def setup(self):
160 | super().setup()
161 | self.bg = BatchGenerator(
162 | self.ds,
163 | input_dims={'x': 10, 'y': 10},
164 | input_overlap={'x': 5, 'y': 5},
165 | concat_input_dims=True,
166 | )
167 |
168 | def time_next_batch(self):
169 | """
170 | Get a batch
171 | """
172 | next(iter(self.bg))
173 |
--------------------------------------------------------------------------------
/asv_bench/benchmarks/generators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import xarray as xr
4 |
5 | from xbatcher import BatchGenerator
6 |
7 | from . import parameterized, randn
8 |
9 | nx = 250
10 | ny = 50
11 | nt = 10
12 |
13 | randn_xyt = randn((nx, ny, nt), frac_nan=0.1)
14 |
15 |
16 | class Generator:
17 | def setup(self, *args, **kwargs):
18 | self.ds = xr.Dataset(
19 | {
20 | 'var1': (('x', 'y', 't'), randn_xyt),
21 | },
22 | coords={
23 | 'x': np.arange(nx),
24 | 'y': np.linspace(0, 1, ny),
25 | 't': pd.date_range('1970-01-01', periods=nt, freq='D'),
26 | },
27 | )
28 |
29 | @parameterized(['preload_batch'], ([True, False]))
30 | def time_batch_preload(self, preload_batch):
31 | """
32 | Construct a generator on a chunked DataSet with and without preloading
33 | batches.
34 | """
35 | ds_dask = self.ds.chunk({'t': 2})
36 | BatchGenerator(ds_dask, input_dims={'t': 2}, preload_batch=preload_batch)
37 |
38 | @parameterized(
39 | ['input_dims'],
40 | ([{'x': 10}, {'x': 10, 'y': 5}, {'x': 10, 'y': 5, 't': 2}],),
41 | )
42 | def time_input_dims(self, input_dims):
43 | """
44 | Benchmark simple batch generation case.
45 | """
46 | BatchGenerator(
47 | self.ds,
48 | input_dims=input_dims,
49 | )
50 |
51 | def time_input_dims_and_input_overlap(self):
52 | """
53 | Benchmark simple batch generation case.
54 | """
55 | BatchGenerator(
56 | self.ds, input_dims={'x': 10, 'y': 10}, input_overlap={'x': 5, 'y': 5}
57 | )
58 |
59 | @parameterized(['concat_input_dims'], (['True', 'False']))
60 | def time_input_dims_and_concat_input_dims(self, concat_input_dims):
61 | """
62 | Benchmark concat_input_dims
63 | """
64 | BatchGenerator(
65 | self.ds, input_dims={'x': 10, 'y': 5}, concat_input_dims=concat_input_dims
66 | )
67 |
68 | @parameterized(
69 | ['input_dims', 'batch_dims'],
70 | ([{'x': 10}, {'x': 10, 'y': 5}],),
71 | )
72 | def time_input_dims_and_batch_dims(self, input_dims):
73 | """
74 | Benchmark batch generator with input_dims and batch_dims.
75 | """
76 | BatchGenerator(self.ds, input_dims=input_dims, batch_dims={'t': 2})
77 |
78 | @parameterized(
79 | ['concat_input_dims'],
80 | ([True, False]),
81 | )
82 | def time_input_dims_batch_dims_and_concat_input_dims(self, concat_input_dims):
83 | """
84 | Construct a generator on a DataSet with and without concatenating
85 | chunks specified by ``input_dims`` into the batch dimension.
86 | """
87 | BatchGenerator(
88 | self.ds,
89 | input_dims={'x': 10, 'y': 5},
90 | batch_dims={'x': 20, 'y': 10},
91 | concat_input_dims=concat_input_dims,
92 | )
93 |
94 | @parameterized(
95 | ['concat_input_dims'],
96 | ([True, False]),
97 | )
98 | def time_input_dims_input_overlap_and_concat_input_dims(self, concat_input_dims):
99 | """
100 | Construct a generator on a DataSet with and without concatenating
101 | chunks specified by ``input_dims`` into the batch dimension.
102 | """
103 | BatchGenerator(
104 | self.ds,
105 | input_dims={'x': 10, 'y': 10},
106 | input_overlap={'x': 5, 'y': 5},
107 | concat_input_dims=concat_input_dims,
108 | )
109 |
--------------------------------------------------------------------------------
/asv_bench/benchmarks/loaders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import xarray as xr
4 |
5 | from xbatcher import BatchGenerator
6 | from xbatcher.loaders.torch import IterableDataset, MapDataset
7 |
8 | from . import randn
9 |
10 | nx = 250
11 | ny = 50
12 |
13 | randn_xy = randn((nx, ny), frac_nan=0.1)
14 | randn_y = randn((ny), frac_nan=0.1)
15 |
16 |
17 | class TorchLoader:
18 | def setup(self, *args, **kwargs):
19 | self.ds = xr.Dataset(
20 | {
21 | 'var1': (('x', 'y'), randn_xy),
22 | 'var2': (('y'), randn_y),
23 | },
24 | coords={
25 | 'x': np.arange(nx),
26 | 'y': np.linspace(0, 1, ny),
27 | },
28 | )
29 | self.x_gen = BatchGenerator(self.ds['var1'], {'y': 10})
30 | self.y_gen = BatchGenerator(self.ds['var2'], {'y': 10})
31 |
32 | def time_map_dataset(self):
33 | """
34 | Benchmark MapDataset integration with torch DataLoader.
35 | """
36 | dataset = MapDataset(self.x_gen, self.y_gen)
37 | loader = torch.utils.data.DataLoader(dataset)
38 | next(iter(loader))
39 |
40 | def time_iterable_dataset(self):
41 | """
42 | Benchmark IterableDataset integration with torch DataLoader.
43 | """
44 | dataset = IterableDataset(self.x_gen, self.y_gen)
45 | loader = torch.utils.data.DataLoader(dataset)
46 | next(iter(loader))
47 |
--------------------------------------------------------------------------------
/ci/requirements/asv.yml:
--------------------------------------------------------------------------------
1 | name: xbatcher-benchmarks
2 | channels:
3 | - conda-forge
4 | - nodefaults
5 | dependencies:
6 | # Required dependencies
7 | - python=3.10
8 | - dask
9 | - numpy
10 | - xarray
11 | - pytorch
12 | - tensorflow
13 | # Xbatcher installation
14 | - pip
15 |
--------------------------------------------------------------------------------
/ci/requirements/doc.yml:
--------------------------------------------------------------------------------
1 | name: xbatcher-docs
2 | channels:
3 | - conda-forge
4 | - nodefaults
5 | dependencies:
6 | - python=3.10
7 | - dask
8 | - pydata-sphinx-theme
9 | - ipython
10 | - matplotlib
11 | - numpy
12 | - numpydoc
13 | - pytest
14 | - sphinx<6
15 | - sphinx-autosummary-accessors
16 | - sphinx-copybutton
17 | - sphinx-design
18 | - xarray
19 | # For examples
20 | - s3fs
21 | - ipykernel
22 | - nbsphinx
23 | - netcdf4
24 | - pooch
25 | - zarr
26 | - pytorch
27 | - keras
28 | - tensorflow
29 | # Editable xbatcher installation
30 | - pip
31 |
32 | - pip:
33 | # relative to this file. Needs to be editable to be accepted.
34 | - -e ../..
35 |
--------------------------------------------------------------------------------
/ci/requirements/environment.yml:
--------------------------------------------------------------------------------
1 | name: xbatcher-tests
2 | channels:
3 | - conda-forge
4 | - nodefaults
5 | dependencies:
6 | # Required dependencies
7 | - python=3.10
8 | - dask
9 | - numpy
10 | - xarray
11 | # Dev dependencies
12 | - s3fs
13 | - asv
14 | - pre-commit
15 | - pytest
16 | - pytest-cov
17 | - pytorch
18 | - tensorflow
19 | - zarr
20 | # Style checks
21 | - black
22 | - blackdoc
23 | - docformatter
24 | - flake8
25 | - isort>=5
26 | - pylint
27 | # Xbatcher installation
28 | - pip
29 | - pip:
30 | # relative to this file. Needs to be editable to be accepted.
31 | - -e ../..
32 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import pytest
3 |
4 |
5 | @pytest.fixture(autouse=True)
6 | def add_standard_imports(doctest_namespace, tmpdir):
7 | import numpy as np
8 |
9 | # always seed numpy.random to make the examples deterministic
10 | np.random.seed(0)
11 |
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | # User-friendly check for sphinx-build
11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
13 | endif
14 |
15 | # Internal variables.
16 | PAPEROPT_a4 = -D latex_paper_size=a4
17 | PAPEROPT_letter = -D latex_paper_size=letter
18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
19 | # the i18n builder cannot share the environment and doctrees with the others
20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
21 |
22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
23 |
24 | help:
25 | @echo "Please use \`make ' where is one of"
26 | @echo " html to make standalone HTML files"
27 | @echo " dirhtml to make HTML files named index.html in directories"
28 | @echo " singlehtml to make a single large HTML file"
29 | @echo " pickle to make pickle files"
30 | @echo " json to make JSON files"
31 | @echo " htmlhelp to make HTML files and a HTML help project"
32 | @echo " qthelp to make HTML files and a qthelp project"
33 | @echo " devhelp to make HTML files and a Devhelp project"
34 | @echo " epub to make an epub"
35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
36 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
38 | @echo " text to make text files"
39 | @echo " man to make manual pages"
40 | @echo " texinfo to make Texinfo files"
41 | @echo " info to make Texinfo files and run them through makeinfo"
42 | @echo " gettext to make PO message catalogs"
43 | @echo " changes to make an overview of all changed/added/deprecated items"
44 | @echo " xml to make Docutils-native XML files"
45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
46 | @echo " linkcheck to check all external links for integrity"
47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
48 |
49 | clean:
50 | rm -rf $(BUILDDIR)/*
51 | rm -rf generated/*
52 |
53 | html:
54 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
55 | @echo
56 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
57 |
58 | dirhtml:
59 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
60 | @echo
61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
62 |
63 | singlehtml:
64 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
65 | @echo
66 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
67 |
68 | pickle:
69 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
70 | @echo
71 | @echo "Build finished; now you can process the pickle files."
72 |
73 | json:
74 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
75 | @echo
76 | @echo "Build finished; now you can process the JSON files."
77 |
78 | htmlhelp:
79 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
80 | @echo
81 | @echo "Build finished; now you can run HTML Help Workshop with the" \
82 | ".hhp project file in $(BUILDDIR)/htmlhelp."
83 |
84 | qthelp:
85 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
86 | @echo
87 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
88 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
89 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/xgcm.qhcp"
90 | @echo "To view the help file:"
91 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/xgcm.qhc"
92 |
93 | devhelp:
94 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
95 | @echo
96 | @echo "Build finished."
97 | @echo "To view the help file:"
98 | @echo "# mkdir -p $$HOME/.local/share/devhelp/xgcm"
99 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/xgcm"
100 | @echo "# devhelp"
101 |
102 | epub:
103 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
104 | @echo
105 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
106 |
107 | latex:
108 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
109 | @echo
110 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
111 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
112 | "(use \`make latexpdf' here to do that automatically)."
113 |
114 | latexpdf:
115 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
116 | @echo "Running LaTeX files through pdflatex..."
117 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
118 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
119 |
120 | latexpdfja:
121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
122 | @echo "Running LaTeX files through platex and dvipdfmx..."
123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
125 |
126 | text:
127 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
128 | @echo
129 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
130 |
131 | man:
132 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
133 | @echo
134 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
135 |
136 | texinfo:
137 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
138 | @echo
139 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
140 | @echo "Run \`make' in that directory to run these through makeinfo" \
141 | "(use \`make info' here to do that automatically)."
142 |
143 | info:
144 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
145 | @echo "Running Texinfo files through makeinfo..."
146 | make -C $(BUILDDIR)/texinfo info
147 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
148 |
149 | gettext:
150 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
151 | @echo
152 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
153 |
154 | changes:
155 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
156 | @echo
157 | @echo "The overview file is in $(BUILDDIR)/changes."
158 |
159 | linkcheck:
160 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
161 | @echo
162 | @echo "Link check complete; look for any errors in the above output " \
163 | "or in $(BUILDDIR)/linkcheck/output.txt."
164 |
165 | doctest:
166 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
167 | @echo "Testing of doctests in the sources finished, look at the " \
168 | "results in $(BUILDDIR)/doctest/output.txt."
169 |
170 | xml:
171 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
172 | @echo
173 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
174 |
175 | pseudoxml:
176 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
177 | @echo
178 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
179 |
--------------------------------------------------------------------------------
/doc/_static/logo.svg:
--------------------------------------------------------------------------------
1 |
59 |
--------------------------------------------------------------------------------
/doc/_static/switcher.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "name": "dev",
4 | "version": "latest",
5 | "url": "https://xbatcher.readthedocs.io/en/latest/"
6 | },
7 | {
8 | "version": "0.4.0",
9 | "url": "https://xbatcher.readthedocs.io/en/v0.4.0/"
10 | },
11 | {
12 | "version": "0.3.0",
13 | "url": "https://xbatcher.readthedocs.io/en/v0.3.0/"
14 | },
15 | {
16 | "version": "0.2.0",
17 | "url": "https://xbatcher.readthedocs.io/en/v0.2.0/"
18 | },
19 | {
20 | "version": "0.1.0",
21 | "url": "https://xbatcher.readthedocs.io/en/0.1.0/"
22 | }
23 | ]
24 |
--------------------------------------------------------------------------------
/doc/api.rst:
--------------------------------------------------------------------------------
1 | .. _api:
2 |
3 | API reference
4 | -------------
5 |
6 | This page provides an auto-generated summary of Xbatcher's API.
7 |
8 | Core
9 | ====
10 |
11 | .. currentmodule:: xbatcher
12 |
13 | .. autosummary::
14 | :toctree: generated/
15 |
16 | BatchGenerator
17 | BatchSchema
18 |
19 | Xbatcher Xarray accessors
20 | =========================
21 |
22 | .. currentmodule:: xarray
23 |
24 | .. autosummary::
25 | :toctree: generated/
26 | :template: autosummary/accessor_method.rst
27 |
28 | Dataset.batch.generator
29 | DataArray.batch.generator
30 |
31 | Dataloaders
32 | ===========
33 |
34 | .. currentmodule:: xbatcher
35 |
36 | .. autosummary::
37 | :toctree: generated/
38 |
39 | loaders.torch.MapDataset
40 | loaders.torch.IterableDataset
41 | loaders.keras.CustomTFDataset
42 | :members:
43 |
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | #
2 | # xbatcher documentation build configuration file, created by
3 | # sphinx-quickstart on Sat Aug 29 00:18:20 2015.
4 | #
5 | # This file is execfile()d with the current directory set to its
6 | # containing dir.
7 | #
8 | # Note that not all possible configuration values are present in this
9 | # autogenerated file.
10 | #
11 | # All configuration values have a default; values that are commented out
12 | # serve to show the default.
13 |
14 | # type: ignore
15 |
16 | import datetime
17 | import os
18 | import sys
19 |
20 | import sphinx_autosummary_accessors
21 |
22 | import xbatcher
23 |
24 | # If extensions (or modules to document with autodoc) are in another directory,
25 | # add these directories to sys.path here. If the directory is relative to the
26 | # documentation root, use os.path.abspath to make it absolute, like shown here.
27 | # sys.path.insert(0, os.path.abspath('.'))
28 | # sys.path.insert(os.path.abspath('..'))
29 |
30 | print('python exec:', sys.executable)
31 | print('sys.path:', sys.path)
32 | print('xbatcher.version:', xbatcher.__version__)
33 |
34 |
35 | # -- General configuration ------------------------------------------------
36 |
37 | # If your documentation needs a minimal Sphinx version, state it here.
38 | # needs_sphinx = '1.0'
39 |
40 | # Add any Sphinx extension module names here, as strings. They can be
41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
42 | # ones.
43 | extensions = [
44 | 'sphinx.ext.mathjax',
45 | 'sphinx.ext.autodoc',
46 | 'sphinx.ext.autosummary',
47 | 'sphinx.ext.extlinks',
48 | 'sphinx.ext.viewcode',
49 | 'sphinx.ext.intersphinx',
50 | 'numpydoc',
51 | 'nbsphinx',
52 | 'IPython.sphinxext.ipython_directive',
53 | 'IPython.sphinxext.ipython_console_highlighting',
54 | 'sphinx_autosummary_accessors',
55 | 'sphinx_copybutton',
56 | 'sphinx_design',
57 | ]
58 |
59 | nbsphinx_execute = 'auto'
60 |
61 |
62 | autodoc_mock_imports = ['torch', 'tensorflow']
63 |
64 | # link to github issues
65 | extlinks = {'issue': ('https://github.com/xarray-contrib/xbatcher/issues/%s', '#%s')}
66 |
67 | # sphinx-copybutton configurations (from https://github.com/pydata/xarray/blob/main/doc/conf.py)
68 | copybutton_prompt_text = r'>>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: '
69 | copybutton_prompt_is_regexp = True
70 |
71 | autosummary_generate = True
72 | numpydoc_class_members_toctree = False
73 | numpydoc_show_class_members = False
74 |
75 | # Add any paths that contain templates here, relative to this directory.
76 | templates_path = ['_templates', sphinx_autosummary_accessors.templates_path]
77 |
78 | # The suffix of source filenames.
79 | source_suffix = '.rst'
80 |
81 | # The encoding of source files.
82 | # source_encoding = 'utf-8-sig'
83 |
84 | # The master toctree document.
85 | master_doc = 'index'
86 |
87 | # General information about the project.
88 | project = 'xbatcher'
89 | copyright = f'2016-{datetime.datetime.now().year}, xbatcher developers'
90 |
91 | # The version info for the project you're documenting, acts as replacement for
92 | # |version| and |release|, also used in various other places throughout the
93 | # built documents.
94 | #
95 | # The short X.Y version.
96 | version = xbatcher.__version__
97 | # The full version, including alpha/beta/rc tags.
98 | release = xbatcher.__version__
99 |
100 | # The language for content autogenerated by Sphinx. Refer to documentation
101 | # for a list of supported languages.
102 | # language = None
103 |
104 | # There are two options for replacing |today|: either, you set today to some
105 | # non-false value, then it is used:
106 | # today = ''
107 | # Else, today_fmt is used as the format for a strftime call.
108 | # today_fmt = '%B %d, %Y'
109 |
110 | # List of patterns, relative to source directory, that match files and
111 | # directories to ignore when looking for source files.
112 | exclude_patterns = [
113 | '_build',
114 | '**.ipynb_checkpoints',
115 | 'user-guide/create-fashion-mnist-dataset.ipynb',
116 | ]
117 |
118 | # The reST default role (used for this markup: `text`) to use for all
119 | # documents.
120 | # default_role = None
121 |
122 | # If true, '()' will be appended to :func: etc. cross-reference text.
123 | # add_function_parentheses = True
124 |
125 | # If true, the current module name will be prepended to all description
126 | # unit titles (such as .. function::).
127 | # add_module_names = True
128 |
129 | # If true, sectionauthor and moduleauthor directives will be shown in the
130 | # output. They are ignored by default.
131 | # show_authors = False
132 |
133 | # The name of the Pygments (syntax highlighting) style to use.
134 | pygments_style = 'sphinx'
135 |
136 | # A list of ignored prefixes for module index sorting.
137 | # modindex_common_prefix = []
138 |
139 | # If true, keep warnings as "system message" paragraphs in the built documents.
140 | # keep_warnings = False
141 |
142 |
143 | # -- Options for HTML output ----------------------------------------------
144 |
145 | # The theme to use for HTML and HTML Help pages. See the documentation for
146 | # a list of builtin themes.
147 | # tml_theme = 'default'
148 | html_theme = 'pydata_sphinx_theme'
149 | html_logo = '_static/logo.svg'
150 | html_favicon = '_static/logo.svg'
151 |
152 | # The following is from the pydata-sphinx-theme settings (https://github.com/pydata/pydata-sphinx-theme/blob/main/docs/conf.py)
153 | # Define the json_url for our version switcher.
154 | json_url = 'https://xbatcher.readthedocs.io/en/latest/_static/switcher.json'
155 |
156 | # Define the version we use for matching in the version switcher.
157 | version_match = os.environ.get('READTHEDOCS_VERSION')
158 | # If READTHEDOCS_VERSION doesn't exist, we're not on RTD
159 | # If it is an integer, we're in a PR build and the version isn't correct.
160 | if not version_match or version_match.isdigit():
161 | # For local development, infer the version to match from the package.
162 | release = xbatcher.__version__
163 | if 'dev' in release or 'post' in release or 'rc' in release:
164 | version_match = 'latest'
165 | # We want to keep the relative reference if we are in dev mode
166 | # but we want the whole url if we are effectively in a released version
167 | json_url = '_static/switcher.json'
168 | else:
169 | version_match = 'v' + release
170 |
171 | print(f'release: {release}')
172 |
173 |
174 | # Theme options are theme-specific and customize the look and feel of a theme
175 | # further. For a list of options available for each theme, see the
176 | # documentation.
177 | html_theme_options = {
178 | 'github_url': 'https://github.com/xarray-contrib/xbatcher',
179 | 'switcher': {
180 | 'json_url': json_url,
181 | 'version_match': version_match,
182 | },
183 | 'logo': {
184 | 'text': 'Xbatcher',
185 | 'alt_text': 'Xbatcher',
186 | },
187 | 'navbar_align': 'left', # [left, content, right] For testing that the navbar items align properly
188 | 'navbar_center': ['version-switcher', 'navbar-nav'],
189 | }
190 |
191 | # Add any paths that contain custom themes here, relative to this directory.
192 | # html_theme_path = []
193 |
194 | # The name for this set of Sphinx documents. If None, it defaults to
195 | # " v documentation".
196 | # html_title = None
197 |
198 | # A shorter title for the navigation bar. Default is the same as html_title.
199 | # html_short_title = None
200 |
201 | # The name of an image file (relative to this directory) to place at the top
202 | # of the sidebar.
203 | # html_logo = None
204 |
205 | # The name of an image file (within the static path) to use as favicon of the
206 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
207 | # pixels large.
208 | # html_favicon = None
209 |
210 | # Add any paths that contain custom static files (such as style sheets) here,
211 | # relative to this directory. They are copied after the builtin static files,
212 | # so a file named "default.css" will overwrite the builtin "default.css".
213 | html_static_path = ['_static']
214 |
215 | # Add any extra paths that contain custom files (such as robots.txt or
216 | # .htaccess) here, relative to this directory. These files are copied
217 | # directly to the root of the documentation.
218 | # html_extra_path = []
219 |
220 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
221 | # using the given strftime format.
222 | # html_last_updated_fmt = '%b %d, %Y'
223 |
224 | # If true, SmartyPants will be used to convert quotes and dashes to
225 | # typographically correct entities.
226 | # html_use_smartypants = True
227 |
228 | # Custom sidebar templates, maps document names to template names.
229 | # html_sidebars = {}
230 |
231 | # Additional templates that should be rendered to pages, maps page names to
232 | # template names.
233 | # html_additional_pages = {}
234 |
235 | # If false, no module index is generated.
236 | # html_domain_indices = True
237 |
238 | # If false, no index is generated.
239 | # html_use_index = True
240 |
241 | # If true, the index is split into individual pages for each letter.
242 | # html_split_index = False
243 |
244 | # If true, links to the reST sources are added to the pages.
245 | # html_show_sourcelink = True
246 |
247 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
248 | # html_show_sphinx = True
249 |
250 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
251 | # html_show_copyright = True
252 |
253 | # If true, an OpenSearch description file will be output, and all pages will
254 | # contain a tag referring to it. The value of this option must be the
255 | # base URL from which the finished HTML is served.
256 | # html_use_opensearch = ''
257 |
258 | # This is the file name suffix for HTML files (e.g. ".xhtml").
259 | # html_file_suffix = None
260 |
261 | # Output file base name for HTML help builder.
262 | htmlhelp_basename = 'xbatcherdoc'
263 |
264 |
265 | # -- Options for LaTeX output ---------------------------------------------
266 |
267 | latex_elements = {
268 | # The paper size ('letterpaper' or 'a4paper').
269 | # 'papersize': 'letterpaper',
270 | # The font size ('10pt', '11pt' or '12pt').
271 | # 'pointsize': '10pt',
272 | # Additional stuff for the LaTeX preamble.
273 | # 'preamble': '',
274 | }
275 |
276 | # Grouping the document tree into LaTeX files. List of tuples
277 | # (source start file, target name, title,
278 | # author, documentclass [howto, manual, or own class]).
279 | latex_documents = [
280 | (
281 | 'index',
282 | 'xbatcher.tex',
283 | 'xbatcher Documentation',
284 | 'xbatcher developers',
285 | 'manual',
286 | ),
287 | ]
288 |
289 | # The name of an image file (relative to this directory) to place at the top of
290 | # the title page.
291 | # latex_logo = None
292 |
293 | # For "manual" documents, if this is true, then toplevel headings are parts,
294 | # not chapters.
295 | # latex_use_parts = False
296 |
297 | # If true, show page references after internal links.
298 | # latex_show_pagerefs = False
299 |
300 | # If true, show URL addresses after external links.
301 | # latex_show_urls = False
302 |
303 | # Documents to append as an appendix to all manuals.
304 | # latex_appendices = []
305 |
306 | # If false, no module index is generated.
307 | # latex_domain_indices = True
308 |
309 |
310 | # -- Options for manual page output ---------------------------------------
311 |
312 | # One entry per manual page. List of tuples
313 | # (source start file, name, description, authors, manual section).
314 | man_pages = [
315 | (
316 | 'index',
317 | 'xbatcher',
318 | 'xbatcher Documentation',
319 | ['xbatcher developers'],
320 | 1,
321 | )
322 | ]
323 |
324 | # If true, show URL addresses after external links.
325 | # man_show_urls = False
326 |
327 |
328 | # -- Options for Texinfo output -------------------------------------------
329 |
330 | # Grouping the document tree into Texinfo files. List of tuples
331 | # (source start file, target name, title, author,
332 | # dir menu entry, description, category)
333 | texinfo_documents = [
334 | (
335 | 'index',
336 | 'xbatcher',
337 | 'xbatcher Documentation',
338 | 'xbatcher developers',
339 | 'xbatcher',
340 | 'One line description of project.',
341 | 'Miscellaneous',
342 | ),
343 | ]
344 |
345 | # Documents to append as an appendix to all manuals.
346 | # texinfo_appendices = []
347 |
348 | # If false, no module index is generated.
349 | # texinfo_domain_indices = True
350 |
351 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
352 | # texinfo_show_urls = 'footnote'
353 |
354 | # If true, do not generate a @detailmenu in the "Top" node's menu.
355 | # texinfo_no_detailmenu = False
356 |
357 |
358 | # Example configuration for intersphinx: refer to the Python standard library.
359 | intersphinx_mapping = {
360 | 'python': ('https://docs.python.org/3/', None),
361 | 'xarray': ('http://xarray.pydata.org/en/stable/', None),
362 | }
363 |
--------------------------------------------------------------------------------
/doc/contributing.rst:
--------------------------------------------------------------------------------
1 | .. _contributing:
2 |
3 | ******************
4 | Contributing Guide
5 | ******************
6 |
7 | .. note::
8 |
9 | Large parts of this document came from the `Xarray Contributing
10 | Guide `_, which is based
11 | on the `Pandas Contributing Guide
12 | `_.
13 |
14 | Bug reports and feature requests
15 | ================================
16 |
17 | To report bugs or request new features, head over to the `xbatcher repository
18 | `_.
19 |
20 | Contributing code
21 | ==================
22 |
23 | `GitHub has instructions `__ for
24 | installing git, setting up your SSH key, and configuring git. All these steps
25 | need to be completed for you to work between your local repository and GitHub.
26 |
27 | .. _contributing.forking:
28 |
29 | Forking
30 | -------
31 |
32 | You will need your own fork to work on the code. Go to the `xbatcher project
33 | page `_ and hit the ``Fork`` button.
34 | You will need to clone your fork to your machine::
35 |
36 | git clone git@github.com:yourusername/xbatcher.git
37 | cd xbatcher
38 | git remote add upstream git@github.com:xarray-contrib/xbatcher.git
39 |
40 | This creates the directory ``xbatcher`` and connects your repository to
41 | the upstream (main project) *xbatcher* repository.
42 |
43 | .. _contributing.dev_env:
44 |
45 | Creating a development environment
46 | ----------------------------------
47 |
48 | To test out code changes, you'll need to build *xbatcher* from source, which
49 | requires a Python environment. If you're making documentation changes, you can
50 | skip to :ref:`contributing.documentation` but you won't be able to build the
51 | documentation locally before pushing your changes.
52 |
53 | .. _contributiong.dev_python:
54 |
55 | Creating a Python Environment
56 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57 |
58 | Before starting any development, you'll need to create an isolated xbatcher
59 | development environment:
60 |
61 | - Install either `Anaconda `_ or `miniconda
62 | `_
63 | - Make sure your conda is up to date (``conda update conda``)
64 | - Make sure that you have :ref:`cloned the repository `
65 | - ``cd`` to the *xbatcher* source directory
66 |
67 | First we'll create and activate the build environment:
68 |
69 | .. code-block:: sh
70 |
71 | conda env create --file ci/requirements/environment.yml
72 | conda activate xbatcher-tests
73 |
74 | At this point you should be able to import *xbatcher* from your locally
75 | built version:
76 |
77 | .. code-block:: sh
78 |
79 | $ python # start an interpreter
80 | >>> import xbatcher
81 | >>> xbatcher.__version__
82 |
83 | This will create the new environment, and not touch any of your existing environments,
84 | nor any existing Python installation.
85 |
86 | To view your environments::
87 |
88 | conda info --envs
89 |
90 | To return to your base environment::
91 |
92 | conda deactivate
93 |
94 | See the full conda docs `here `__.
95 |
96 | Setting up pre-commit
97 | ~~~~~~~~~~~~~~~~~~~~~
98 |
99 | We use `pre-commit `_ to manage code linting and style.
100 | To set up pre-commit after activating your conda environment, run:
101 |
102 | .. code-block:: sh
103 |
104 | pre-commit install
105 |
106 | Creating a branch
107 | -----------------
108 |
109 | You want your ``main`` branch to reflect only production-ready code, so create a
110 | feature branch before making your changes. For example::
111 |
112 | git branch shiny-new-feature
113 | git checkout shiny-new-feature
114 |
115 | The above can be simplified to::
116 |
117 | git checkout -b shiny-new-feature
118 |
119 | This changes your working directory to the shiny-new-feature branch. Keep any
120 | changes in this branch specific to one bug or feature so it is clear
121 | what the branch brings to *xbatcher*. You can have many "shiny-new-features"
122 | and switch in between them using the ``git checkout`` command.
123 |
124 | To update this branch, you need to retrieve the changes from the ``main`` branch::
125 |
126 | git fetch upstream
127 | git merge upstream/main
128 |
129 | This will combine your commits with the latest *xbatcher* git ``main``. If this
130 | leads to merge conflicts, you must resolve these before submitting your pull
131 | request. If you have uncommitted changes, you will need to ``git stash`` them
132 | prior to updating. This will effectively store your changes, which can be
133 | reapplied after updating.
134 |
135 | Running the test suite
136 | ----------------------
137 |
138 | *xbatcher* uses the `pytest `_
139 | framework for testing. You can run the test suite using::
140 |
141 | pytest xbatcher
142 |
143 |
144 |
145 | Running the performance test suite
146 | ----------------------------------
147 |
148 | *xbatcher* is starting a suite of benchmarking tests using
149 | `asv `__ to enable easy monitoring of
150 | the performance of critical operations. These benchmarks are all found in the
151 | ``asv_bench`` directory.
152 |
153 | To use all features of asv, you will need either ``conda`` or ``virtualenv``.
154 | For more details please check the `asv installation webpage
155 | `_.
156 |
157 | To install asv::
158 |
159 | pip install git+https://github.com/airspeed-velocity/asv
160 |
161 | If you need to run a benchmark, change your directory to ``asv_bench/`` and run::
162 |
163 | asv continuous -f 1.1 main
164 |
165 | You can replace ``my-branch`` with the name of the branch you are working on.
166 | The output will include "BENCHMARKS NOT SIGNIFICANTLY CHANGED" if the
167 | benchmarks did not change by more than 10%.
168 |
169 | The command uses ``conda`` by default for creating the benchmark
170 | environments. If you want to use virtualenv instead, write::
171 |
172 | asv continuous -f 1.1 -E virtualenv main
173 |
174 | The ``-E virtualenv`` option should be added to all ``asv`` commands
175 | that run benchmarks. The default value is defined in ``asv.conf.json``.
176 |
177 | If you want to only run a specific group of tests from a file, you can do it
178 | using ``.`` as a separator. For example::
179 |
180 | asv continuous -f 1.1 main HEAD -b benchmarks.Generator.time_batch_preload
181 |
182 | will only run the ``Generator.time_batch_preload`` benchmark defined in
183 | ``benchmarks.py``.
184 |
185 | Information on how to write a benchmark and how to use asv can be found in the
186 | `asv documentation `_.
187 |
188 | Contributing documentation
189 | ==========================
190 |
191 | We greatly appreciate documentation improvements. The docs are built from the docstrings
192 | in the code and the docs in the ``doc`` directory.
193 |
194 | To build the documentation, you will need to requirements listed in ``ci/requirements/doc.yml``.
195 | You can create an environment for building the documentation using::
196 |
197 | conda env create --file ci/requirements/docs.yml
198 | conda activate xbatcher-docs
199 |
200 | You can then build the documentation using::
201 |
202 | cd docs
203 | make html
204 |
205 | Contributing changes
206 | ====================
207 |
208 | Once you've made changes, you can see them by typing::
209 |
210 | git status
211 |
212 | If you have created a new file, it is not being tracked by git. Add it by typing::
213 |
214 | git add path/to/file-to-be-added.py
215 |
216 | The following defines how a commit message should be structured:
217 |
218 | * A subject line with `< 72` chars.
219 | * One blank line.
220 | * Optionally, a commit message body.
221 |
222 | Now you can commit your changes in your local repository::
223 |
224 | git commit -m
225 |
226 | When you want your changes to appear publicly on your GitHub page, push your
227 | commits to a branch off your fork::
228 |
229 | git push origin shiny-new-feature
230 |
231 | Here ``origin`` is the default name given to your remote repository on GitHub.
232 | You can see the remote repositories::
233 |
234 | git remote -v
235 |
236 | If you navigate to your branch on GitHub, you should see a banner to submit a pull
237 | request to the *xbatcher* repository.
238 |
239 | .. _contributing.ci:
240 |
241 | Continuous integration
242 | ======================
243 |
244 | Continuous integration is done with `GitHub Actions `_.
245 |
246 | There are currently 3 workflows configured:
247 |
248 | - `main.yaml `_ - Run test suite with pytest.
249 | - `pypi-release.yaml `_ - Publish
250 | wheels to TestPyPI and PyPI on a tagged release. The pull request trigger can be uncommented to test a release using Test PyPI.
251 | - `release-drafter.yml `_ - Draft
252 | release notes based on PR titles and labels.
253 |
--------------------------------------------------------------------------------
/doc/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "sticky-exhibit",
7 | "metadata": {},
8 | "source": [
9 | "# Demo\n",
10 | "\n",
11 | "Author: Cindy Chiao\n",
12 | "\n",
13 | "## What is xbatcher? \n",
14 | "Xbatcher is a small library for iterating through Xarray objects (DataArrays and Datasets) in batches. The goal is to make it easy to feed Xarray objects to machine learning libraries such as Keras and PyTorch. \n",
15 | "\n",
16 | "## What is included in this notebook?\n",
17 | "* showcase current abilities with example data \n",
18 | "* brief discussion of current development track and ideas for future work "
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "banner-importance",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import xarray as xr\n",
29 | "\n",
30 | "import xbatcher"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "id": "equipped-sense",
36 | "metadata": {},
37 | "source": [
38 | "## Example data\n",
39 | "\n",
40 | "Here we will load an example dataset from a global climate model. The data is from the _historical_ experiment from CMIP6 and represents 60 days of daily max air temperature. "
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "id": "dutch-grave",
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "store = 's3://carbonplan-share/xbatcher/example_cmip6_data.zarr'\n",
51 | "ds = xr.open_dataset(\n",
52 | " store, engine='zarr', chunks={}, backend_kwargs={'storage_options': {'anon': True}}\n",
53 | ")\n",
54 | "\n",
55 | "# inspect the dataset\n",
56 | "ds"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "applicable-diesel",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "# plot the first time dimension\n",
67 | "ds.isel(time=0).tasmax.plot();"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "id": "animated-marsh",
73 | "metadata": {},
74 | "source": [
75 | "## Batch generation\n",
76 | "\n",
77 | "Xbatcher's `BatchGenerator` can be used to generate batches with several arguments controlling the exact behavior.\n",
78 | "\n",
79 | "The `input_dims` argument takes a dictionary specifying the size of the inputs in each dimension. For example, `{'time': 10}` means that each of the input sample will have 10 time points, while all other dimensions are flattened to a \"sample\" dimension\n",
80 | "\n",
81 | "Note that even though `ds` in this case only has one variable, the function can operate on multiple variables at the same time."
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "id": "attempted-cooling",
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "n_timepoint_in_each_sample = 10\n",
92 | "\n",
93 | "bgen = xbatcher.BatchGenerator(\n",
94 | " ds=ds,\n",
95 | " input_dims={'time': n_timepoint_in_each_sample},\n",
96 | ")\n",
97 | "\n",
98 | "print(f'{len(bgen)} batches')"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "id": "546aed21-3931-46b5-910e-c43498b51e23",
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "batch = bgen[0]\n",
109 | "batch"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "id": "digital-night",
115 | "metadata": {},
116 | "source": [
117 | "We can verify that the outputs have the expected shapes. \n",
118 | "\n",
119 | "For example, there are 60 time points in our input dataset, we're asking 10 timepoints in each batch, thus expecting 6 batches "
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "integral-theta",
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n",
130 | "print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "id": "usual-kennedy",
136 | "metadata": {},
137 | "source": [
138 | "There are 145 lat points and 192 lon points, thus we're expecting 145 * 192 = 27840 samples in a batch."
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "incomplete-native",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "expected_batch_size = len(ds.lat) * len(ds.lon)\n",
149 | "print(\n",
150 | " f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n",
151 | ")"
152 | ]
153 | },
154 | {
155 | "cell_type": "markdown",
156 | "id": "durable-gazette",
157 | "metadata": {},
158 | "source": [
159 | "## Controlling the size/shape of batches\n",
160 | "\n",
161 | "We can use `batch_dims` and `concat_input_dims` options to control how many sample ends up in each batch. For example, we can specify 10 time points for each sample, but 20 time points in each batch this should yield half as many batches and twice as many samples in a batch as the example above note the difference in dimension name in this case "
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "id": "sophisticated-legislation",
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "n_timepoint_in_each_sample = 10\n",
172 | "n_timepoint_in_each_batch = 20\n",
173 | "\n",
174 | "bgen = xbatcher.BatchGenerator(\n",
175 | " ds=ds,\n",
176 | " input_dims={'time': n_timepoint_in_each_sample},\n",
177 | " batch_dims={'time': n_timepoint_in_each_batch},\n",
178 | " concat_input_dims=True,\n",
179 | ")\n",
180 | "\n",
181 | "print(f'{len(bgen)} batches')"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "857d962e-2b9e-4e25-95e3-a922bfac3d6f",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "bgen[0]"
192 | ]
193 | },
194 | {
195 | "attachments": {},
196 | "cell_type": "markdown",
197 | "id": "spectacular-reading",
198 | "metadata": {},
199 | "source": [
200 | "## Last batch behavior\n",
201 | "\n",
202 | "If the input ds is not divisible by the specified `input_dims`, the remainder will be discarded instead of having a fractional batch. See https://github.com/xarray-contrib/xbatcher/discussions/82 for more on this topic."
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "id": "residential-income",
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "n_timepoint_in_batch = 31\n",
213 | "\n",
214 | "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n",
215 | "\n",
216 | "for batch in bgen:\n",
217 | " print(f'last time point in ds is {ds.time[-1].values}')\n",
218 | " print(f'last time point in batch is {batch.time[-1].values}')\n",
219 | "batch"
220 | ]
221 | },
222 | {
223 | "cell_type": "markdown",
224 | "id": "competitive-islam",
225 | "metadata": {},
226 | "source": [
227 | "## Overlapping inputs\n",
228 | "\n",
229 | "In the example above, all samples have distinct time points. That is, for any lat/lon pixel, sample 1 has time points 1-10, sample 2 has time point 11-20, and they do not overlap \n",
230 | "however, in many machine learning applications, we will want overlapping samples (e.g. sample 1 has time points 1-10, sample 2 has time points 2-11, and so on). We can use the `input_overlap` argument to get this behavior."
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "id": "cleared-custody",
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "n_timepoint_in_each_sample = 10\n",
241 | "n_timepoint_in_each_batch = 20\n",
242 | "input_overlap = 9\n",
243 | "\n",
244 | "bgen = xbatcher.BatchGenerator(\n",
245 | " ds=ds,\n",
246 | " input_dims={'time': n_timepoint_in_each_sample},\n",
247 | " batch_dims={'time': n_timepoint_in_each_batch},\n",
248 | " concat_input_dims=True,\n",
249 | " input_overlap={'time': input_overlap},\n",
250 | ")\n",
251 | "\n",
252 | "batch = bgen[0]\n",
253 | "\n",
254 | "print(f'{len(bgen)} batches')\n",
255 | "batch"
256 | ]
257 | },
258 | {
259 | "attachments": {},
260 | "cell_type": "markdown",
261 | "id": "harmful-benefit",
262 | "metadata": {},
263 | "source": [
264 | "We can inspect the samples in a batch for a lat/lon pixel, noting that the overlap applies across batches."
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "id": "earlier-warehouse",
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "lat = -90\n",
275 | "lon = 0\n",
276 | "pixel = batch.sel(lat=lat, lon=lon)\n",
277 | "display(pixel)\n",
278 | "\n",
279 | "print(\n",
280 | " f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n",
281 | ")\n",
282 | "print(\n",
283 | " f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n",
284 | ")"
285 | ]
286 | },
287 | {
288 | "cell_type": "markdown",
289 | "id": "arranged-telephone",
290 | "metadata": {},
291 | "source": [
292 | "## Example applications\n",
293 | "\n",
294 | "These batches can then be used to train a downstream machine learning model while preserving the indices of these sample. \n",
295 | "\n",
296 | "As an example, let's say we want to train a simple CNN model to predict the max air temprature for each day at each lat/lon pixel. To predict the temperature at lat/lon/time of (i, j, t), we'll use features including the temperature of a 9 x 9 grid centered at (i, j), from times t-10 to t-1 (shape of input should be (n_samples_in_each_batch, 9, 9, 9)). Note that in this example, we subset the dataset to a smaller domain for efficiency."
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "id": "consolidated-chocolate",
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "bgen = xbatcher.BatchGenerator(\n",
307 | " ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n",
308 | " input_dims={'lat': 9, 'lon': 9, 'time': 10},\n",
309 | " batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n",
310 | " concat_input_dims=True,\n",
311 | " input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n",
312 | ")\n",
313 | "\n",
314 | "for i, batch in enumerate(bgen):\n",
315 | " print(f'batch {i}')\n",
316 | " # make sure the ordering of dimension is consistent\n",
317 | " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n",
318 | "\n",
319 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
320 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n",
321 | " # select the center pixel at the last time point to be the label to be predicted\n",
322 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
323 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
324 | "\n",
325 | " print('feature shape', features.shape)\n",
326 | " print('label shape', labels.shape)\n",
327 | " print('shape of lat of each sample', labels.coords['lat'].shape)\n",
328 | " print('')"
329 | ]
330 | },
331 | {
332 | "cell_type": "markdown",
333 | "id": "legislative-closer",
334 | "metadata": {},
335 | "source": [
336 | "We can also use the Xarray's \"stack\" method to transform these into 2D inputs (n_samples, n_features) suitable for other machine learning algorithms implemented in libraries such as [sklearn](https://scikit-learn.org/stable/) and [xgboost](https://xgboost.readthedocs.io/en/stable/). In this case, we are expecting 9 x 9 x 9 = 729 features total."
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": null,
342 | "id": "advisory-chicken",
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "for i, batch in enumerate(bgen):\n",
347 | " print(f'batch {i}')\n",
348 | " # make sure the ordering of dimension is consistent\n",
349 | " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n",
350 | "\n",
351 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
352 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n",
353 | " features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n",
354 | "\n",
355 | " # select the center pixel at the last time point to be the label to be predicted\n",
356 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
357 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
358 | "\n",
359 | " print('feature shape', features.shape)\n",
360 | " print('label shape', labels.shape)\n",
361 | " print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')"
362 | ]
363 | },
364 | {
365 | "attachments": {},
366 | "cell_type": "markdown",
367 | "id": "persistent-culture",
368 | "metadata": {},
369 | "source": [
370 | "## What's next?\n",
371 | "\n",
372 | "There are many additional useful features that were yet to be implemented in the context of batch generation for downstream machine learning model training purposes. One of the current efforts is to improve the set of data loaders. \n",
373 | "\n",
374 | "Additional features of interest can include: \n",
375 | "\n",
376 | "1. Shuffling/randomization of samples across batches. It is often desirable for each batch to be grouped randomly instead of along a specific dimension. \n",
377 | "\n",
378 | "2. Be efficient in terms of memory usage. In the case where overlap is enabled, each sample would comprised of mostly repetitive values compared to adjacent samples. It would be beneficial if each batch/sample is generated lazily to avoid storing these extra duplicative values. \n",
379 | "\n",
380 | "3. Handling preprocessing steps. For example, data augmentation, scaling/normalization, outlier detection, etc. \n",
381 | "\n",
382 | "\n",
383 | "More thoughts on 1. can be found in [this discussion](https://github.com/xarray-contrib/xbatcher/discussions/78). Interested users are welcomed to comment or submit other issues in GitHub. "
384 | ]
385 | }
386 | ],
387 | "metadata": {
388 | "kernelspec": {
389 | "display_name": "Python 3 (ipykernel)",
390 | "language": "python",
391 | "name": "python3"
392 | },
393 | "language_info": {
394 | "codemirror_mode": {
395 | "name": "ipython",
396 | "version": 3
397 | },
398 | "file_extension": ".py",
399 | "mimetype": "text/x-python",
400 | "name": "python",
401 | "nbconvert_exporter": "python",
402 | "pygments_lexer": "ipython3",
403 | "version": "3.11.9"
404 | },
405 | "vscode": {
406 | "interpreter": {
407 | "hash": "64c578a0a9f6dde4e1dfaddaa39417770d5e50fec039804eaf1eb97ef756c00c"
408 | }
409 | }
410 | },
411 | "nbformat": 4,
412 | "nbformat_minor": 5
413 | }
414 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | xbatcher: Batch Generation from Xarray Datasets
2 | ===============================================
3 |
4 | Xbatcher is a small library for iterating Xarray DataArrays and Datasets in
5 | batches. The goal is to make it easy to feed Xarray objects to machine learning
6 | libraries such as Keras_.
7 |
8 | .. _Keras: https://keras.io/
9 |
10 | Installation
11 | ------------
12 |
13 | Xbatcher can be installed from PyPI as::
14 |
15 | python -m pip install xbatcher
16 |
17 | Or via Conda as::
18 |
19 | conda install -c conda-forge xbatcher
20 |
21 | Or from source as::
22 |
23 | python -m pip install git+https://github.com/xarray-contrib/xbatcher.git
24 |
25 | Optional Dependencies
26 | ~~~~~~~~~~~~~~~~~~~~~
27 |
28 | .. note::
29 | The required dependencies installed with Xbatcher are `Xarray `_,
30 | `Dask `_, and `NumPy `_.
31 | You will need to separately install `TensorFlow `_
32 | or `PyTorch `_ to use those data loaders or
33 | Xarray accessors.
34 |
35 | To install Xbatcher and PyTorch via `Conda `_::
36 |
37 | conda install -c conda-forge xbatcher pytorch
38 |
39 | Or via PyPI::
40 |
41 | python -m pip install xbatcher[torch]
42 |
43 | To install Xbatcher and TensorFlow via `Conda `_::
44 |
45 | conda install -c conda-forge xbatcher tensorflow
46 |
47 | Or via PyPI::
48 |
49 | python -m pip install xbatcher[tensorflow]
50 |
51 | Basic Usage
52 | -----------
53 |
54 | Let's say we have an Xarray Dataset
55 |
56 | .. ipython:: python
57 |
58 | import xarray as xr
59 | import numpy as np
60 | da = xr.DataArray(np.random.rand(1000, 100, 100), name='foo',
61 | dims=['time', 'y', 'x']).chunk({'time': 1})
62 | da
63 |
64 | and we want to create batches along the time dimension. We can do it like this
65 |
66 | .. ipython:: python
67 |
68 | import xbatcher
69 | bgen = xbatcher.BatchGenerator(da, {'time': 10})
70 | for batch in bgen:
71 | pass
72 | # actually feed to machine learning library
73 | batch
74 |
75 | or via a built-in `Xarray accessor `_:
76 |
77 | .. ipython:: python
78 |
79 | import xbatcher
80 |
81 | for batch in da.batch.generator({'time': 10}):
82 | pass
83 | # actually feed to machine learning library
84 | batch
85 |
86 | .. toctree::
87 | :maxdepth: 2
88 | :caption: Contents:
89 |
90 | api
91 | user-guide/index
92 | tutorials-and-presentations
93 | roadmap
94 | contributing
95 |
--------------------------------------------------------------------------------
/doc/roadmap.rst:
--------------------------------------------------------------------------------
1 | .. _roadmap:
2 |
3 | Roadmap
4 | =======
5 |
6 | Authors: Joe Hamman and Ryan Abernathey
7 | Date: February 7, 2019
8 |
9 | Background and scope
10 | --------------------
11 |
12 | Xbatcher is a small library for iterating xarray objects in batches. The
13 | goal is to make it easy to feed xarray datasets to machine learning libraries
14 | such as `Keras`_ or `PyTorch`_. For example, implementing a simple machine
15 | learning workflow may look something like this:
16 |
17 | .. code-block:: Python
18 |
19 | import xarray as xr
20 | import xbatcher as xb
21 |
22 | da = xr.open_dataset(filename, chunks=chunks) # open a dataset and use dask
23 | da_train = preprocess(ds) # perform some preprocessing
24 | bgen = xb.BatchGenerator(da_train, {'time': 10}) # create a generator
25 |
26 | for batch in bgen: # iterate through the generator
27 | model.fit(batch['x'], batch['y']) # fit a deep-learning model
28 | # or
29 | model.predict(batch['x']) # make one batch of predictions
30 |
31 | We are currently envisioning the project growing to support more complex
32 | extract-transform-load components commonly found in machine learning workflows
33 | that use multidimensional data. We note that many of the concepts in Xbatcher
34 | have been developed through collaborations in the `Pangeo Project Machine
35 | Learning Working Group `_.
36 |
37 | Batch generation
38 | ~~~~~~~~~~~~~~~~
39 |
40 | At the core of Xbatcher is the ability to define a schema that defines a
41 | selection of a larger dataset. Today, this schema is fairly simple (e.g.
42 | `{'time': 10}`) but this may evolve in the future. As we describe below,
43 | additional utilities for shuffling, sampling, and caching may provide enhanced
44 | batch generation functionality
45 |
46 | Shuffle and Sampling APIs
47 | ~~~~~~~~~~~~~~~~~~~~~~~~~
48 |
49 | When training machine-learning models in batches, it is often necessary to
50 | selectively or randomly sample from your training data. Xbatcher can help
51 | facilitate seamless shuffling and sampling by providing APIs that operate on
52 | batches and/or full datasets. This may require working with Xarray and Dask to
53 | facilitate fast, distributed shuffles of Dask arrays.
54 |
55 | Caching APIs
56 | ~~~~~~~~~~~~
57 |
58 | A common pattern in ML is perform the ETL tasks once before saving the results
59 | to a local file system. This is an effective approach for speeding up dataset
60 | loading during training but comes with numerous downsides (i.e. requires
61 | sufficient file space, breaks workflow continuity, etc.). We propose the
62 | development of a pluggable cache mechanism in Xbatcher that would help address
63 | these downsides while providing improved performance during model training and
64 | inference. For example, this pluggable cache mechanism may allow choosing
65 | between multiple cache types, such as an LRU in-memory cache, a Zarr filesystem
66 | or S3 bucket, or a Redis database cache.
67 |
68 | Integration with TensorFlow and PyTorch Dataset Loaders
69 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
70 |
71 | Deep-learning libraries like TensorFlow and PyTorch provide high-performance
72 | dataset-generator APIs that facilitate the construction of flexible and
73 | efficient input pipelines. In particular, they have been optimized to support
74 | asynchronous data loading and training, transfer to and from GPUs, and batch
75 | caching. Xbatcher will provide compatible dataset APIs that allow users to pass
76 | Xarray datasets directly to deep-learning frameworks.
77 |
78 | Dependencies
79 | ------------
80 |
81 | - Core: Xarray, Pandas, Dask, Scikit-learn, Numpy, Scipy
82 | - Optional: Keras, PyTorch, Tensorflow, etc.
83 |
84 | .. _Keras: https://keras.io/
85 | .. _PyTorch: https://pytorch.org/
86 |
--------------------------------------------------------------------------------
/doc/tutorials-and-presentations.rst:
--------------------------------------------------------------------------------
1 | .. _tutorials-and-presentations:
2 |
3 | Tutorials and Presentations
4 | ===========================
5 |
6 | Tutorials
7 | ---------
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 | :hidden:
12 |
13 | demo
14 |
15 | .. grid:: 1 2 2 2
16 | :gutter: 2
17 |
18 |
19 | .. grid-item-card::
20 | :text-align: center
21 | :link: demo.html
22 |
23 | .. image:: https://xbatcher.readthedocs.io/en/latest/_images/demo_4_0.png
24 | :alt: Xbatcher demonstration
25 | +++
26 | Xbatcher demonstration
27 |
28 | Presentations
29 | -------------
30 |
31 | .. card:: Xbatcher - A Python Package That Simplifies Feeding Xarray Data Objects to Machine Learning Libraries
32 |
33 | 2023 AMS Annual Meeting
34 | ^^^
35 |
36 |
37 | | Presentation Recording (starts at 45:30): `AMS Confex `_
38 | | DOI: `10.6084/m9.figshare.22264072.v1 `_
39 |
40 | +++
41 | Max Jones, Joe Hamman, and Wei Ji Leong
42 |
--------------------------------------------------------------------------------
/doc/user-guide/caching.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Xbatcher Caching Feature \n",
8 | "\n",
9 | "This notebook demonstrates the new caching feature added to xbatcher's `BatchGenerator`. This feature allows you to cache batches, potentially improving performance for repeated access to the same batches. \n",
10 | "\n",
11 | "\n",
12 | "## Introduction\n",
13 | "\n",
14 | "The caching feature in xbatcher's `BatchGenerator` allows you to store generated batches in a cache, which can significantly speed up subsequent accesses to the same batches. This is particularly useful in scenarios where you need to iterate over the same dataset multiple times. \n",
15 | "\n",
16 | "\n",
17 | "The cache is pluggable, meaning you can use any dict-like object to store the cache. This flexibility allows for various storage backends, including local storage, distributed storage systems, or cloud storage solutions.\n",
18 | "\n",
19 | "## Installation \n",
20 | "\n",
21 | "To use the caching feature, you'll need to have xbatcher installed, along with zarr for serialization. If you haven't already, you can install these using pip:\n",
22 | "\n",
23 | "```bash\n",
24 | "python -m pip install xbatcher zarr\n",
25 | "```\n",
26 | "\n",
27 | "or \n",
28 | "\n",
29 | "using conda:\n",
30 | "\n",
31 | "```bash\n",
32 | "conda install -c conda-forge xbatcher zarr\n",
33 | "```\n"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "## Basic Usage \n",
41 | "\n",
42 | "Let's start with a basic example of how to use the caching feature:"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "import tempfile\n",
52 | "\n",
53 | "import xarray as xr\n",
54 | "import zarr\n",
55 | "\n",
56 | "import xbatcher"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# create a cache using Zarr's DirectoryStore\n",
66 | "directory = f'{tempfile.mkdtemp()}/xbatcher-cache'\n",
67 | "print(directory)\n",
68 | "cache = zarr.storage.DirectoryStore(directory)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "In this example, we're using a local directory to store the cache, but you could use any zarr-compatible store, such as S3, Redis, etc."
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "# load a sample dataset\n",
85 | "ds = xr.tutorial.open_dataset('air_temperature', chunks={})\n",
86 | "ds"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "# create a BatchGenerator with caching enabled\n",
96 | "gen = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10}, cache=cache)"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "### Performance Comparison\n",
104 | "\n",
105 | "\n",
106 | "Let's compare the performance with and without caching:\n"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "import time\n",
116 | "\n",
117 | "\n",
118 | "def time_iteration(gen):\n",
119 | " start = time.time()\n",
120 | " for batch in gen:\n",
121 | " pass\n",
122 | " end = time.time()\n",
123 | " return end - start"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "directory = f'{tempfile.mkdtemp()}/xbatcher-cache'\n",
133 | "cache = zarr.storage.DirectoryStore(directory)\n",
134 | "\n",
135 | "# Without cache\n",
136 | "gen_no_cache = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10})\n",
137 | "time_no_cache = time_iteration(gen_no_cache)\n",
138 | "print(f'Time without cache: {time_no_cache:.2f} seconds')"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "# With cache\n",
148 | "gen_with_cache = xbatcher.BatchGenerator(\n",
149 | " ds, input_dims={'lat': 10, 'lon': 10}, cache=cache\n",
150 | ")\n",
151 | "time_first_run = time_iteration(gen_with_cache)\n",
152 | "print(f'Time with cache (first run): {time_first_run:.2f} seconds')\n",
153 | "\n",
154 | "\n",
155 | "time_second_run = time_iteration(gen_with_cache)\n",
156 | "print(f'Time with cache (second run): {time_second_run:.2f} seconds')"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "You should see that the second run with cache is significantly faster than both the first run and the run without cache."
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {},
169 | "source": [
170 | "## Advanced Usage \n",
171 | "\n",
172 | "### Custom Cache Preprocessing\n",
173 | "\n",
174 | "You can also specify a custom preprocessing function to be applied to batches before they are cached:\n"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "# create a cache using Zarr's DirectoryStore\n",
184 | "directory = f'{tempfile.mkdtemp()}/xbatcher-cache'\n",
185 | "cache = zarr.storage.DirectoryStore(directory)\n",
186 | "\n",
187 | "\n",
188 | "def preprocess_batch(batch):\n",
189 | " # example: add a new variable to each batch\n",
190 | " batch['new_var'] = batch['air'] * 2\n",
191 | " return batch\n",
192 | "\n",
193 | "\n",
194 | "gen_with_preprocess = xbatcher.BatchGenerator(\n",
195 | " ds,\n",
196 | " input_dims={'lat': 10, 'lon': 10},\n",
197 | " cache=cache,\n",
198 | " cache_preprocess=preprocess_batch,\n",
199 | ")\n",
200 | "\n",
201 | "# Now, each cached batch will include the 'new_var' variable\n",
202 | "for batch in gen_with_preprocess:\n",
203 | " print(batch)\n",
204 | " break"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {},
210 | "source": [
211 | "### Using Different Storage Backends\n",
212 | "\n",
213 | "While we've been using a local directory for caching, you can use any dict-like that is compatible with zarr. For example, you could use an S3 bucket as the cache storage backend:\n",
214 | "\n",
215 | "```python\n",
216 | "import s3fs\n",
217 | "import zarr \n",
218 | "\n",
219 | "# Set up S3 filesystem (you'll need appropriate credentials)\n",
220 | "s3 = s3fs.S3FileSystem(anon=False)\n",
221 | "cache = s3.get_mapper('s3://my-bucket/my-cache.zarr')\n",
222 | "\n",
223 | "# Use this cache with BatchGenerator\n",
224 | "gen_s3 = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10}, cache=cache)\n",
225 | "```\n"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "## Considerations and Best Practices \n",
233 | "\n",
234 | "- **Storage Space**: Be mindful of the storage space required for your cache, especially when working with large datasets.\n",
235 | "- **Cache Invalidation**: The current implementation doesn't handle cache invalidation. If your source data changes, you'll need to manually clear or update the cache.\n",
236 | "- **Performance Tradeoffs**: While caching can significantly speed up repeated access to the same data, the initial caching process may be slower than processing without a cache. Consider your use case to determine if caching is beneficial.\n",
237 | "- **Storage Backend**: Choose a storage backend that's appropriate for your use case. Local storage might be fastest for single-machine applications, while distributed or cloud storage might be necessary for cluster computing or cloud-based workflows.\n",
238 | "\n"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": []
245 | }
246 | ],
247 | "metadata": {
248 | "kernelspec": {
249 | "display_name": "Python 3 (ipykernel)",
250 | "language": "python",
251 | "name": "python3"
252 | },
253 | "language_info": {
254 | "codemirror_mode": {
255 | "name": "ipython",
256 | "version": 3
257 | },
258 | "file_extension": ".py",
259 | "mimetype": "text/x-python",
260 | "name": "python",
261 | "nbconvert_exporter": "python",
262 | "pygments_lexer": "ipython3",
263 | "version": "3.11.9"
264 | }
265 | },
266 | "nbformat": 4,
267 | "nbformat_minor": 4
268 | }
269 |
--------------------------------------------------------------------------------
/doc/user-guide/index.rst:
--------------------------------------------------------------------------------
1 | User Guide
2 | ===========
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Contents:
7 |
8 | caching
9 | training-a-neural-network-with-Pytorch-and-xbatcher
10 | training-a-neural-network-with-keras-and-xbatcher
11 |
--------------------------------------------------------------------------------
/doc/user-guide/training-a-neural-network-with-Pytorch-and-xbatcher.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e579c0e1-bb12-4c7b-8a97-bdb6fad01755",
6 | "metadata": {},
7 | "source": [
8 | "# End-to-End Tutorial: Training a Neural Network with PyTorch and Xbatcher\n",
9 | "\n",
10 | "This tutorial demonstrates how to use xarray, xbatcher, and PyTorch to train a simple neural network on the FashionMNIST dataset."
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "5aa4bf55-588a-465d-affb-de5d16a54cdd",
16 | "metadata": {},
17 | "source": [
18 | "## Step 1: Setup \n",
19 | "\n",
20 | "Import the necessary libraries and load the dataset"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "916bb4a8-d2df-49e8-9109-a92299960886",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import matplotlib.pyplot as plt\n",
31 | "import torch\n",
32 | "import torch.nn as nn\n",
33 | "import torch.optim as optim\n",
34 | "import torch.utils.data\n",
35 | "import xarray as xr\n",
36 | "\n",
37 | "import xbatcher as xb\n",
38 | "import xbatcher.loaders.torch"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "474b2cc1-9991-4060-92fe-559c15d96678",
45 | "metadata": {
46 | "scrolled": true
47 | },
48 | "outputs": [],
49 | "source": [
50 | "ds = xr.open_dataset(\n",
51 | " 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n",
52 | " engine='zarr',\n",
53 | " chunks={},\n",
54 | " backend_kwargs={'storage_options': {'anon': True}},\n",
55 | ")\n",
56 | "ds"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "d647846e-381a-4901-ba1c-d4d47ff7b1fa",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "ds.sel(sample=1).images.plot(cmap='gray');"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "id": "533b9827-8c27-4229-8035-cf39f3e99e54",
72 | "metadata": {},
73 | "source": [
74 | "## Step 2: Create batch generator and data loader\n",
75 | "\n",
76 | "We use `xbatcher` to create batch generators for the images (`X_bgen`) and labels (`y_gen`)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "303b2a1d-9126-44e7-b312-fa546eca8f2e",
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "# Define batch generators\n",
87 | "X_bgen = xb.BatchGenerator(\n",
88 | " ds['images'],\n",
89 | " input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n",
90 | " preload_batch=False,\n",
91 | ")\n",
92 | "y_bgen = xb.BatchGenerator(\n",
93 | " ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n",
94 | ")\n",
95 | "X_bgen[0]"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "1c1ab1a7-4bf2-4d73-a8e2-a88e7cdf6829",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "# Map batches to a PyTorch-compatible dataset\n",
106 | "dataset = xbatcher.loaders.torch.MapDataset(X_bgen, y_bgen)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "942ee9c8-5369-49d3-8038-0658b80ff851",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "# Create a DataLoader\n",
117 | "train_dataloader = torch.utils.data.DataLoader(\n",
118 | " dataset,\n",
119 | " batch_size=None, # Using batches defined by the dataset itself (via xbatcher)\n",
120 | " prefetch_factor=3, # Prefetch up to 3 batches in advance to reduce data loading latency\n",
121 | " num_workers=4, # Use 4 parallel worker processes to load data concurrently\n",
122 | " persistent_workers=True, # Keep workers alive between epochs for faster subsequent epochs\n",
123 | " multiprocessing_context='forkserver', # Use \"forkserver\" to spawn subprocesses, ensuring stability in multiprocessing\n",
124 | ")"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "id": "16953fd7-53d8-4d37-80e8-57f5b4daccee",
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "train_features, train_labels = next(iter(train_dataloader))"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "id": "119cbac9-a973-4b37-b42d-12e03f105826",
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "print(f'Feature batch shape: {train_features.size()}')\n",
145 | "print(f'Labels batch shape: {train_labels.size()}')\n",
146 | "img = train_features[0].squeeze()\n",
147 | "label = train_labels[0]\n",
148 | "plt.imshow(img, cmap='gray')\n",
149 | "plt.show()\n",
150 | "print(f'Label: {label}')"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "4a6c3219-b281-4782-a77d-58860f7f7c83",
156 | "metadata": {},
157 | "source": [
158 | "## Step 3: Define the Neural Network\n",
159 | "\n",
160 | "We define a simple feedforward neural network for classification."
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "09620988-bbda-4508-b39e-e1d81f2374c9",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "class SimpleNN(nn.Module):\n",
171 | " def __init__(self):\n",
172 | " super().__init__()\n",
173 | " self.flatten = nn.Flatten()\n",
174 | " self.fc1 = nn.Linear(28 * 28, 128)\n",
175 | " self.fc2 = nn.Linear(128, 10)\n",
176 | "\n",
177 | " def forward(self, x):\n",
178 | " x = self.flatten(x)\n",
179 | " x = torch.relu(self.fc1(x))\n",
180 | " x = self.fc2(x)\n",
181 | " return x"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "cff27830-a15e-4f4a-b529-7fed8ea7632d",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# Instantiate the model\n",
192 | "model = SimpleNN()\n",
193 | "model"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "id": "c12073c2-a380-4a28-bc24-68c1811739b5",
199 | "metadata": {},
200 | "source": [
201 | "## Step 4: Define Loss Function and Optimizer\n",
202 | "We use Cross-Entropy Loss and the Adam optimizer."
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "id": "ab29bcbc-11e6-48ac-8618-8b6da4627b8f",
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "loss_fn = nn.CrossEntropyLoss()\n",
213 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "id": "ec0bd97d-c9fd-42ae-8962-ec9ed6434331",
219 | "metadata": {},
220 | "source": [
221 | "## Step 5: Train the Model\n",
222 | "We train the model using the data loader."
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "id": "76fe4f2d-d15d-43ba-a079-0c3c8c4d65a2",
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "%%time\n",
233 | "\n",
234 | "epochs = 5\n",
235 | "\n",
236 | "for epoch in range(epochs):\n",
237 | " print(f'Epoch {epoch+1}/{epochs}')\n",
238 | " for batch, (X, y) in enumerate(train_dataloader):\n",
239 | " # Forward pass\n",
240 | " predictions = model(X)\n",
241 | " loss = loss_fn(predictions, y)\n",
242 | "\n",
243 | " # Backward pass\n",
244 | " optimizer.zero_grad()\n",
245 | " loss.backward()\n",
246 | " optimizer.step()\n",
247 | "\n",
248 | " if batch % 10 == 0:\n",
249 | " print(f'Batch {batch}: Loss = {loss.item():.4f}')\n",
250 | "\n",
251 | "print('Training completed!')"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "id": "aa1db735-9670-41dc-a27e-74369e8c320d",
257 | "metadata": {},
258 | "source": [
259 | "## Step 6: Evaluate the Model\n",
260 | "You can evaluate the model on the test set or visualize some predictions."
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "id": "ea26801a-18a9-4ffe-9d04-92157c42bc8a",
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "# Visualize a sample prediction\n",
271 | "img = train_features[0].squeeze()\n",
272 | "label = train_labels[0]\n",
273 | "predicted_label = torch.argmax(model(train_features[0:1]), dim=1).item()\n",
274 | "\n",
275 | "plt.imshow(img, cmap='gray')\n",
276 | "plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n",
277 | "plt.show()"
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "id": "8c52dcc1-d583-4cfe-af34-8030d4b451f0",
283 | "metadata": {},
284 | "source": [
285 | "## Key Highlights\n",
286 | "\n",
287 | "- **Data Handling**: We use Xbatcher to create efficient, chunked data pipelines from Xarray datasets.\n",
288 | "- **Integration**: The `xbatcher.loaders.torch.MapDatase`t enables direct compatibility with PyTorch's DataLoader.\n",
289 | "- **Training**: PyTorch simplifies the model training loop while leveraging the custom data pipeline.\n"
290 | ]
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": "Python 3 (ipykernel)",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.11.9"
310 | }
311 | },
312 | "nbformat": 4,
313 | "nbformat_minor": 5
314 | }
315 |
--------------------------------------------------------------------------------
/doc/user-guide/training-a-neural-network-with-keras-and-xbatcher.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b314e777-7ffb-4e62-b4c5-ce8a785c5181",
6 | "metadata": {},
7 | "source": [
8 | "# End-to-End Tutorial: Training a Neural Network with Keras and Xbatcher\n",
9 | "\n",
10 | "## Import Required Libraries"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "5d912ff0-d808-4704-8dea-b9e1b5a53bf1",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import tensorflow as tf\n",
22 | "import xarray as xr\n",
23 | "from keras import layers, models, optimizers\n",
24 | "\n",
25 | "import xbatcher as xb\n",
26 | "import xbatcher.loaders.keras"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "7fb892c1-50fd-48c8-8567-b150946b53c9",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "# Open the dataset stored in Zarr format\n",
37 | "ds = xr.open_dataset(\n",
38 | " 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n",
39 | " engine='zarr',\n",
40 | " chunks={},\n",
41 | " backend_kwargs={'storage_options': {'anon': True}},\n",
42 | ")"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "id": "c98134fe-581f-412a-93e3-6b07b7706078",
48 | "metadata": {},
49 | "source": [
50 | "## Define Batch Generators"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "id": "c680ebd7-0310-4f40-91b5-e7cc1a59e853",
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "# Define batch generators for features (X) and labels (y)\n",
61 | "X_bgen = xb.BatchGenerator(\n",
62 | " ds['images'],\n",
63 | " input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n",
64 | " preload_batch=False, # Load each batch dynamically\n",
65 | ")\n",
66 | "y_bgen = xb.BatchGenerator(\n",
67 | " ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n",
68 | ")"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "id": "91d63180-e3a6-49f7-a8e7-67b8b698b08c",
74 | "metadata": {},
75 | "source": [
76 | "## Map Batches to a Keras-Compatible Dataset"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "d1195057-269b-44ba-a3e7-aeedaa4ba8df",
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "# Use xbatcher's MapDataset to wrap the generators\n",
87 | "dataset = xbatcher.loaders.keras.CustomTFDataset(X_bgen, y_bgen)\n",
88 | "\n",
89 | "# Create a DataLoader using tf.data.Dataset\n",
90 | "train_dataloader = tf.data.Dataset.from_generator(\n",
91 | " lambda: iter(dataset),\n",
92 | " output_signature=(\n",
93 | " tf.TensorSpec(shape=(2000, 1, 28, 28), dtype=tf.float32), # Images\n",
94 | " tf.TensorSpec(shape=(2000,), dtype=tf.int64), # Labels\n",
95 | " ),\n",
96 | ").prefetch(3) # Prefetch 3 batches to improve performance"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "1892411c-ca17-4d7f-b76b-5b5decaa78c1",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "## Visualize a Sample Batch"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "133b24bc-e7bc-4734-ad0a-22a848dd204c",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "# Extract a batch from the DataLoader\n",
117 | "for train_features, train_labels in train_dataloader.take(1):\n",
118 | " print(f'Feature batch shape: {train_features.shape}')\n",
119 | " print(f'Labels batch shape: {train_labels.shape}')\n",
120 | "\n",
121 | " img = train_features[0].numpy().squeeze() # Extract the first image\n",
122 | " label = train_labels[0].numpy()\n",
123 | " plt.imshow(img, cmap='gray')\n",
124 | " plt.title(f'Label: {label}')\n",
125 | " plt.show()\n",
126 | " break"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "id": "1e5d6a66-1943-47da-be67-9b54d51defed",
132 | "metadata": {},
133 | "source": [
134 | "## Build a Simple Neural Network with Keras"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "id": "8b0490e5-7ccc-47fe-90ec-d41a81c4eb20",
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "# Define a simple feedforward neural network\n",
145 | "model = models.Sequential(\n",
146 | " [\n",
147 | " layers.Flatten(input_shape=(1, 28, 28)), # Flatten input images\n",
148 | " layers.Dense(128, activation='relu'), # Fully connected layer with 128 units\n",
149 | " layers.Dense(10, activation='softmax'), # Output layer for 10 classes\n",
150 | " ]\n",
151 | ")\n",
152 | "\n",
153 | "# Compile the model\n",
154 | "model.compile(\n",
155 | " optimizer=optimizers.Adam(learning_rate=0.001),\n",
156 | " loss='sparse_categorical_crossentropy',\n",
157 | " metrics=['accuracy'],\n",
158 | ")\n",
159 | "\n",
160 | "# Display model summary\n",
161 | "model.summary()"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "id": "838df9c6-0753-4120-a0e0-dcc1480416b4",
167 | "metadata": {},
168 | "source": [
169 | "## Train the Model "
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "id": "25e86eba-4d4e-47cc-a6a7-9f0be244b009",
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "%%time\n",
180 | "\n",
181 | "# Train the model for 5 epochs\n",
182 | "epochs = 5\n",
183 | "\n",
184 | "model.fit(\n",
185 | " train_dataloader, # Pass the DataLoader directly\n",
186 | " epochs=epochs,\n",
187 | " verbose=1, # Print progress during training\n",
188 | ")"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "id": "a0f4246c-6461-4e2a-a49d-df6c1ce770fc",
194 | "metadata": {},
195 | "source": [
196 | "## Visualize a Sample Prediction"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "id": "9361cb65-3c0d-40d6-be5c-18b309626817",
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "# Visualize a prediction on a sample image\n",
207 | "for train_features, train_labels in train_dataloader.take(1):\n",
208 | " img = train_features[0].numpy().squeeze()\n",
209 | " label = train_labels[0].numpy()\n",
210 | " predicted_label = tf.argmax(model.predict(train_features[:1]), axis=1).numpy()[0]\n",
211 | "\n",
212 | " plt.imshow(img, cmap='gray')\n",
213 | " plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n",
214 | " plt.show()\n",
215 | " break"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "id": "372d0e0a-1542-4aa0-b3b9-9fd4337459ba",
221 | "metadata": {},
222 | "source": [
223 | "## Key Highlights \n",
224 | "\n",
225 | "- **Dynamic Batching**: Xbatcher and the MapDataset class allow for dynamic loading of batches, which reduces memory usage and speeds up data processing.\n",
226 | "- **Prefetching**: The prefetch feature in `tf.data.Dataset` overlaps data loading with model training to minimize idle time.\n",
227 | "- **Compatibility**: The pipeline works seamlessly with `keras.Model.fit`, simplifying training workflows."
228 | ]
229 | }
230 | ],
231 | "metadata": {
232 | "kernelspec": {
233 | "display_name": "Python 3 (ipykernel)",
234 | "language": "python",
235 | "name": "python3"
236 | },
237 | "language_info": {
238 | "codemirror_mode": {
239 | "name": "ipython",
240 | "version": 3
241 | },
242 | "file_extension": ".py",
243 | "mimetype": "text/x-python",
244 | "name": "python",
245 | "nbconvert_exporter": "python",
246 | "pygments_lexer": "ipython3",
247 | "version": "3.11.9"
248 | }
249 | },
250 | "nbformat": 4,
251 | "nbformat_minor": 5
252 | }
253 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "setuptools.build_meta"
3 | requires = ["setuptools-scm[toml]>=6.2", "setuptools>=64"]
4 |
5 | [project]
6 | authors = [
7 | { name = "xbatcher Developers", email = "rpa@ldeo.columbia.edu" },
8 | ]
9 | classifiers = [
10 | "Development Status :: 4 - Beta",
11 | "Intended Audience :: Science/Research",
12 | "License :: OSI Approved :: Apache Software License",
13 | "Operating System :: OS Independent",
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.10",
16 | "Programming Language :: Python :: 3.11",
17 | "Programming Language :: Python :: 3.12",
18 | "Programming Language :: Python",
19 | "Topic :: Scientific/Engineering",
20 | ]
21 | dependencies = ["dask", "numpy", "xarray"]
22 | description = "Batch generation from Xarray objects"
23 | dynamic = ["version"]
24 | license = { text = "Apache" }
25 | name = "xbatcher"
26 | readme = "README.rst"
27 | requires-python = ">=3.10"
28 | [project.optional-dependencies]
29 | dev = [
30 | "asv",
31 | "coverage",
32 | "pytest",
33 | "pytest-cov",
34 | "s3fs",
35 | "tensorflow",
36 | "torch",
37 | "zarr<3.0",
38 | ]
39 | tensorflow = ["tensorflow"]
40 | torch = ["torch"]
41 | [project.urls]
42 | documentation = "https://xbatcher.readthedocs.io/en/latest/"
43 | repository = "https://github.com/xarray-contrib/xbatcher"
44 |
45 | [tool.setuptools.packages.find]
46 | include = ["xbatcher*"]
47 |
48 | [tool.setuptools_scm]
49 | fallback_version = "999"
50 | local_scheme = "node-and-date"
51 |
52 | [tool.ruff]
53 | extend-include = ["*.ipynb"]
54 | target-version = "py310"
55 |
56 | builtins = ["ellipsis"]
57 | # Exclude a variety of commonly ignored directories.
58 | exclude = [
59 | ".bzr",
60 | ".direnv",
61 | ".eggs",
62 | ".git",
63 | ".git-rewrite",
64 | ".hg",
65 | ".ipynb_checkpoints",
66 | ".mypy_cache",
67 | ".nox",
68 | ".pants.d",
69 | ".pyenv",
70 | ".pytest_cache",
71 | ".pytype",
72 | ".ruff_cache",
73 | ".svn",
74 | ".tox",
75 | ".venv",
76 | ".vscode",
77 | "__pypackages__",
78 | "_build",
79 | "buck-out",
80 | "build",
81 | "dist",
82 | "node_modules",
83 | "site-packages",
84 | "venv",
85 | ]
86 | [tool.ruff.lint]
87 | ignore = [
88 | "E501", # Conflicts with ruff format
89 | "E721", # Comparing types instead of isinstance
90 | "E741", # Ambiguous variable names
91 | ]
92 | per-file-ignores = {}
93 | select = [
94 | # Pyflakes
95 | "F",
96 | # Pycodestyle
97 | "E",
98 | "W",
99 | # isort
100 | "I",
101 | # Pyupgrade
102 | "UP",
103 | ]
104 |
105 | [tool.ruff.lint.mccabe]
106 | max-complexity = 18
107 |
108 | [tool.ruff.lint.isort]
109 | known-first-party = ["xbatcher"]
110 | known-third-party = [
111 | "numpy",
112 | "pandas",
113 | "pytest",
114 | "sphinx_autosummary_accessors",
115 | "torch",
116 | "xarray",
117 | ]
118 |
119 | combine-as-imports = true
120 |
121 | [tool.ruff.format]
122 | docstring-code-format = true
123 | quote-style = "single"
124 |
125 | [tool.ruff.lint.pydocstyle]
126 | convention = "numpy"
127 |
128 | [tool.ruff.lint.pyupgrade]
129 | # Preserve types, even if a file imports `from __future__ import annotations`.
130 | keep-runtime-typing = true
131 |
132 | [tool.pytest.ini_options]
133 | log_cli = true
134 | log_level = "INFO"
135 |
--------------------------------------------------------------------------------
/readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | build:
8 | os: "ubuntu-24.04"
9 | tools:
10 | python: "mambaforge-latest"
11 |
12 | # Build documentation in the doc/ directory with Sphinx
13 | sphinx:
14 | configuration: doc/conf.py
15 | fail_on_warning: true
16 |
17 | # Optionally declare the Python requirements required to build your docs
18 | conda:
19 | environment: ci/requirements/doc.yml
20 |
--------------------------------------------------------------------------------
/xbatcher/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import (
2 | PackageNotFoundError as _PackageNotFoundError,
3 | version as _version,
4 | )
5 |
6 | from . import testing # noqa: F401
7 | from .accessors import BatchAccessor # noqa: F401
8 | from .generators import BatchGenerator, BatchSchema # noqa: F401
9 | from .util.print_versions import show_versions # noqa: F401
10 |
11 | try:
12 | __version__ = _version(__name__)
13 | except _PackageNotFoundError:
14 | # package is not installed
15 | __version__ = 'unknown'
16 |
--------------------------------------------------------------------------------
/xbatcher/accessors.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import xarray as xr
4 |
5 | from .generators import BatchGenerator
6 |
7 |
8 | def _as_xarray_dataarray(xr_obj: xr.Dataset | xr.DataArray) -> xr.DataArray:
9 | """
10 | Convert xarray.Dataset to xarray.DataArray if needed, so that it can
11 | be converted into a Tensor object.
12 | """
13 | if isinstance(xr_obj, xr.Dataset):
14 | xr_obj = xr_obj.to_array().squeeze(dim='variable')
15 |
16 | return xr_obj
17 |
18 |
19 | @xr.register_dataarray_accessor('batch')
20 | @xr.register_dataset_accessor('batch')
21 | class BatchAccessor:
22 | def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
23 | """
24 | Batch accessor returning a BatchGenerator object via the `generator method`
25 | """
26 | self._obj = xarray_obj
27 |
28 | def generator(self, *args, **kwargs) -> BatchGenerator:
29 | """
30 | Return a BatchGenerator via the batch accessor
31 |
32 | Parameters
33 | ----------
34 | *args : iterable
35 | Positional arguments to pass to the `BatchGenerator` constructor.
36 | **kwargs : dict
37 | Keyword arguments to pass to the `BatchGenerator` constructor.
38 | """
39 | return BatchGenerator(self._obj, *args, **kwargs)
40 |
41 |
42 | @xr.register_dataarray_accessor('tf')
43 | @xr.register_dataset_accessor('tf')
44 | class TFAccessor:
45 | def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
46 | self._obj = xarray_obj
47 |
48 | def to_tensor(self) -> Any:
49 | """Convert this DataArray to a tensorflow.Tensor"""
50 | import tensorflow as tf
51 |
52 | dataarray = _as_xarray_dataarray(xr_obj=self._obj)
53 |
54 | return tf.convert_to_tensor(dataarray.data)
55 |
56 |
57 | @xr.register_dataarray_accessor('torch')
58 | @xr.register_dataset_accessor('torch')
59 | class TorchAccessor:
60 | def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
61 | self._obj = xarray_obj
62 |
63 | def to_tensor(self) -> Any:
64 | """Convert this DataArray to a torch.Tensor"""
65 | import torch
66 |
67 | dataarray = _as_xarray_dataarray(xr_obj=self._obj)
68 |
69 | return torch.tensor(data=dataarray.data)
70 |
71 | def to_named_tensor(self) -> Any:
72 | """
73 | Convert this DataArray to a torch.Tensor with named dimensions.
74 |
75 | See https://pytorch.org/docs/stable/named_tensor.html
76 | """
77 | import torch
78 |
79 | dataarray = _as_xarray_dataarray(xr_obj=self._obj)
80 |
81 | return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))
82 |
--------------------------------------------------------------------------------
/xbatcher/generators.py:
--------------------------------------------------------------------------------
1 | """Classes for iterating through xarray datarrays / datasets in batches."""
2 |
3 | import itertools
4 | import json
5 | import warnings
6 | from collections.abc import Callable, Hashable, Iterator, Sequence
7 | from operator import itemgetter
8 | from typing import Any
9 |
10 | import numpy as np
11 | import xarray as xr
12 |
13 | PatchGenerator = Iterator[dict[Hashable, slice]]
14 | BatchSelector = list[dict[Hashable, slice]]
15 | BatchSelectorSet = dict[int, BatchSelector]
16 |
17 |
18 | class BatchSchema:
19 | """
20 | A representation of the indices and stacking/transposing parameters needed
21 | to generator batches from Xarray DataArrays and Datasets using
22 | xbatcher.BatchGenerator.
23 |
24 | Parameters
25 | ----------
26 | ds : ``xarray.Dataset`` or ``xarray.DataArray``
27 | The data to iterate over. Unlike for the BatchGenerator, the data is
28 | not retained as a class attribute for the BatchSchema.
29 | input_dims : dict
30 | A dictionary specifying the size of the inputs in each dimension,
31 | e.g. ``{'lat': 30, 'lon': 30}``
32 | These are the dimensions the ML library will see. All other dimensions
33 | will be stacked into one dimension called ``sample``.
34 | input_overlap : dict, optional
35 | A dictionary specifying the overlap along each dimension
36 | e.g. ``{'lat': 3, 'lon': 3}``
37 | batch_dims : dict, optional
38 | A dictionary specifying the size of the batch along each dimension
39 | e.g. ``{'time': 10}``. These will always be iterated over.
40 | concat_input_dims : bool, optional
41 | If ``True``, the dimension chunks specified in ``input_dims`` will be
42 | concatenated and stacked into the ``sample`` dimension. The batch index
43 | will be included as a new level ``input_batch`` in the ``sample``
44 | coordinate.
45 | If ``False``, the dimension chunks specified in ``input_dims`` will be
46 | iterated over.
47 | preload_batch : bool, optional
48 | If ``True``, each batch will be loaded into memory before reshaping /
49 | processing, triggering any dask arrays to be computed.
50 |
51 | Notes
52 | -----
53 | The BatchSchema is experimental and subject to change without notice.
54 | """
55 |
56 | def __init__(
57 | self,
58 | ds: xr.Dataset | xr.DataArray,
59 | input_dims: dict[Hashable, int],
60 | input_overlap: dict[Hashable, int] | None = None,
61 | batch_dims: dict[Hashable, int] | None = None,
62 | concat_input_bins: bool = True,
63 | preload_batch: bool = True,
64 | ):
65 | if input_overlap is None:
66 | input_overlap = {}
67 | if batch_dims is None:
68 | batch_dims = {}
69 | self.input_dims = dict(input_dims)
70 | self.input_overlap = input_overlap
71 | self.batch_dims = dict(batch_dims)
72 | self.concat_input_dims = concat_input_bins
73 | self.preload_batch = preload_batch
74 | # Store helpful information based on arguments
75 | self._duplicate_batch_dims: dict[Hashable, int] = {
76 | dim: length
77 | for dim, length in self.batch_dims.items()
78 | if self.input_dims.get(dim) is not None
79 | }
80 | self._unique_batch_dims: dict[Hashable, int] = {
81 | dim: length
82 | for dim, length in self.batch_dims.items()
83 | if self.input_dims.get(dim) is None
84 | }
85 | self._input_stride: dict[Hashable, int] = {
86 | dim: length - self.input_overlap.get(dim, 0)
87 | for dim, length in self.input_dims.items()
88 | }
89 | self._all_sliced_dims: dict[Hashable, int] = dict(
90 | **self._unique_batch_dims, **self.input_dims
91 | )
92 | self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds)
93 |
94 | def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet:
95 | """
96 | Create batch selectors dict, which can be used to create a batch
97 | from an Xarray data object.
98 | """
99 | # Create an iterator that returns an object usable for .isel in xarray
100 | patch_selectors = self._gen_patch_selectors(ds)
101 | # Create the Dict containing batch selectors
102 | if self.concat_input_dims: # Combine the patches into batches
103 | return self._combine_patches_into_batch(ds, patch_selectors)
104 | else: # Each patch gets its own batch
105 | return {ind: [value] for ind, value in enumerate(patch_selectors)}
106 |
107 | def _gen_patch_selectors(self, ds: xr.DataArray | xr.Dataset) -> PatchGenerator:
108 | """
109 | Create an iterator that can be used to index an Xarray Dataset/DataArray.
110 | """
111 | if self._duplicate_batch_dims and not self.concat_input_dims:
112 | warnings.warn(
113 | 'The following dimensions were included in both ``input_dims`` '
114 | 'and ``batch_dims``. Since ``concat_input_dims`` is ``False``, '
115 | f'these dimensions will not impact batch generation: {self._duplicate_batch_dims}'
116 | )
117 | # Generate the slices by iterating over batch_dims and input_dims
118 | all_slices = _iterate_through_dimensions(
119 | ds,
120 | dims=self._all_sliced_dims,
121 | overlap=self.input_overlap,
122 | )
123 | return all_slices
124 |
125 | def _combine_patches_into_batch(
126 | self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator
127 | ) -> BatchSelectorSet:
128 | """
129 | Combine the patch selectors to form a batch
130 | """
131 | # Check that patches are only combined with concat_input_dims
132 | if not self.concat_input_dims:
133 | raise AssertionError(
134 | 'Patches should only be combined into batches when ``concat_input_dims`` is ``True``'
135 | )
136 | if not self.batch_dims:
137 | return self._combine_patches_into_one_batch(patch_selectors)
138 | elif self._duplicate_batch_dims:
139 | return self._combine_patches_grouped_by_input_and_batch_dims(
140 | ds=ds, patch_selectors=patch_selectors
141 | )
142 | else:
143 | return self._combine_patches_grouped_by_batch_dims(patch_selectors)
144 |
145 | def _combine_patches_into_one_batch(
146 | self, patch_selectors: PatchGenerator
147 | ) -> BatchSelectorSet:
148 | """
149 | Group all patches into a single batch
150 | """
151 | return dict(enumerate([list(patch_selectors)]))
152 |
153 | def _combine_patches_grouped_by_batch_dims(
154 | self, patch_selectors: PatchGenerator
155 | ) -> BatchSelectorSet:
156 | """
157 | Group patches based on the unique slices for dimensions in ``batch_dims``
158 | """
159 | batch_selectors = [
160 | list(value)
161 | for _, value in itertools.groupby(
162 | patch_selectors, key=itemgetter(*self.batch_dims)
163 | )
164 | ]
165 | return dict(enumerate(batch_selectors))
166 |
167 | def _combine_patches_grouped_by_input_and_batch_dims(
168 | self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator
169 | ) -> BatchSelectorSet:
170 | """
171 | Combine patches with multiple slices along ``batch_dims`` grouped into
172 | each patch. Required when a dimension is duplicated between ``batch_dims``
173 | and ``input_dims``.
174 | """
175 | self._gen_patch_numbers(ds)
176 | self._gen_batch_numbers(ds)
177 | batch_id_per_patch = self._get_batch_multi_index_per_patch()
178 | patch_in_range = self._get_batch_in_range_per_batch(
179 | batch_multi_index=batch_id_per_patch
180 | )
181 | batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch)
182 | batch_selectors = self._gen_empty_batch_selectors()
183 | for i, patch in enumerate(patch_selectors):
184 | if patch_in_range[i]:
185 | batch_selectors[batch_id_per_patch[i]].append(patch)
186 | return batch_selectors
187 |
188 | def _gen_empty_batch_selectors(self) -> BatchSelectorSet:
189 | """
190 | Create an empty batch selector set that can be populated by appending
191 | patches to each batch.
192 | """
193 | n_batches = np.prod(list(self._n_batches_per_dim.values()))
194 | return {k: [] for k in range(n_batches)}
195 |
196 | def _gen_patch_numbers(self, ds: xr.DataArray | xr.Dataset):
197 | """
198 | Calculate the number of patches per dimension and the number of patches
199 | in each batch per dimension.
200 | """
201 | self._n_patches_per_batch: dict[Hashable, int] = {
202 | dim: int(np.ceil(length / self._input_stride.get(dim, length)))
203 | for dim, length in self.batch_dims.items()
204 | }
205 | self._n_patches_per_dim: dict[Hashable, int] = {
206 | dim: int(
207 | (ds.sizes[dim] - self.input_overlap.get(dim, 0))
208 | // (length - self.input_overlap.get(dim, 0))
209 | )
210 | for dim, length in self._all_sliced_dims.items()
211 | }
212 |
213 | def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset):
214 | """
215 | Calculate the number of batches per dimension
216 | """
217 | self._n_batches_per_dim: dict[Hashable, int] = {
218 | dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim]))
219 | for dim in self._all_sliced_dims.keys()
220 | }
221 |
222 | def _get_batch_multi_index_per_patch(self):
223 | """
224 | Calculate the batch multi-index for each patch
225 | """
226 | batch_id_per_dim: dict[Hashable, Any] = {
227 | dim: np.floor(
228 | np.arange(0, n_patches)
229 | / self._n_patches_per_batch.get(dim, n_patches + 1)
230 | ).astype(np.int64)
231 | for dim, n_patches in self._n_patches_per_dim.items()
232 | }
233 | batch_id_per_patch = np.array(
234 | list(itertools.product(*batch_id_per_dim.values()))
235 | ).transpose()
236 | return batch_id_per_patch
237 |
238 | def _ravel_batch_multi_index(self, batch_multi_index):
239 | """
240 | Convert the batch multi-index to a flat index for each patch
241 | """
242 | return np.ravel_multi_index(
243 | multi_index=batch_multi_index,
244 | dims=tuple(self._n_batches_per_dim.values()),
245 | mode='clip',
246 | )
247 |
248 | def _get_batch_in_range_per_batch(self, batch_multi_index):
249 | """
250 | Determine whether each patch is contained within any of the batches.
251 | """
252 | batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int)
253 | batch_id_maximum = np.pad(
254 | batch_id_maximum,
255 | (0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))),
256 | constant_values=(1),
257 | )
258 | batch_id_maximum = batch_id_maximum[:, np.newaxis]
259 | batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0)
260 | return batch_in_range_per_patch
261 |
262 | def to_json(self):
263 | """
264 | Dump the BatchSchema properties to a JSON file.
265 |
266 | Returns
267 | ----------
268 | out_json: str
269 | The JSON representation of the BatchSchema
270 | """
271 | out_dict = {
272 | 'input_dims': self.input_dims,
273 | 'input_overlap': self.input_overlap,
274 | 'batch_dims': self.batch_dims,
275 | 'concat_input_dims': self.input_dims,
276 | 'preload_batch': self.preload_batch,
277 | }
278 | batch_selector_dict = {}
279 | for i in self.selectors.keys():
280 | batch_selector_dict[i] = self.selectors[i]
281 | for member in batch_selector_dict[i]:
282 | member_keys = list(member.keys())
283 | out_member_dict = {
284 | member_key: {
285 | 'start': member[member_key].start,
286 | 'stop': member[member_key].stop,
287 | 'step': member[member_key].step,
288 | }
289 | for member_key in member_keys
290 | }
291 | out_dict['selector'] = out_member_dict
292 | return json.dumps(out_dict)
293 |
294 | def to_file(self, out_file_name: str):
295 | """
296 | Dumps the JSON representation of the BatchSchema object to a file.
297 |
298 | Parameters
299 | ----------
300 | out_file_name: str
301 | The path to the json file to write to.
302 | """
303 | out_json = self.to_json()
304 | with open(out_file_name, mode='w') as out_file:
305 | out_file.write(out_json)
306 |
307 |
308 | def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[slice]:
309 | # return a list of slices to chop up a single dimension
310 | if overlap >= slice_size:
311 | raise ValueError(
312 | 'input overlap must be less than the input sample length, but '
313 | f'the input sample length is {slice_size} and the overlap is {overlap}'
314 | )
315 | slices = []
316 | stride = slice_size - overlap
317 | for start in range(0, dim_size, stride):
318 | end = start + slice_size
319 | if end <= dim_size:
320 | slices.append(slice(start, end))
321 | return slices
322 |
323 |
324 | def _iterate_through_dimensions(
325 | ds: xr.Dataset | xr.DataArray,
326 | *,
327 | dims: dict[Hashable, int],
328 | overlap: dict[Hashable, int] | None = None,
329 | ) -> Iterator[dict[Hashable, slice]]:
330 | if overlap is None:
331 | overlap = {}
332 | dim_slices = []
333 | for dim, slice_size in dims.items():
334 | dim_size = ds.sizes[dim]
335 | slice_overlap = overlap.get(dim, 0)
336 | if slice_size > dim_size:
337 | raise ValueError(
338 | 'input sample length must be less than or equal to the '
339 | f'dimension length, but the sample length of {slice_size} '
340 | f'is greater than the dimension length of {dim_size} '
341 | f'for {dim}'
342 | )
343 | dim_slices.append(
344 | _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap)
345 | )
346 | for slices in itertools.product(*dim_slices):
347 | selector = dict(zip(dims, slices))
348 | yield selector
349 |
350 |
351 | def _drop_input_dims(
352 | ds: xr.Dataset | xr.DataArray,
353 | input_dims: dict[Hashable, int],
354 | suffix: str = '_input',
355 | ) -> xr.Dataset | xr.DataArray:
356 | # remove input_dims coordinates from datasets, rename the dimensions
357 | # then put intput_dims back in as coordinates
358 | out = ds.copy()
359 | for dim in input_dims:
360 | newdim = f'{dim}{suffix}'
361 | out = out.rename({dim: newdim})
362 | # extra steps needed if there is a coordinate
363 | if newdim in out:
364 | out = out.drop_vars(newdim)
365 | out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs
366 | return out
367 |
368 |
369 | def _maybe_stack_batch_dims(
370 | ds: xr.Dataset | xr.DataArray,
371 | input_dims: Sequence[Hashable],
372 | ) -> xr.Dataset | xr.DataArray:
373 | batch_dims = [d for d in ds.sizes if d not in input_dims]
374 | if len(batch_dims) < 2:
375 | return ds
376 | ds_stack = ds.stack(sample=batch_dims)
377 | # ensure correct order
378 | dim_order = ('sample',) + tuple(input_dims)
379 | return ds_stack.transpose(*dim_order)
380 |
381 |
382 | class BatchGenerator:
383 | """Create generator for iterating through Xarray DataArrays / Datasets in
384 | batches.
385 |
386 | Parameters
387 | ----------
388 | ds : ``xarray.Dataset`` or ``xarray.DataArray``
389 | The data to iterate over
390 | input_dims : dict
391 | A dictionary specifying the size of the inputs in each dimension,
392 | e.g. ``{'lat': 30, 'lon': 30}``
393 | These are the dimensions the ML library will see. All other dimensions
394 | will be stacked into one dimension called ``sample``.
395 | input_overlap : dict, optional
396 | A dictionary specifying the overlap along each dimension
397 | e.g. ``{'lat': 3, 'lon': 3}``
398 | batch_dims : dict, optional
399 | A dictionary specifying the size of the batch along each dimension
400 | e.g. ``{'time': 10}``. These will always be iterated over.
401 | concat_input_dims : bool, optional
402 | If ``True``, the dimension chunks specified in ``input_dims`` will be
403 | concatenated and stacked into the ``sample`` dimension. The batch index
404 | will be included as a new level ``input_batch`` in the ``sample``
405 | coordinate.
406 | If ``False``, the dimension chunks specified in ``input_dims`` will be
407 | iterated over.
408 | preload_batch : bool, optional
409 | If ``True``, each batch will be loaded into memory before reshaping /
410 | processing, triggering any dask arrays to be computed.
411 | cache : dict, optional
412 | Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note:
413 | The caching API is experimental and subject to change.
414 | cache_preprocess: callable, optional
415 | A function to apply to batches prior to caching.
416 | Note: The caching API is experimental and subject to change.
417 |
418 | Yields
419 | ------
420 | ds_slice : ``xarray.Dataset`` or ``xarray.DataArray``
421 | Slices of the array matching the given batch size specification.
422 | """
423 |
424 | def __init__(
425 | self,
426 | ds: xr.Dataset | xr.DataArray,
427 | input_dims: dict[Hashable, int],
428 | input_overlap: dict[Hashable, int] | None = None,
429 | batch_dims: dict[Hashable, int] | None = None,
430 | concat_input_dims: bool = False,
431 | preload_batch: bool = True,
432 | cache: dict[str, Any] | None = None,
433 | cache_preprocess: Callable | None = None,
434 | ):
435 | if input_overlap is None:
436 | input_overlap = {}
437 | if batch_dims is None:
438 | batch_dims = {}
439 | self.ds = ds
440 | self.cache = cache
441 | self.cache_preprocess = cache_preprocess
442 |
443 | self._batch_selectors: BatchSchema = BatchSchema(
444 | ds,
445 | input_dims=input_dims,
446 | input_overlap=input_overlap,
447 | batch_dims=batch_dims,
448 | concat_input_bins=concat_input_dims,
449 | preload_batch=preload_batch,
450 | )
451 |
452 | @property
453 | def input_dims(self):
454 | return self._batch_selectors.input_dims
455 |
456 | @property
457 | def input_overlap(self):
458 | return self._batch_selectors.input_overlap
459 |
460 | @property
461 | def batch_dims(self):
462 | return self._batch_selectors.batch_dims
463 |
464 | @property
465 | def concat_input_dims(self):
466 | return self._batch_selectors.concat_input_dims
467 |
468 | @property
469 | def preload_batch(self):
470 | return self._batch_selectors.preload_batch
471 |
472 | def __iter__(self) -> Iterator[xr.DataArray | xr.Dataset]:
473 | for idx in self._batch_selectors.selectors:
474 | yield self[idx]
475 |
476 | def __len__(self) -> int:
477 | return len(self._batch_selectors.selectors)
478 |
479 | def __getitem__(self, idx: int) -> xr.Dataset | xr.DataArray:
480 | if not isinstance(idx, int):
481 | raise NotImplementedError(
482 | f'{type(self).__name__}.__getitem__ currently requires a single integer key'
483 | )
484 |
485 | if idx < 0:
486 | idx = list(self._batch_selectors.selectors)[idx]
487 |
488 | if self.cache and self._batch_in_cache(idx):
489 | return self._get_cached_batch(idx)
490 |
491 | if idx not in self._batch_selectors.selectors:
492 | raise IndexError('list index out of range')
493 |
494 | if self.concat_input_dims:
495 | new_dim_suffix = '_input'
496 | all_dsets: list = []
497 | batch_selector = {}
498 | for dim in self._batch_selectors.batch_dims.keys():
499 | starts = [x[dim].start for x in self._batch_selectors.selectors[idx]]
500 | stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]]
501 | batch_selector[dim] = slice(min(starts), max(stops))
502 | batch_ds = self.ds.isel(batch_selector)
503 | if self.preload_batch:
504 | batch_ds.load()
505 | for selector in self._batch_selectors.selectors[idx]:
506 | patch_ds = self.ds.isel(selector)
507 | all_dsets.append(
508 | _drop_input_dims(
509 | patch_ds,
510 | self.input_dims,
511 | suffix=new_dim_suffix,
512 | )
513 | )
514 | dsc = xr.concat(all_dsets, dim='input_batch')
515 | new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims]
516 | batch = _maybe_stack_batch_dims(dsc, new_input_dims)
517 | else:
518 | batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0])
519 | if self.preload_batch:
520 | batch_ds.load()
521 | batch = _maybe_stack_batch_dims(
522 | batch_ds,
523 | list(self.input_dims),
524 | )
525 | if self.cache is not None and self.cache_preprocess is not None:
526 | batch = self.cache_preprocess(batch)
527 | if self.cache is not None:
528 | self._cache_batch(idx, batch)
529 |
530 | return batch
531 |
532 | def _batch_in_cache(self, idx: int) -> bool:
533 | return self.cache is not None and f'{idx}/.zgroup' in self.cache
534 |
535 | def _cache_batch(self, idx: int, batch: xr.Dataset | xr.DataArray) -> None:
536 | batch.to_zarr(self.cache, group=str(idx), mode='a')
537 |
538 | def _get_cached_batch(self, idx: int) -> xr.Dataset:
539 | ds = xr.open_zarr(self.cache, group=str(idx))
540 | if self.preload_batch:
541 | ds = ds.load()
542 | return ds
543 |
--------------------------------------------------------------------------------
/xbatcher/loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xarray-contrib/xbatcher/5898a76ce88200a28eed036f0fbec9890a4280b5/xbatcher/loaders/__init__.py
--------------------------------------------------------------------------------
/xbatcher/loaders/keras.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from typing import Any
3 |
4 | try:
5 | import tensorflow as tf
6 | except ImportError as exc:
7 | raise ImportError(
8 | 'The Xbatcher TensorFlow Dataset API depends on TensorFlow. Please '
9 | 'install TensorFlow to proceed.'
10 | ) from exc
11 |
12 | # Notes:
13 | # This module includes one Keras dataset, which can be provided to model.fit().
14 | # - The CustomTFDataset provides an indexable interface
15 | # Assumptions made:
16 | # - The dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)
17 |
18 |
19 | class CustomTFDataset(tf.keras.utils.Sequence):
20 | def __init__(
21 | self,
22 | X_generator,
23 | y_generator,
24 | *,
25 | transform: Callable | None = None,
26 | target_transform: Callable | None = None,
27 | ) -> None:
28 | """
29 | Keras Dataset adapter for Xbatcher
30 |
31 | Parameters
32 | ----------
33 | X_generator : xbatcher.BatchGenerator
34 | y_generator : xbatcher.BatchGenerator
35 | transform : callable, optional
36 | A function/transform that takes in an array and returns a transformed version.
37 | target_transform : callable, optional
38 | A function/transform that takes in the target and transforms it.
39 | """
40 | self.X_generator = X_generator
41 | self.y_generator = y_generator
42 | self.transform = transform
43 | self.target_transform = target_transform
44 |
45 | def __len__(self) -> int:
46 | return len(self.X_generator)
47 |
48 | def __getitem__(self, idx: int) -> tuple[Any, Any]:
49 | X_batch = tf.convert_to_tensor(self.X_generator[idx].data)
50 | y_batch = tf.convert_to_tensor(self.y_generator[idx].data)
51 |
52 | # TODO: Should the transformations be applied before tensor conversion?
53 | if self.transform:
54 | X_batch = self.transform(X_batch)
55 |
56 | if self.target_transform:
57 | y_batch = self.target_transform(y_batch)
58 | return X_batch, y_batch
59 |
--------------------------------------------------------------------------------
/xbatcher/loaders/torch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from types import ModuleType
5 |
6 | import xarray as xr
7 |
8 | from xbatcher import BatchGenerator
9 |
10 | try:
11 | import torch
12 | except ImportError as exc:
13 | raise ImportError(
14 | 'The Xbatcher PyTorch Dataset API depends on PyTorch. Please '
15 | 'install PyTorch to proceed.'
16 | ) from exc
17 |
18 | try:
19 | import dask
20 | except ImportError:
21 | dask: ModuleType | None = None # type: ignore[no-redef]
22 |
23 | T_DataArrayOrSet = xr.DataArray | xr.Dataset
24 |
25 | # Notes:
26 | # This module includes two PyTorch datasets.
27 | # - The MapDataset provides an indexable interface
28 | # - The IterableDataset provides a simple iterable interface
29 | # Both can be provided as arguments to the the Torch DataLoader
30 | # Assumptions made:
31 | # - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)
32 | # TODOs:
33 | # - need to test with additional dataset parameters (e.g. transforms)
34 |
35 |
36 | def to_tensor(xr_obj: T_DataArrayOrSet) -> torch.Tensor:
37 | """Convert this DataArray or Dataset to a torch.Tensor"""
38 | if isinstance(xr_obj, xr.Dataset):
39 | xr_obj = xr_obj.to_array().squeeze(dim='variable')
40 | if isinstance(xr_obj, xr.DataArray):
41 | xr_obj = xr_obj.data
42 | return torch.tensor(xr_obj)
43 |
44 |
45 | class MapDataset(torch.utils.data.Dataset):
46 | def __init__(
47 | self,
48 | X_generator: BatchGenerator,
49 | y_generator: BatchGenerator | None = None,
50 | transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor,
51 | target_transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor,
52 | ) -> None:
53 | """
54 | PyTorch Dataset adapter for Xbatcher
55 |
56 | Parameters
57 | ----------
58 | X_generator : xbatcher.BatchGenerator
59 | y_generator : xbatcher.BatchGenerator
60 | transform, target_transform : callable, optional
61 | A function/transform that takes in an Xarray object and returns a transformed version in the form of a torch.Tensor.
62 | """
63 | self.X_generator = X_generator
64 | self.y_generator = y_generator
65 | self.transform = transform
66 | self.target_transform = target_transform
67 |
68 | def __len__(self) -> int:
69 | return len(self.X_generator)
70 |
71 | def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
72 | if torch.is_tensor(idx):
73 | idx = idx.tolist()
74 | if len(idx) == 1:
75 | idx = idx[0]
76 | else:
77 | raise NotImplementedError(
78 | f'{type(self).__name__}.__getitem__ currently requires a single integer key'
79 | )
80 |
81 | # generate batch (or batches)
82 | if self.y_generator is not None:
83 | X_batch, y_batch = self.X_generator[idx], self.y_generator[idx]
84 | else:
85 | X_batch, y_batch = self.X_generator[idx], None
86 |
87 | # load batch (or batches) with dask if possible
88 | if dask is not None:
89 | X_batch, y_batch = dask.compute(X_batch, y_batch)
90 |
91 | # apply transformation(s)
92 | X_batch_tensor = self.transform(X_batch)
93 | if y_batch is not None:
94 | y_batch_tensor = self.target_transform(y_batch)
95 |
96 | assert isinstance(X_batch_tensor, torch.Tensor), self.transform
97 |
98 | if y_batch is None:
99 | return X_batch_tensor
100 | assert isinstance(y_batch_tensor, torch.Tensor)
101 | return X_batch_tensor, y_batch_tensor
102 |
103 |
104 | class IterableDataset(torch.utils.data.IterableDataset):
105 | def __init__(
106 | self,
107 | X_generator,
108 | y_generator,
109 | ) -> None:
110 | """
111 | PyTorch Dataset adapter for Xbatcher
112 |
113 | Parameters
114 | ----------
115 | X_generator : xbatcher.BatchGenerator
116 | y_generator : xbatcher.BatchGenerator
117 | """
118 |
119 | self.X_generator = X_generator
120 | self.y_generator = y_generator
121 |
122 | def __iter__(self):
123 | for xb, yb in zip(self.X_generator, self.y_generator):
124 | yield (xb.torch.to_tensor(), yb.torch.to_tensor())
125 |
--------------------------------------------------------------------------------
/xbatcher/testing.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Hashable
2 | from unittest import TestCase
3 |
4 | import numpy as np
5 | import xarray as xr
6 |
7 | from .generators import BatchGenerator
8 |
9 |
10 | def _get_non_specified_dims(generator: BatchGenerator) -> dict[Hashable, int]:
11 | """
12 | Return all dimensions that are in the input dataset but not ``input_dims``
13 | or ``batch_dims``.
14 |
15 | Parameters
16 | ----------
17 | generator : xbatcher.BatchGenerator
18 | The batch generator object.
19 |
20 | Returns
21 | -------
22 | d : dict
23 | Dict containing all dimensions in the input dataset that are not
24 | in the input_dims or batch_dims attributes of the batch generator.
25 | """
26 | return {
27 | dim: length
28 | for dim, length in generator.ds.sizes.items()
29 | if generator.input_dims.get(dim) is None
30 | and generator.batch_dims.get(dim) is None
31 | }
32 |
33 |
34 | def _get_non_input_batch_dims(generator: BatchGenerator) -> dict[Hashable, int]:
35 | """
36 | Return all dimensions that are in batch_dims but not input_dims.
37 |
38 | Parameters
39 | ----------
40 | generator : xbatcher.BatchGenerator
41 | The batch generator object.
42 |
43 | Returns
44 | -------
45 | d : dict
46 | Dict containing all dimensions in specified in batch_dims that are
47 | not also in input_dims
48 | """
49 | return {
50 | dim: length
51 | for dim, length in generator.batch_dims.items()
52 | if generator.input_dims.get(dim) is None
53 | }
54 |
55 |
56 | def _get_duplicate_batch_dims(generator: BatchGenerator) -> dict[Hashable, int]:
57 | """
58 | Return all dimensions that are in both batch_dims and input_dims.
59 |
60 | Parameters
61 | ----------
62 | generator : xbatcher.BatchGenerator
63 | The batch generator object.
64 |
65 | Returns
66 | -------
67 | d : dict
68 | Dict containing all dimensions duplicated between batch_dims and input_dims.
69 | """
70 | return {
71 | dim: length
72 | for dim, length in generator.batch_dims.items()
73 | if generator.input_dims.get(dim) is not None
74 | }
75 |
76 |
77 | def _get_sample_length(
78 | *,
79 | generator: BatchGenerator,
80 | non_specified_ds_dims: dict[Hashable, int],
81 | non_input_batch_dims: dict[Hashable, int],
82 | ) -> int:
83 | """
84 | Return the expected length of the sample dimension.
85 |
86 | Parameters
87 | ----------
88 | generator : xbatcher.BatchGenerator
89 | The batch generator object.
90 | non_specified_ds_dics : dict
91 | Dict containing all dimensions in the input dataset that are not
92 | in the input_dims or batch_dims attributes of the batch generator.
93 | non_input_batch_dims : dict
94 | Dict containing all dimensions in specified in batch_dims that are
95 | not also in input_dims
96 |
97 | Returns
98 | -------
99 | s : int
100 | Expected length of the sample dimension
101 | """
102 | if generator.concat_input_dims:
103 | batch_concat_dims = [
104 | (
105 | generator.batch_dims.get(dim) // length
106 | if generator.batch_dims.get(dim)
107 | else generator.ds.sizes.get(dim) // length
108 | )
109 | for dim, length in generator.input_dims.items()
110 | ]
111 | else:
112 | batch_concat_dims = []
113 | return int(
114 | np.prod(list(non_specified_ds_dims.values()))
115 | * np.prod(list(non_input_batch_dims.values()))
116 | * np.prod(batch_concat_dims)
117 | )
118 |
119 |
120 | def get_batch_dimensions(generator: BatchGenerator) -> dict[Hashable, int]:
121 | """
122 | Return the expected batch dimensions based on the ``input_dims``,
123 | ``batch_dims``, and ``concat_input_dims`` attributes of the batch
124 | generator.
125 |
126 | Parameters
127 | ----------
128 | generator : xbatcher.BatchGenerator
129 | The batch generator object.
130 |
131 | Returns
132 | -------
133 | d : dict
134 | Dict containing the expected dimensions for batches returned by the
135 | batch generator.
136 | """
137 | # dimensions that are in the input dataset but not input_dims or batch_dims
138 | non_specified_ds_dims = _get_non_specified_dims(generator)
139 | # dimensions that are in batch_dims but not input_dims
140 | non_input_batch_dims = _get_non_input_batch_dims(generator)
141 | expected_sample_length = _get_sample_length(
142 | generator=generator,
143 | non_specified_ds_dims=non_specified_ds_dims,
144 | non_input_batch_dims=non_input_batch_dims,
145 | )
146 | # input_dims stay the same, possibly with a new suffix
147 | expected_dims = {
148 | f'{k}_input' if generator.concat_input_dims else k: v
149 | for k, v in generator.input_dims.items()
150 | }
151 | # Add a sample dimension if there's anything to get stacked
152 | if (
153 | generator.concat_input_dims
154 | and (len(generator.ds.sizes) - len(generator.input_dims)) == 0
155 | ):
156 | expected_dims = {**{'input_batch': expected_sample_length}, **expected_dims}
157 | elif (
158 | generator.concat_input_dims
159 | or (len(generator.ds.sizes) - len(generator.input_dims)) > 1
160 | ):
161 | expected_dims = {**{'sample': expected_sample_length}, **expected_dims}
162 | else:
163 | expected_dims = dict(
164 | **non_specified_ds_dims,
165 | **non_input_batch_dims,
166 | **expected_dims,
167 | )
168 | return expected_dims
169 |
170 |
171 | def validate_batch_dimensions(
172 | *, expected_dims: dict[Hashable, int], batch: xr.Dataset | xr.DataArray
173 | ) -> None:
174 | """
175 | Raises an AssertionError if the shape and dimensions of a batch do not
176 | match expected_dims.
177 |
178 | Parameters
179 | ----------
180 | expected_dims : Dict
181 | Dict containing the expected dimensions for batches.
182 | batch : xarray.Dataset or xarray.DataArray
183 | The xarray data object returned by the batch generator.
184 | """
185 |
186 | # Check the names and lengths of the dimensions are equal
187 | TestCase().assertDictEqual(
188 | expected_dims, dict(batch.sizes), msg='Dimension names and/or lengths differ'
189 | )
190 | # Check the dimension order is equal
191 | for var in batch.data_vars:
192 | TestCase().assertEqual(
193 | tuple(expected_dims.values()),
194 | batch[var].shape,
195 | msg=f'Order differs for dimensions of: {expected_dims}',
196 | )
197 |
198 |
199 | def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int:
200 | """
201 | Calculate the number of batches expected based on ``input_dims`` and
202 | ``input_overlap``.
203 |
204 | Parameters
205 | ----------
206 | generator : xbatcher.BatchGenerator
207 | The batch generator object.
208 |
209 | Returns
210 | -------
211 | s : int
212 | Number of batches expected given ``input_dims`` and ``input_overlap``.
213 | """
214 | nbatches_from_input_dims = np.prod(
215 | [
216 | generator.ds.sizes[dim] // length
217 | for dim, length in generator.input_dims.items()
218 | if generator.input_overlap.get(dim) is None
219 | and generator.batch_dims.get(dim) is None
220 | ]
221 | )
222 | if generator.input_overlap:
223 | nbatches_from_input_overlap = np.prod(
224 | [
225 | (generator.ds.sizes[dim] - overlap)
226 | // (generator.input_dims[dim] - overlap)
227 | for dim, overlap in generator.input_overlap.items()
228 | ]
229 | )
230 | return int(nbatches_from_input_overlap * nbatches_from_input_dims)
231 | else:
232 | return int(nbatches_from_input_dims)
233 |
234 |
235 | def validate_generator_length(generator: BatchGenerator) -> None:
236 | """
237 | Raises an AssertionError if the generator length does not match
238 | expectations based on the input Dataset and ``input_dims``.
239 |
240 | Parameters
241 | ----------
242 | generator : xbatcher.BatchGenerator
243 | The batch generator object.
244 | """
245 | non_input_batch_dims = _get_non_input_batch_dims(generator)
246 | duplicate_batch_dims = _get_duplicate_batch_dims(generator)
247 | nbatches_from_unique_batch_dims = np.prod(
248 | [
249 | generator.ds.sizes[dim] // length
250 | for dim, length in non_input_batch_dims.items()
251 | ]
252 | )
253 | nbatches_from_duplicate_batch_dims = np.prod(
254 | [
255 | generator.ds.sizes[dim] // length
256 | for dim, length in duplicate_batch_dims.items()
257 | ]
258 | )
259 | if generator.concat_input_dims:
260 | expected_length = int(
261 | nbatches_from_unique_batch_dims * nbatches_from_duplicate_batch_dims
262 | )
263 | else:
264 | nbatches_from_input_dims = _get_nbatches_from_input_dims(generator)
265 | expected_length = int(
266 | nbatches_from_unique_batch_dims
267 | * nbatches_from_duplicate_batch_dims
268 | * nbatches_from_input_dims
269 | )
270 | TestCase().assertEqual(
271 | expected_length,
272 | len(generator),
273 | msg='Batch generator length differs',
274 | )
275 |
--------------------------------------------------------------------------------
/xbatcher/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xarray-contrib/xbatcher/5898a76ce88200a28eed036f0fbec9890a4280b5/xbatcher/tests/__init__.py
--------------------------------------------------------------------------------
/xbatcher/tests/test_accessors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import xarray as xr
4 |
5 | import xbatcher # noqa: F401
6 | from xbatcher import BatchGenerator
7 |
8 |
9 | @pytest.fixture(scope='module')
10 | def sample_ds_3d():
11 | shape = (10, 50, 100)
12 | ds = xr.Dataset(
13 | {
14 | 'foo': (['time', 'y', 'x'], np.random.rand(*shape)),
15 | 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)),
16 | },
17 | {
18 | 'x': (['x'], np.arange(shape[-1])),
19 | 'y': (['y'], np.arange(shape[-2])),
20 | },
21 | )
22 | return ds
23 |
24 |
25 | @pytest.fixture(scope='module')
26 | def sample_dataArray():
27 | return xr.DataArray(np.zeros((2, 4), dtype='i4'), dims=('x', 'y'), name='foo')
28 |
29 |
30 | @pytest.fixture(scope='module')
31 | def sample_Dataset():
32 | return xr.Dataset(
33 | {
34 | 'x': xr.DataArray(np.arange(10), dims='x'),
35 | 'foo': xr.DataArray(np.ones(10, dtype='float'), dims='x'),
36 | }
37 | )
38 |
39 |
40 | def test_as_xarray_dataarray(sample_dataArray, sample_Dataset):
41 | assert isinstance(
42 | xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray
43 | )
44 | assert isinstance(
45 | xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray
46 | )
47 |
48 |
49 | def test_batch_accessor_ds(sample_ds_3d):
50 | bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5})
51 | bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5})
52 | assert isinstance(bg_acc, BatchGenerator)
53 | for batch_class, batch_acc in zip(bg_class, bg_acc):
54 | assert isinstance(batch_acc, xr.Dataset)
55 | assert batch_class.equals(batch_acc)
56 |
57 |
58 | def test_batch_accessor_da(sample_ds_3d):
59 | sample_da = sample_ds_3d['foo']
60 | bg_class = BatchGenerator(sample_da, input_dims={'x': 5})
61 | bg_acc = sample_da.batch.generator(input_dims={'x': 5})
62 | assert isinstance(bg_acc, BatchGenerator)
63 | for batch_class, batch_acc in zip(bg_class, bg_acc):
64 | assert batch_class.equals(batch_acc)
65 |
66 |
67 | @pytest.mark.parametrize(
68 | 'foo_var',
69 | [
70 | 'foo', # xr.DataArray
71 | ['foo'], # xr.Dataset
72 | ],
73 | )
74 | def test_tf_to_tensor(sample_ds_3d, foo_var):
75 | tf = pytest.importorskip('tensorflow')
76 |
77 | foo = sample_ds_3d[foo_var]
78 | t = foo.tf.to_tensor()
79 | assert isinstance(t, tf.Tensor)
80 | assert t.shape == tuple(foo.sizes.values())
81 |
82 | foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
83 | np.testing.assert_array_equal(t, foo_array.values)
84 |
85 |
86 | @pytest.mark.parametrize(
87 | 'foo_var',
88 | [
89 | 'foo', # xr.DataArray
90 | ['foo'], # xr.Dataset
91 | ],
92 | )
93 | def test_torch_to_tensor(sample_ds_3d, foo_var):
94 | torch = pytest.importorskip('torch')
95 |
96 | foo = sample_ds_3d[foo_var]
97 | t = foo.torch.to_tensor()
98 | assert isinstance(t, torch.Tensor)
99 | assert t.names == (None, None, None)
100 | assert t.shape == tuple(foo.sizes.values())
101 |
102 | foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
103 | np.testing.assert_array_equal(t, foo_array.values)
104 |
105 |
106 | @pytest.mark.parametrize(
107 | 'foo_var',
108 | [
109 | 'foo', # xr.DataArray
110 | ['foo'], # xr.Dataset
111 | ],
112 | )
113 | def test_torch_to_named_tensor(sample_ds_3d, foo_var):
114 | torch = pytest.importorskip('torch')
115 |
116 | foo = sample_ds_3d[foo_var]
117 | t = foo.torch.to_named_tensor()
118 | assert isinstance(t, torch.Tensor)
119 | assert t.names == tuple(foo.dims)
120 | assert t.shape == tuple(foo.sizes.values())
121 |
122 | foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
123 | np.testing.assert_array_equal(t, foo_array.values)
124 |
--------------------------------------------------------------------------------
/xbatcher/tests/test_generators.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tempfile
3 | from typing import Any
4 |
5 | import numpy as np
6 | import pytest
7 | import xarray as xr
8 |
9 | from xbatcher import BatchGenerator, BatchSchema
10 | from xbatcher.testing import (
11 | get_batch_dimensions,
12 | validate_batch_dimensions,
13 | validate_generator_length,
14 | )
15 |
16 |
17 | @pytest.fixture(scope='module')
18 | def sample_ds_1d():
19 | """
20 | Sample 1D xarray.Dataset for testing.
21 | """
22 | size = 100
23 | ds = xr.Dataset(
24 | {
25 | 'foo': (['x'], np.random.rand(size)),
26 | 'bar': (['x'], np.random.randint(0, 10, size)),
27 | },
28 | {'x': (['x'], np.arange(size))},
29 | )
30 | return ds
31 |
32 |
33 | @pytest.fixture(scope='module')
34 | def sample_ds_3d():
35 | """
36 | Sample 3D xarray.Dataset for testing.
37 | """
38 | shape = (10, 50, 100)
39 | ds = xr.Dataset(
40 | {
41 | 'foo': (['time', 'y', 'x'], np.random.rand(*shape)),
42 | 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)),
43 | },
44 | {
45 | 'x': (['x'], np.arange(shape[-1])),
46 | 'y': (['y'], np.arange(shape[-2])),
47 | },
48 | )
49 | return ds
50 |
51 |
52 | def test_constructor_dataarray():
53 | """
54 | Test that the xarray.DataArray passed to the batch generator is stored
55 | in the .ds attribute.
56 | """
57 | da = xr.DataArray(np.random.rand(10), dims='x', name='foo')
58 | bg = BatchGenerator(da, input_dims={'x': 2})
59 | xr.testing.assert_identical(da, bg.ds)
60 |
61 |
62 | @pytest.mark.parametrize('input_size', [5, 6])
63 | def test_generator_length(sample_ds_1d, input_size):
64 | """ "
65 | Test the length of the batch generator.
66 | """
67 | bg = BatchGenerator(sample_ds_1d, input_dims={'x': input_size})
68 | validate_generator_length(bg)
69 |
70 |
71 | def test_generator_getitem(sample_ds_1d):
72 | """
73 | Test indexing on the batch generator.
74 | """
75 | bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})
76 | first_batch = bg[0]
77 | last_batch = bg[-1]
78 | expected_dims = get_batch_dimensions(bg)
79 | validate_batch_dimensions(expected_dims=expected_dims, batch=first_batch)
80 | validate_batch_dimensions(expected_dims=expected_dims, batch=last_batch)
81 | # raises IndexError for out of range index
82 | with pytest.raises(IndexError, match=r'list index out of range'):
83 | bg[9999999]
84 |
85 | # raises NotImplementedError for iterable index
86 | with pytest.raises(NotImplementedError):
87 | bg[[1, 2, 3]]
88 |
89 |
90 | @pytest.mark.parametrize('input_size', [5, 10])
91 | def test_batch_1d(sample_ds_1d, input_size):
92 | """
93 | Test batch generation for a 1D dataset using ``input_dims``.
94 | """
95 | bg = BatchGenerator(sample_ds_1d, input_dims={'x': input_size})
96 | validate_generator_length(bg)
97 | expected_dims = get_batch_dimensions(bg)
98 | for n, ds_batch in enumerate(bg):
99 | assert ds_batch.dims['x'] == input_size
100 | expected_slice = slice(input_size * n, input_size * (n + 1))
101 | ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
102 | xr.testing.assert_identical(ds_batch_expected, ds_batch)
103 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
104 |
105 |
106 | @pytest.mark.parametrize('input_size', [5, 10])
107 | def test_batch_1d_concat(sample_ds_1d, input_size):
108 | """
109 | Test batch generation for a 1D dataset using ``input_dims`` and concat_input_dims``.
110 | """
111 | bg = BatchGenerator(
112 | sample_ds_1d, input_dims={'x': input_size}, concat_input_dims=True
113 | )
114 | validate_generator_length(bg)
115 | expected_dims = get_batch_dimensions(bg)
116 | for ds_batch in bg:
117 | assert isinstance(ds_batch, xr.Dataset)
118 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
119 | assert 'x' in ds_batch.coords
120 |
121 |
122 | def test_batch_1d_concat_duplicate_dim(sample_ds_1d):
123 | """
124 | Test batch generation for a 1D dataset using ``concat_input_dims`` when
125 | the same dimension occurs in ``input_dims`` and `batch_dims``
126 | """
127 | bg = BatchGenerator(
128 | sample_ds_1d, input_dims={'x': 5}, batch_dims={'x': 10}, concat_input_dims=True
129 | )
130 | validate_generator_length(bg)
131 | expected_dims = get_batch_dimensions(bg)
132 | for ds_batch in bg:
133 | assert isinstance(ds_batch, xr.Dataset)
134 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
135 |
136 |
137 | @pytest.mark.parametrize('input_size', [5, 10])
138 | def test_batch_1d_no_coordinate(sample_ds_1d, input_size):
139 | """
140 | Test batch generation for a 1D dataset without coordinates using ``input_dims``.
141 |
142 | Fix for https://github.com/xarray-contrib/xbatcher/issues/3.
143 | """
144 | ds_dropped = sample_ds_1d.drop_vars('x')
145 | bg = BatchGenerator(ds_dropped, input_dims={'x': input_size})
146 | validate_generator_length(bg)
147 | expected_dims = get_batch_dimensions(bg)
148 | for n, ds_batch in enumerate(bg):
149 | assert ds_batch.dims['x'] == input_size
150 | expected_slice = slice(input_size * n, input_size * (n + 1))
151 | ds_batch_expected = ds_dropped.isel(x=expected_slice)
152 | xr.testing.assert_identical(ds_batch_expected, ds_batch)
153 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
154 |
155 |
156 | @pytest.mark.parametrize('input_size', [5, 10])
157 | def test_batch_1d_concat_no_coordinate(sample_ds_1d, input_size):
158 | """
159 | Test batch generation for a 1D dataset without coordinates using ``input_dims``
160 | and ``concat_input_dims``.
161 |
162 | Fix for https://github.com/xarray-contrib/xbatcher/issues/3.
163 | """
164 | ds_dropped = sample_ds_1d.drop_vars('x')
165 | bg = BatchGenerator(
166 | ds_dropped, input_dims={'x': input_size}, concat_input_dims=True
167 | )
168 | validate_generator_length(bg)
169 | expected_dims = get_batch_dimensions(bg)
170 | for ds_batch in bg:
171 | assert isinstance(ds_batch, xr.Dataset)
172 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
173 | assert 'x' not in ds_batch.coords
174 |
175 |
176 | @pytest.mark.parametrize('input_overlap', [1, 4])
177 | def test_batch_1d_overlap(sample_ds_1d, input_overlap):
178 | """
179 | Test batch generation for a 1D dataset without coordinates using ``input_dims``
180 | and ``input_overlap``.
181 | """
182 | input_size = 10
183 | bg = BatchGenerator(
184 | sample_ds_1d, input_dims={'x': input_size}, input_overlap={'x': input_overlap}
185 | )
186 | validate_generator_length(bg)
187 | expected_dims = get_batch_dimensions(bg)
188 | stride = input_size - input_overlap
189 | for n, ds_batch in enumerate(bg):
190 | assert ds_batch.dims['x'] == input_size
191 | expected_slice = slice(stride * n, stride * n + input_size)
192 | ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
193 | xr.testing.assert_identical(ds_batch_expected, ds_batch)
194 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
195 |
196 |
197 | @pytest.mark.parametrize('input_size', [5, 10])
198 | def test_batch_3d_1d_input(sample_ds_3d, input_size):
199 | """
200 | Test batch generation for a 3D dataset with 1 dimension
201 | specified in ``input_dims``.
202 | """
203 | bg = BatchGenerator(sample_ds_3d, input_dims={'x': input_size})
204 | validate_generator_length(bg)
205 | expected_dims = get_batch_dimensions(bg)
206 | for n, ds_batch in enumerate(bg):
207 | assert ds_batch.dims['x'] == input_size
208 | # time and y should be collapsed into batch dimension
209 | assert (
210 | ds_batch.dims['sample']
211 | == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time']
212 | )
213 | expected_slice = slice(input_size * n, input_size * (n + 1))
214 | ds_batch_expected = (
215 | sample_ds_3d.isel(x=expected_slice)
216 | .stack(sample=['time', 'y'])
217 | .transpose('sample', 'x')
218 | )
219 | xr.testing.assert_identical(ds_batch_expected, ds_batch)
220 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
221 |
222 |
223 | @pytest.mark.parametrize(
224 | 'concat',
225 | [
226 | True,
227 | pytest.param(
228 | False,
229 | marks=pytest.mark.xfail(
230 | reason='Bug described in https://github.com/xarray-contrib/xbatcher/issues/126'
231 | ),
232 | ),
233 | ],
234 | )
235 | def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat):
236 | """
237 | Test batch generation for a 3D dataset using ``input_dims`` and batch_dims``.
238 | """
239 | bg = BatchGenerator(
240 | sample_ds_3d,
241 | input_dims={'x': 5, 'y': 10},
242 | batch_dims={'time': 2},
243 | concat_input_dims=concat,
244 | )
245 | validate_generator_length(bg)
246 | expected_dims = get_batch_dimensions(bg)
247 | for ds_batch in bg:
248 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
249 |
250 |
251 | def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d):
252 | """
253 | Test batch generation for a 3D dataset using ``concat_input_dims`` when
254 | the same dimension occurs in ``input_dims`` and batch_dims``.
255 | """
256 | bg = BatchGenerator(
257 | sample_ds_3d,
258 | input_dims={'x': 5, 'y': 10},
259 | batch_dims={'x': 10, 'y': 20},
260 | concat_input_dims=True,
261 | )
262 | validate_generator_length(bg)
263 | expected_dims = get_batch_dimensions(bg)
264 | for ds_batch in bg:
265 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
266 |
267 |
268 | @pytest.mark.parametrize('input_size', [5, 10])
269 | def test_batch_3d_2d_input(sample_ds_3d, input_size):
270 | """
271 | Test batch generation for a 3D dataset with 2 dimensions
272 | specified in ``input_dims``.
273 | """
274 | x_input_size = 20
275 | bg = BatchGenerator(sample_ds_3d, input_dims={'y': input_size, 'x': x_input_size})
276 | validate_generator_length(bg)
277 | expected_dims = get_batch_dimensions(bg)
278 | for n, ds_batch in enumerate(bg):
279 | yn, xn = np.unravel_index(
280 | n,
281 | (
282 | (sample_ds_3d.dims['y'] // input_size),
283 | (sample_ds_3d.dims['x'] // x_input_size),
284 | ),
285 | )
286 | expected_xslice = slice(x_input_size * xn, x_input_size * (xn + 1))
287 | expected_yslice = slice(input_size * yn, input_size * (yn + 1))
288 | ds_batch_expected = sample_ds_3d.isel(x=expected_xslice, y=expected_yslice)
289 | xr.testing.assert_identical(ds_batch_expected, ds_batch)
290 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
291 |
292 |
293 | @pytest.mark.parametrize('input_size', [5, 10])
294 | def test_batch_3d_2d_input_concat(sample_ds_3d, input_size):
295 | """
296 | Test batch generation for a 3D dataset with 2 dimensions
297 | specified in ``input_dims`` using ``concat_input_dims``.
298 | """
299 | x_input_size = 20
300 | bg = BatchGenerator(
301 | sample_ds_3d,
302 | input_dims={'y': input_size, 'x': x_input_size},
303 | concat_input_dims=True,
304 | )
305 | validate_generator_length(bg)
306 | expected_dims = get_batch_dimensions(bg)
307 | for ds_batch in bg:
308 | assert isinstance(ds_batch, xr.Dataset)
309 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
310 |
311 | bg = BatchGenerator(
312 | sample_ds_3d,
313 | input_dims={'time': input_size, 'x': x_input_size},
314 | concat_input_dims=True,
315 | )
316 | validate_generator_length(bg)
317 | expected_dims = get_batch_dimensions(bg)
318 | for ds_batch in bg:
319 | assert isinstance(ds_batch, xr.Dataset)
320 | validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
321 |
322 |
323 | def test_preload_batch_false(sample_ds_1d):
324 | """
325 | Test ``preload_batch=False`` does not compute Dask arrays.
326 | """
327 | sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2})
328 | bg = BatchGenerator(sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=False)
329 | assert bg.preload_batch is False
330 | for ds_batch in bg:
331 | assert isinstance(ds_batch, xr.Dataset)
332 | assert ds_batch.chunks
333 |
334 |
335 | def test_preload_batch_true(sample_ds_1d):
336 | """
337 | Test ``preload_batch=True`` does computes Dask arrays.
338 | """
339 | sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2})
340 | bg = BatchGenerator(sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=True)
341 | assert bg.preload_batch is True
342 | for ds_batch in bg:
343 | assert isinstance(ds_batch, xr.Dataset)
344 | assert not ds_batch.chunks
345 |
346 |
347 | def test_input_dim_exceptions(sample_ds_1d):
348 | """
349 | Test that a ValueError is raised when input_dim[dim] > ds.sizes[dim]
350 | """
351 | with pytest.raises(ValueError) as e:
352 | BatchGenerator(sample_ds_1d, input_dims={'x': 110})
353 | assert len(e) == 1
354 |
355 |
356 | def test_input_overlap_exceptions(sample_ds_1d):
357 | """
358 | Test that a ValueError is raised when input_overlap[dim] > input_dim[dim]
359 | """
360 | with pytest.raises(ValueError) as e:
361 | BatchGenerator(sample_ds_1d, input_dims={'x': 10}, input_overlap={'x': 20})
362 | assert len(e) == 1
363 |
364 |
365 | @pytest.mark.parametrize('input_size', [5, 10])
366 | def test_to_json(sample_ds_3d, input_size):
367 | x_input_size = 20
368 | bg = BatchSchema(
369 | sample_ds_3d,
370 | input_dims={'time': input_size, 'x': x_input_size},
371 | )
372 | out_file = tempfile.NamedTemporaryFile(mode='w+b')
373 | bg.to_file(out_file.name)
374 | in_dict = json.load(out_file)
375 | assert in_dict['input_dims']['time'] == input_size
376 | assert in_dict['input_dims']['x'] == x_input_size
377 | out_file.close()
378 |
379 |
380 | @pytest.mark.parametrize('preload', [True, False])
381 | def test_batcher_cached_getitem(sample_ds_1d, preload) -> None:
382 | pytest.importorskip('zarr')
383 | cache: dict[str, Any] = {}
384 |
385 | def preproc(ds):
386 | processed = ds.load().chunk(-1)
387 | processed.attrs['foo'] = 'bar'
388 | return processed
389 |
390 | bg = BatchGenerator(
391 | sample_ds_1d,
392 | input_dims={'x': 10},
393 | cache=cache,
394 | cache_preprocess=preproc,
395 | preload_batch=preload,
396 | )
397 |
398 | # first batch
399 | assert bg[0].sizes['x'] == 10
400 | ds_no_cache = bg[1]
401 | # last batch
402 | assert bg[-1].sizes['x'] == 10
403 |
404 | assert '0/.zgroup' in cache
405 |
406 | # now from cache
407 | # first batch
408 | assert bg[0].sizes['x'] == 10
409 | # last batch
410 | assert bg[-1].sizes['x'] == 10
411 | ds_cache = bg[1]
412 |
413 | assert ds_no_cache.attrs['foo'] == 'bar'
414 | assert ds_cache.attrs['foo'] == 'bar'
415 |
416 | xr.testing.assert_equal(ds_no_cache, ds_cache)
417 | xr.testing.assert_identical(ds_no_cache, ds_cache)
418 |
419 | # without preprocess func
420 | bg = BatchGenerator(
421 | sample_ds_1d, input_dims={'x': 10}, cache=cache, preload_batch=preload
422 | )
423 | assert bg.cache_preprocess is None
424 | assert bg[0].sizes['x'] == 10
425 | ds_no_cache = bg[1]
426 | assert '1/.zgroup' in cache
427 | ds_cache = bg[1]
428 | xr.testing.assert_equal(ds_no_cache, ds_cache)
429 | xr.testing.assert_identical(ds_no_cache, ds_cache)
430 |
--------------------------------------------------------------------------------
/xbatcher/tests/test_keras_loaders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import xarray as xr
4 |
5 | from xbatcher import BatchGenerator
6 | from xbatcher.loaders.keras import CustomTFDataset
7 |
8 | tf = pytest.importorskip('tensorflow')
9 |
10 |
11 | @pytest.fixture(scope='module')
12 | def ds_xy():
13 | n_samples = 100
14 | n_features = 5
15 | ds = xr.Dataset(
16 | {
17 | 'x': (
18 | ['sample', 'feature'],
19 | np.random.random((n_samples, n_features)),
20 | ),
21 | 'y': (['sample'], np.random.random(n_samples)),
22 | },
23 | )
24 | return ds
25 |
26 |
27 | def test_custom_dataarray(ds_xy):
28 | x = ds_xy['x']
29 | y = ds_xy['y']
30 |
31 | x_gen = BatchGenerator(x, {'sample': 10})
32 | y_gen = BatchGenerator(y, {'sample': 10})
33 |
34 | dataset = CustomTFDataset(x_gen, y_gen)
35 |
36 | # test __getitem__
37 | x_batch, y_batch = dataset[0]
38 | assert x_batch.shape == (10, 5)
39 | assert y_batch.shape == (10,)
40 | assert tf.is_tensor(x_batch)
41 | assert tf.is_tensor(y_batch)
42 |
43 | # test __len__
44 | assert len(dataset) == len(x_gen)
45 |
46 |
47 | def test_custom_dataarray_with_transform(ds_xy):
48 | x = ds_xy['x']
49 | y = ds_xy['y']
50 |
51 | x_gen = BatchGenerator(x, {'sample': 10})
52 | y_gen = BatchGenerator(y, {'sample': 10})
53 |
54 | def x_transform(batch):
55 | return batch * 0 + 1
56 |
57 | def y_transform(batch):
58 | return batch * 0 - 1
59 |
60 | dataset = CustomTFDataset(
61 | x_gen, y_gen, transform=x_transform, target_transform=y_transform
62 | )
63 | x_batch, y_batch = dataset[0]
64 | assert x_batch.shape == (10, 5)
65 | assert y_batch.shape == (10,)
66 | assert tf.is_tensor(x_batch)
67 | assert tf.is_tensor(y_batch)
68 | assert tf.experimental.numpy.all(x_batch == 1)
69 | assert tf.experimental.numpy.all(y_batch == -1)
70 |
--------------------------------------------------------------------------------
/xbatcher/tests/test_print_versions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 |
5 | import xbatcher
6 |
7 |
8 | def test_show_versions() -> None:
9 | """
10 | Test xbatcher.show_versions()
11 |
12 | Based on https://github.com/pydata/xarray/blob/main/xarray/tests/test_print_versions.py
13 | """
14 | f = io.StringIO()
15 | xbatcher.show_versions(file=f)
16 | assert 'xbatcher information' in f.getvalue()
17 |
--------------------------------------------------------------------------------
/xbatcher/tests/test_torch_loaders.py:
--------------------------------------------------------------------------------
1 | from importlib import reload
2 |
3 | import numpy as np
4 | import pytest
5 | import xarray as xr
6 |
7 | from xbatcher import BatchGenerator
8 | from xbatcher.loaders.torch import IterableDataset, MapDataset, to_tensor
9 |
10 | torch = pytest.importorskip('torch')
11 |
12 |
13 | def test_import_torch_failure(monkeypatch):
14 | import sys
15 |
16 | import xbatcher.loaders
17 |
18 | monkeypatch.setitem(sys.modules, 'torch', None)
19 |
20 | with pytest.raises(ImportError) as excinfo:
21 | reload(xbatcher.loaders.torch)
22 |
23 | assert 'install PyTorch to proceed' in str(excinfo.value)
24 |
25 |
26 | def test_import_dask_failure(monkeypatch):
27 | import sys
28 |
29 | import xbatcher.loaders
30 |
31 | monkeypatch.setitem(sys.modules, 'dask', None)
32 | reload(xbatcher.loaders.torch)
33 |
34 | assert xbatcher.loaders.torch.dask is None
35 |
36 |
37 | @pytest.fixture(scope='module', params=[True, False])
38 | def ds_xy(request):
39 | n_samples = 100
40 | n_features = 5
41 | ds = xr.Dataset(
42 | {
43 | 'x': (
44 | ['sample', 'feature'],
45 | np.random.random((n_samples, n_features)),
46 | ),
47 | 'y': (['sample'], np.random.random(n_samples)),
48 | },
49 | )
50 |
51 | if request.param:
52 | ds = ds.chunk({'sample': 10})
53 |
54 | return ds
55 |
56 |
57 | @pytest.mark.parametrize('x_var', ['x', ['x']])
58 | def test_map_dataset_without_y(ds_xy, x_var) -> None:
59 | x = ds_xy[x_var]
60 |
61 | x_gen = BatchGenerator(x, {'sample': 10})
62 |
63 | dataset = MapDataset(x_gen)
64 |
65 | # test __getitem__
66 | x_batch = dataset[0]
67 | assert x_batch.shape == (10, 5) # type: ignore[union-attr]
68 | assert isinstance(x_batch, torch.Tensor)
69 |
70 | idx = torch.tensor([0])
71 | x_batch = dataset[idx]
72 | assert x_batch.shape == (10, 5)
73 | assert isinstance(x_batch, torch.Tensor)
74 |
75 | with pytest.raises(NotImplementedError):
76 | idx = torch.tensor([0, 1])
77 | x_batch = dataset[idx]
78 |
79 | # test __len__
80 | assert len(dataset) == len(x_gen)
81 |
82 | # test integration with torch DataLoader
83 | loader = torch.utils.data.DataLoader(dataset, batch_size=None)
84 |
85 | for x_batch in loader:
86 | assert x_batch.shape == (10, 5) # type: ignore[union-attr]
87 | assert isinstance(x_batch, torch.Tensor)
88 |
89 | # Check that array shape of last item in generator is same as the batch image
90 | assert tuple(x_gen[-1].sizes.values()) == x_batch.shape # type: ignore[union-attr]
91 | # Check that array values from last item in generator and batch are the same
92 | gen_array = (
93 | x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
94 | )
95 | np.testing.assert_array_equal(gen_array, x_batch) # type: ignore
96 |
97 |
98 | @pytest.mark.parametrize(
99 | ('x_var', 'y_var'),
100 | [
101 | ('x', 'y'), # xr.DataArray
102 | (['x'], ['y']), # xr.Dataset
103 | ],
104 | )
105 | def test_map_dataset(ds_xy, x_var, y_var) -> None:
106 | x = ds_xy[x_var]
107 | y = ds_xy[y_var]
108 |
109 | x_gen = BatchGenerator(x, {'sample': 10})
110 | y_gen = BatchGenerator(y, {'sample': 10})
111 |
112 | dataset = MapDataset(x_gen, y_gen)
113 |
114 | # test __getitem__
115 | x_batch, y_batch = dataset[0]
116 | assert x_batch.shape == (10, 5)
117 | assert y_batch.shape == (10,)
118 | assert isinstance(x_batch, torch.Tensor)
119 |
120 | idx = torch.tensor([0])
121 | x_batch, y_batch = dataset[idx]
122 | assert x_batch.shape == (10, 5)
123 | assert y_batch.shape == (10,)
124 | assert isinstance(x_batch, torch.Tensor)
125 |
126 | with pytest.raises(NotImplementedError):
127 | idx = torch.tensor([0, 1])
128 | x_batch, y_batch = dataset[idx]
129 |
130 | # test __len__
131 | assert len(dataset) == len(x_gen)
132 |
133 | # test integration with torch DataLoader
134 | loader = torch.utils.data.DataLoader(dataset, batch_size=None)
135 |
136 | for x_batch, y_batch in loader:
137 | assert x_batch.shape == (10, 5)
138 | assert y_batch.shape == (10,)
139 | assert isinstance(x_batch, torch.Tensor)
140 |
141 | # Check that array shape of last item in generator is same as the batch image
142 | assert tuple(x_gen[-1].sizes.values()) == x_batch.shape
143 | # Check that array values from last item in generator and batch are the same
144 | gen_array = (
145 | x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
146 | )
147 | np.testing.assert_array_equal(gen_array, x_batch) # type: ignore
148 |
149 |
150 | @pytest.mark.parametrize(
151 | ('x_var', 'y_var'),
152 | [
153 | ('x', 'y'), # xr.DataArray
154 | (['x'], ['y']), # xr.Dataset
155 | ],
156 | )
157 | def test_map_dataset_with_transform(ds_xy, x_var, y_var) -> None:
158 | x = ds_xy[x_var]
159 | y = ds_xy[y_var]
160 |
161 | x_gen = BatchGenerator(x, {'sample': 10})
162 | y_gen = BatchGenerator(y, {'sample': 10})
163 |
164 | def x_transform(batch):
165 | return to_tensor(batch * 0 + 1)
166 |
167 | def y_transform(batch):
168 | return to_tensor(batch * 0 - 1)
169 |
170 | dataset = MapDataset(
171 | x_gen, y_gen, transform=x_transform, target_transform=y_transform
172 | )
173 | x_batch, y_batch = dataset[0]
174 | assert x_batch.shape == (10, 5)
175 | assert y_batch.shape == (10,)
176 | assert isinstance(x_batch, torch.Tensor)
177 | assert (x_batch == 1).all()
178 | assert (y_batch == -1).all()
179 |
180 |
181 | @pytest.mark.parametrize(
182 | ('x_var', 'y_var'),
183 | [
184 | ('x', 'y'), # xr.DataArray
185 | (['x'], ['y']), # xr.Dataset
186 | ],
187 | )
188 | def test_iterable_dataset(ds_xy, x_var, y_var):
189 | x = ds_xy[x_var]
190 | y = ds_xy[y_var]
191 |
192 | x_gen = BatchGenerator(x, {'sample': 10})
193 | y_gen = BatchGenerator(y, {'sample': 10})
194 |
195 | dataset = IterableDataset(x_gen, y_gen)
196 |
197 | # test integration with torch DataLoader
198 | loader = torch.utils.data.DataLoader(dataset, batch_size=None)
199 |
200 | for x_batch, y_batch in loader:
201 | assert x_batch.shape == (10, 5)
202 | assert y_batch.shape == (10,)
203 | assert isinstance(x_batch, torch.Tensor)
204 |
205 | # Check that array shape of last item in generator is same as the batch image
206 | assert tuple(x_gen[-1].sizes.values()) == x_batch.shape
207 | # Check that array values from last item in generator and batch are the same
208 | gen_array = (
209 | x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
210 | )
211 | np.testing.assert_array_equal(gen_array, x_batch)
212 |
--------------------------------------------------------------------------------
/xbatcher/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xarray-contrib/xbatcher/5898a76ce88200a28eed036f0fbec9890a4280b5/xbatcher/util/__init__.py
--------------------------------------------------------------------------------
/xbatcher/util/print_versions.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import platform
3 | import sys
4 |
5 |
6 | def show_versions(file=sys.stdout):
7 | """
8 | Print various dependency versions, including information about:
9 |
10 | - xbatcher
11 | - System information (Python version, Operating System)
12 | - Dependency versions (Xarray, etc)
13 |
14 | Based on https://github.com/GenericMappingTools/pygmt/blob/09f9e65ebebfa929f9ddc2af90e05f3302c2239d/pygmt/__init__.py#L95
15 | """
16 |
17 | def _get_module_version(modname):
18 | """
19 | Get version information of a Python module.
20 |
21 | Copied from https://github.com/GenericMappingTools/pygmt/blob/09f9e65ebebfa929f9ddc2af90e05f3302c2239d/pygmt/__init__.py#L111
22 | """
23 | try:
24 | if modname in sys.modules:
25 | module = sys.modules[modname]
26 | else:
27 | module = importlib.import_module(modname)
28 |
29 | return getattr(module, '__version__', 'installed')
30 | except ImportError:
31 | return None
32 |
33 | sys_info = {
34 | 'python': sys.version.replace('\n', ' '),
35 | 'executable': sys.executable,
36 | 'machine': platform.platform(),
37 | }
38 |
39 | deps = [
40 | # Required
41 | 'dask',
42 | 'numpy',
43 | 'xarray',
44 | # Optional
45 | 'torch',
46 | # Setup/test
47 | 'pip',
48 | 'conda',
49 | 'pytest',
50 | # Misc.
51 | 'IPython',
52 | 'sphinx',
53 | ]
54 | __version__ = f'v{importlib.metadata.version("xbatcher")}'
55 |
56 | print('xbatcher information:', file=file)
57 | print(f' version: {__version__}', file=file)
58 |
59 | print('System information:', file=file)
60 | for key, val in sys_info.items():
61 | print(f' {key}: {val}', file=file)
62 |
63 | print('Dependency information:', file=file)
64 | for modname in deps:
65 | print(f' {modname}: {_get_module_version(modname)}', file=file)
66 |
67 |
68 | if __name__ == '__main__': # pragma: no cover
69 | show_versions()
70 |
--------------------------------------------------------------------------------